You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

273 lines
6.6 KiB

//! Socket handling
use super::*;
use std::str;
use std::io;
use std::path::{
Path, PathBuf
};
use std::{fmt, error};
use std::{
task::{Context, Poll},
pin::Pin,
};
use tokio::io::{
AsyncWrite, AsyncRead,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct SocketAddrUnix
{
pub path: PathBuf,
}
impl str::FromStr for SocketAddrUnix
{
type Err = AddrParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let path = Path::new(s);
if path.exists() && !path.is_dir() {
Ok(Self{path: path.into()})
} else {
Err(AddrParseError)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum SocketAddr
{
Unix(SocketAddrUnix),
IP(std::net::SocketAddr),
}
impl TryFrom<tokio::net::unix::SocketAddr> for SocketAddrUnix
{
type Error = AddrParseError;
fn try_from(from: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error>
{
from.as_pathname().ok_or(AddrParseError).map(|path| Self{path: path.into()})
}
}
impl TryFrom<tokio::net::unix::SocketAddr> for SocketAddr
{
type Error = AddrParseError;
fn try_from(from: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error>
{
SocketAddrUnix::try_from(from).map(Self::Unix)
}
}
impl From<SocketAddrUnix> for SocketAddr
{
fn from(from: SocketAddrUnix) -> Self
{
Self::Unix(from)
}
}
impl From<std::net::SocketAddr> for SocketAddr
{
fn from(from: std::net::SocketAddr) -> Self
{
Self::IP(from)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AddrParseError;
impl error::Error for AddrParseError{}
impl fmt::Display for AddrParseError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "failed to parse address")
}
}
impl From<std::net::AddrParseError> for AddrParseError
{
fn from(_: std::net::AddrParseError) -> Self
{
Self
}
}
const UNIX_SOCK_PREFIX: &str = "unix:/";
impl str::FromStr for SocketAddr
{
type Err = AddrParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if s.starts_with(UNIX_SOCK_PREFIX) {
Self::Unix(s[(UNIX_SOCK_PREFIX.len())..].parse()?)
} else {
Self::IP(s.parse()?)
})
}
}
#[derive(Debug)]
enum ListenerInner
{
Unix(tokio::net::UnixListener),
Tcp(tokio::net::TcpListener),
}
#[derive(Debug)]
enum StreamInner
{
Unix(tokio::net::UnixStream),
Tcp(tokio::net::TcpStream),
}
impl ListenerInner
{
#[inline] fn con_unix(sock: &SocketAddrUnix) -> io::Result<Self>
{
tokio::net::UnixListener::bind(&sock.path).map(Self::Unix)
}
async fn con_tcp(sock: &std::net::SocketAddr) -> io::Result<Self>
{
tokio::net::TcpListener::bind(sock).await.map(Self::Tcp)
}
}
/// A connected socket.
//TODO: Stream::connect(), direct connection
#[derive(Debug)]
pub struct Stream(Box<StreamInner>, SocketAddr);
impl From<(tokio::net::UnixStream, tokio::net::unix::SocketAddr)> for Stream
{
fn from((s, a): (tokio::net::UnixStream, tokio::net::unix::SocketAddr)) -> Self
{
Self(Box::new(StreamInner::Unix(s)), SocketAddrUnix::try_from(a).unwrap().into())
}
}
impl From<(tokio::net::TcpStream, std::net::SocketAddr)> for Stream
{
fn from((s, a): (tokio::net::TcpStream, std::net::SocketAddr)) -> Self
{
Self(Box::new(StreamInner::Tcp(s)), a.into())
}
}
/// A network listener, for either a Unix socket or TCP socket.
#[derive(Debug)]
pub struct Listener(Box<ListenerInner>);
impl Listener
{
/// Bind to Unix socket or IP/port.
///
/// Completes immediately when binding unix socket
pub async fn bind(sock: impl Into<SocketAddr>) -> io::Result<Self>
{
match sock.into() {
SocketAddr::Unix(ref unix) => ListenerInner::con_unix(unix).map(Box::new),
SocketAddr::IP(ref tcp) => ListenerInner::con_tcp(tcp).await.map(Box::new),
}.map(Self)
}
/// Accept connection on this listener.
pub async fn accept(&self) -> io::Result<Stream>
{
match self.0.as_ref()
{
ListenerInner::Unix(un) => un.accept().await.map(Into::into),
ListenerInner::Tcp(tcp) => tcp.accept().await.map(Into::into),
}
}
/// Local bound address
pub fn local_addr(&self) -> io::Result<SocketAddr>
{
match self.0.as_ref()
{
ListenerInner::Unix(un) => un.local_addr().and_then(|addr| addr.try_into().map_err(|e| io::Error::new(io::ErrorKind::Unsupported, e))),
ListenerInner::Tcp(t) => t.local_addr().map(Into::into)
}
}
/// Poll to accept a new incoming connection to this listener
pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Stream>>
{
match self.0.as_ref()
{
ListenerInner::Unix(un) => un.poll_accept(cx).map_ok(Into::into),
ListenerInner::Tcp(tcp) => tcp.poll_accept(cx).map_ok(Into::into)
}
}
}
//TODO: impl Stream
impl Stream
{
#[inline(always)] pub fn reader_ref(&self) -> &(dyn AsyncRead + Unpin + Send + Sync + '_)
{
match self.0.as_ref()
{
StreamInner::Unix(un) => un,
StreamInner::Tcp(tc) => tc,
}
}
#[inline(always)] pub fn writer_ref(&self) -> &(dyn AsyncWrite + Unpin + Send + Sync + '_)
{
match self.0.as_ref()
{
StreamInner::Unix(un) => un,
StreamInner::Tcp(tc) => tc,
}
}
#[inline(always)] pub fn reader_mut(&mut self) -> &mut (dyn AsyncRead + Unpin + Send + Sync + '_)
{
match self.0.as_mut()
{
StreamInner::Unix(un) => un,
StreamInner::Tcp(tc) => tc,
}
}
#[inline(always)] pub fn writer_mut(&mut self) -> &mut (dyn AsyncWrite + Unpin + Send + Sync + '_)
{
match self.0.as_mut()
{
StreamInner::Unix(un) => un,
StreamInner::Tcp(tc) => tc,
}
}
#[inline(always)] fn reader_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut (dyn AsyncRead + Unpin + Send + Sync + '_)>
{
Pin::new(self.get_mut().reader_mut())
}
#[inline(always)] fn writer_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut (dyn AsyncWrite + Unpin + Send + Sync + '_)>
{
Pin::new(self.get_mut().writer_mut())
}
}
impl AsyncWrite for Stream
{
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.writer_pinned_mut().poll_write(cx, buf)
}
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.writer_pinned_mut().poll_flush(cx)
}
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.writer_pinned_mut().poll_shutdown(cx)
}
}
impl AsyncRead for Stream
{
#[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<io::Result<()>> {
self.reader_pinned_mut().poll_read(cx, buf)
}
}