You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
273 lines
6.6 KiB
273 lines
6.6 KiB
//! Socket handling
|
|
use super::*;
|
|
use std::str;
|
|
use std::io;
|
|
use std::path::{
|
|
Path, PathBuf
|
|
};
|
|
use std::{fmt, error};
|
|
use std::{
|
|
task::{Context, Poll},
|
|
pin::Pin,
|
|
};
|
|
use tokio::io::{
|
|
AsyncWrite, AsyncRead,
|
|
};
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
|
pub struct SocketAddrUnix
|
|
{
|
|
pub path: PathBuf,
|
|
}
|
|
|
|
impl str::FromStr for SocketAddrUnix
|
|
{
|
|
type Err = AddrParseError;
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
let path = Path::new(s);
|
|
if path.exists() && !path.is_dir() {
|
|
Ok(Self{path: path.into()})
|
|
} else {
|
|
Err(AddrParseError)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
|
pub enum SocketAddr
|
|
{
|
|
Unix(SocketAddrUnix),
|
|
IP(std::net::SocketAddr),
|
|
}
|
|
|
|
impl TryFrom<tokio::net::unix::SocketAddr> for SocketAddrUnix
|
|
{
|
|
type Error = AddrParseError;
|
|
|
|
fn try_from(from: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error>
|
|
{
|
|
from.as_pathname().ok_or(AddrParseError).map(|path| Self{path: path.into()})
|
|
}
|
|
}
|
|
impl TryFrom<tokio::net::unix::SocketAddr> for SocketAddr
|
|
{
|
|
type Error = AddrParseError;
|
|
|
|
fn try_from(from: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error>
|
|
{
|
|
SocketAddrUnix::try_from(from).map(Self::Unix)
|
|
}
|
|
}
|
|
|
|
impl From<SocketAddrUnix> for SocketAddr
|
|
{
|
|
fn from(from: SocketAddrUnix) -> Self
|
|
{
|
|
Self::Unix(from)
|
|
}
|
|
}
|
|
|
|
|
|
impl From<std::net::SocketAddr> for SocketAddr
|
|
{
|
|
fn from(from: std::net::SocketAddr) -> Self
|
|
{
|
|
Self::IP(from)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct AddrParseError;
|
|
impl error::Error for AddrParseError{}
|
|
impl fmt::Display for AddrParseError
|
|
{
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
|
|
{
|
|
write!(f, "failed to parse address")
|
|
}
|
|
}
|
|
|
|
|
|
impl From<std::net::AddrParseError> for AddrParseError
|
|
{
|
|
fn from(_: std::net::AddrParseError) -> Self
|
|
{
|
|
Self
|
|
}
|
|
}
|
|
|
|
const UNIX_SOCK_PREFIX: &str = "unix:/";
|
|
impl str::FromStr for SocketAddr
|
|
{
|
|
type Err = AddrParseError;
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
Ok(if s.starts_with(UNIX_SOCK_PREFIX) {
|
|
Self::Unix(s[(UNIX_SOCK_PREFIX.len())..].parse()?)
|
|
} else {
|
|
Self::IP(s.parse()?)
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum ListenerInner
|
|
{
|
|
Unix(tokio::net::UnixListener),
|
|
Tcp(tokio::net::TcpListener),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum StreamInner
|
|
{
|
|
Unix(tokio::net::UnixStream),
|
|
Tcp(tokio::net::TcpStream),
|
|
}
|
|
|
|
impl ListenerInner
|
|
{
|
|
#[inline] fn con_unix(sock: &SocketAddrUnix) -> io::Result<Self>
|
|
{
|
|
tokio::net::UnixListener::bind(&sock.path).map(Self::Unix)
|
|
}
|
|
async fn con_tcp(sock: &std::net::SocketAddr) -> io::Result<Self>
|
|
{
|
|
tokio::net::TcpListener::bind(sock).await.map(Self::Tcp)
|
|
}
|
|
}
|
|
|
|
/// A connected socket.
|
|
//TODO: Stream::connect(), direct connection
|
|
#[derive(Debug)]
|
|
pub struct Stream(Box<StreamInner>, SocketAddr);
|
|
|
|
impl From<(tokio::net::UnixStream, tokio::net::unix::SocketAddr)> for Stream
|
|
{
|
|
fn from((s, a): (tokio::net::UnixStream, tokio::net::unix::SocketAddr)) -> Self
|
|
{
|
|
Self(Box::new(StreamInner::Unix(s)), SocketAddrUnix::try_from(a).unwrap().into())
|
|
}
|
|
}
|
|
|
|
impl From<(tokio::net::TcpStream, std::net::SocketAddr)> for Stream
|
|
{
|
|
fn from((s, a): (tokio::net::TcpStream, std::net::SocketAddr)) -> Self
|
|
{
|
|
Self(Box::new(StreamInner::Tcp(s)), a.into())
|
|
}
|
|
}
|
|
|
|
/// A network listener, for either a Unix socket or TCP socket.
|
|
#[derive(Debug)]
|
|
pub struct Listener(Box<ListenerInner>);
|
|
|
|
impl Listener
|
|
{
|
|
/// Bind to Unix socket or IP/port.
|
|
///
|
|
/// Completes immediately when binding unix socket
|
|
pub async fn bind(sock: impl Into<SocketAddr>) -> io::Result<Self>
|
|
{
|
|
match sock.into() {
|
|
SocketAddr::Unix(ref unix) => ListenerInner::con_unix(unix).map(Box::new),
|
|
SocketAddr::IP(ref tcp) => ListenerInner::con_tcp(tcp).await.map(Box::new),
|
|
}.map(Self)
|
|
}
|
|
|
|
/// Accept connection on this listener.
|
|
pub async fn accept(&self) -> io::Result<Stream>
|
|
{
|
|
match self.0.as_ref()
|
|
{
|
|
ListenerInner::Unix(un) => un.accept().await.map(Into::into),
|
|
ListenerInner::Tcp(tcp) => tcp.accept().await.map(Into::into),
|
|
}
|
|
}
|
|
|
|
/// Local bound address
|
|
pub fn local_addr(&self) -> io::Result<SocketAddr>
|
|
{
|
|
match self.0.as_ref()
|
|
{
|
|
ListenerInner::Unix(un) => un.local_addr().and_then(|addr| addr.try_into().map_err(|e| io::Error::new(io::ErrorKind::Unsupported, e))),
|
|
ListenerInner::Tcp(t) => t.local_addr().map(Into::into)
|
|
}
|
|
}
|
|
|
|
/// Poll to accept a new incoming connection to this listener
|
|
pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Stream>>
|
|
{
|
|
match self.0.as_ref()
|
|
{
|
|
ListenerInner::Unix(un) => un.poll_accept(cx).map_ok(Into::into),
|
|
ListenerInner::Tcp(tcp) => tcp.poll_accept(cx).map_ok(Into::into)
|
|
}
|
|
}
|
|
}
|
|
|
|
//TODO: impl Stream
|
|
impl Stream
|
|
{
|
|
#[inline(always)] pub fn reader_ref(&self) -> &(dyn AsyncRead + Unpin + Send + Sync + '_)
|
|
{
|
|
match self.0.as_ref()
|
|
{
|
|
StreamInner::Unix(un) => un,
|
|
StreamInner::Tcp(tc) => tc,
|
|
}
|
|
}
|
|
#[inline(always)] pub fn writer_ref(&self) -> &(dyn AsyncWrite + Unpin + Send + Sync + '_)
|
|
{
|
|
match self.0.as_ref()
|
|
{
|
|
StreamInner::Unix(un) => un,
|
|
StreamInner::Tcp(tc) => tc,
|
|
}
|
|
}
|
|
|
|
#[inline(always)] pub fn reader_mut(&mut self) -> &mut (dyn AsyncRead + Unpin + Send + Sync + '_)
|
|
{
|
|
match self.0.as_mut()
|
|
{
|
|
StreamInner::Unix(un) => un,
|
|
StreamInner::Tcp(tc) => tc,
|
|
}
|
|
}
|
|
#[inline(always)] pub fn writer_mut(&mut self) -> &mut (dyn AsyncWrite + Unpin + Send + Sync + '_)
|
|
{
|
|
match self.0.as_mut()
|
|
{
|
|
StreamInner::Unix(un) => un,
|
|
StreamInner::Tcp(tc) => tc,
|
|
}
|
|
}
|
|
|
|
#[inline(always)] fn reader_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut (dyn AsyncRead + Unpin + Send + Sync + '_)>
|
|
{
|
|
Pin::new(self.get_mut().reader_mut())
|
|
}
|
|
#[inline(always)] fn writer_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut (dyn AsyncWrite + Unpin + Send + Sync + '_)>
|
|
{
|
|
Pin::new(self.get_mut().writer_mut())
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for Stream
|
|
{
|
|
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
|
self.writer_pinned_mut().poll_write(cx, buf)
|
|
}
|
|
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
|
self.writer_pinned_mut().poll_flush(cx)
|
|
}
|
|
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
|
self.writer_pinned_mut().poll_shutdown(cx)
|
|
}
|
|
}
|
|
|
|
impl AsyncRead for Stream
|
|
{
|
|
#[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<io::Result<()>> {
|
|
self.reader_pinned_mut().poll_read(cx, buf)
|
|
}
|
|
}
|