//! Socket handling use super::*; use std::str; use std::io; use std::path::{ Path, PathBuf }; use std::{fmt, error}; use std::{ task::{Context, Poll}, pin::Pin, }; use tokio::io::{ AsyncWrite, AsyncRead, }; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct SocketAddrUnix { pub path: PathBuf, } impl str::FromStr for SocketAddrUnix { type Err = AddrParseError; fn from_str(s: &str) -> Result { let path = Path::new(s); if path.exists() && !path.is_dir() { Ok(Self{path: path.into()}) } else { Err(AddrParseError) } } } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum SocketAddr { Unix(SocketAddrUnix), IP(std::net::SocketAddr), } impl TryFrom for SocketAddrUnix { type Error = AddrParseError; fn try_from(from: tokio::net::unix::SocketAddr) -> Result { from.as_pathname().ok_or(AddrParseError).map(|path| Self{path: path.into()}) } } impl TryFrom for SocketAddr { type Error = AddrParseError; fn try_from(from: tokio::net::unix::SocketAddr) -> Result { SocketAddrUnix::try_from(from).map(Self::Unix) } } impl From for SocketAddr { fn from(from: SocketAddrUnix) -> Self { Self::Unix(from) } } impl From for SocketAddr { fn from(from: std::net::SocketAddr) -> Self { Self::IP(from) } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct AddrParseError; impl error::Error for AddrParseError{} impl fmt::Display for AddrParseError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "failed to parse address") } } impl From for AddrParseError { fn from(_: std::net::AddrParseError) -> Self { Self } } const UNIX_SOCK_PREFIX: &str = "unix:/"; impl str::FromStr for SocketAddr { type Err = AddrParseError; fn from_str(s: &str) -> Result { Ok(if s.starts_with(UNIX_SOCK_PREFIX) { Self::Unix(s[(UNIX_SOCK_PREFIX.len())..].parse()?) } else { Self::IP(s.parse()?) }) } } #[derive(Debug)] enum ListenerInner { Unix(tokio::net::UnixListener), Tcp(tokio::net::TcpListener), } #[derive(Debug)] enum StreamInner { Unix(tokio::net::UnixStream), Tcp(tokio::net::TcpStream), } impl ListenerInner { #[inline] fn con_unix(sock: &SocketAddrUnix) -> io::Result { tokio::net::UnixListener::bind(&sock.path).map(Self::Unix) } async fn con_tcp(sock: &std::net::SocketAddr) -> io::Result { tokio::net::TcpListener::bind(sock).await.map(Self::Tcp) } } /// A connected socket. //TODO: Stream::connect(), direct connection #[derive(Debug)] pub struct Stream(Box, SocketAddr); impl From<(tokio::net::UnixStream, tokio::net::unix::SocketAddr)> for Stream { fn from((s, a): (tokio::net::UnixStream, tokio::net::unix::SocketAddr)) -> Self { Self(Box::new(StreamInner::Unix(s)), SocketAddrUnix::try_from(a).unwrap().into()) } } impl From<(tokio::net::TcpStream, std::net::SocketAddr)> for Stream { fn from((s, a): (tokio::net::TcpStream, std::net::SocketAddr)) -> Self { Self(Box::new(StreamInner::Tcp(s)), a.into()) } } /// A network listener, for either a Unix socket or TCP socket. #[derive(Debug)] pub struct Listener(Box); impl Listener { /// Bind to Unix socket or IP/port. /// /// Completes immediately when binding unix socket pub async fn bind(sock: impl Into) -> io::Result { match sock.into() { SocketAddr::Unix(ref unix) => ListenerInner::con_unix(unix).map(Box::new), SocketAddr::IP(ref tcp) => ListenerInner::con_tcp(tcp).await.map(Box::new), }.map(Self) } /// Accept connection on this listener. pub async fn accept(&self) -> io::Result { match self.0.as_ref() { ListenerInner::Unix(un) => un.accept().await.map(Into::into), ListenerInner::Tcp(tcp) => tcp.accept().await.map(Into::into), } } /// Local bound address pub fn local_addr(&self) -> io::Result { match self.0.as_ref() { ListenerInner::Unix(un) => un.local_addr().and_then(|addr| addr.try_into().map_err(|e| io::Error::new(io::ErrorKind::Unsupported, e))), ListenerInner::Tcp(t) => t.local_addr().map(Into::into) } } /// Poll to accept a new incoming connection to this listener pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll> { match self.0.as_ref() { ListenerInner::Unix(un) => un.poll_accept(cx).map_ok(Into::into), ListenerInner::Tcp(tcp) => tcp.poll_accept(cx).map_ok(Into::into) } } } //TODO: impl Stream impl Stream { #[inline(always)] pub fn reader_ref(&self) -> &(dyn AsyncRead + Unpin + Send + Sync + '_) { match self.0.as_ref() { StreamInner::Unix(un) => un, StreamInner::Tcp(tc) => tc, } } #[inline(always)] pub fn writer_ref(&self) -> &(dyn AsyncWrite + Unpin + Send + Sync + '_) { match self.0.as_ref() { StreamInner::Unix(un) => un, StreamInner::Tcp(tc) => tc, } } #[inline(always)] pub fn reader_mut(&mut self) -> &mut (dyn AsyncRead + Unpin + Send + Sync + '_) { match self.0.as_mut() { StreamInner::Unix(un) => un, StreamInner::Tcp(tc) => tc, } } #[inline(always)] pub fn writer_mut(&mut self) -> &mut (dyn AsyncWrite + Unpin + Send + Sync + '_) { match self.0.as_mut() { StreamInner::Unix(un) => un, StreamInner::Tcp(tc) => tc, } } #[inline(always)] fn reader_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut (dyn AsyncRead + Unpin + Send + Sync + '_)> { Pin::new(self.get_mut().reader_mut()) } #[inline(always)] fn writer_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut (dyn AsyncWrite + Unpin + Send + Sync + '_)> { Pin::new(self.get_mut().writer_mut()) } } impl AsyncWrite for Stream { #[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { self.writer_pinned_mut().poll_write(cx, buf) } #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.writer_pinned_mut().poll_flush(cx) } #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.writer_pinned_mut().poll_shutdown(cx) } } impl AsyncRead for Stream { #[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { self.reader_pinned_mut().poll_read(cx, buf) } }