diff --git a/src/ext.rs b/src/ext.rs index 27aea52..d7b45f0 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -18,6 +18,11 @@ pub mod prelude pub use super::StreamGateExt as _; pub use super::StreamLagExt as _; pub use super::INodeExt as _; + + pub use super::async_write_ext::{ + EitherWrite, + DeadSink, + }; } pub trait INodeExt @@ -346,3 +351,147 @@ impl fmt::Display for SoftAssertionFailedError } }; } + +mod async_write_ext { + use std::ops::{Deref, DerefMut}; + use tokio::io::{ + self, + AsyncWrite, AsyncRead, + }; + use std::{ + pin::Pin, + task::{Poll, Context}, + }; + use std::marker::PhantomData; + + #[derive(Debug, Clone)] + pub enum EitherWrite<'a, T,U> + { + First(T), + Second(U, PhantomData<&'a mut U>), + } + + impl<'a, T, U> Deref for EitherWrite<'a, T,U> + where T: AsyncWrite + Unpin + 'a, + U: AsyncWrite + Unpin + 'a + { + type Target = dyn AsyncWrite + Unpin + 'a; + + fn deref(&self) -> &Self::Target { + match self { + Self::First(t) => t, + Self::Second(u, _) => u, + } + } + } + + impl<'a, T, U> DerefMut for EitherWrite<'a, T,U> + where T: AsyncWrite + Unpin + 'a, + U: AsyncWrite + Unpin + 'a + { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::First(t) => t, + Self::Second(u, _) => u, + } + } + } + + impl<'a, T, U> From> for EitherWrite<'a, T, U> + where T: AsyncWrite + Unpin + 'a, + U: AsyncWrite + Unpin + 'a + { + #[inline] fn from(from: Result) -> Self + { + match from { + Ok(v) => Self::First(v), + Err(v) => Self::Second(v, PhantomData), + } + } + } + + + + impl<'a, T> EitherWrite<'a, T, DeadSink> + { + #[inline] fn as_first_infallible(&mut self) -> &mut T + { + match self { + Self::Second(_, _) => unsafe { core::hint::unreachable_unchecked() }, + Self::First(t) => t + } + } + } + + impl<'a, U> EitherWrite<'a, DeadSink, U> + { + #[inline] fn as_second_infallible(&mut self) -> &mut U + { + match self { + Self::First(_) => unsafe { core::hint::unreachable_unchecked() }, + Self::Second(t, _) => t + } + } + } + +/* impl<'a, T> AsyncWrite for EitherWrite<'a, T, DeadSink> + where T: AsyncWrite + Unpin + 'a + { + #[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = unsafe { self.map_unchecked_mut(|x| x.as_first_infallible()) }; + this.poll_write(cx, buf) + } + #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = unsafe { self.map_unchecked_mut(|x| x.as_first_infallible()) }; + this.poll_flush(cx) + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = unsafe { self.map_unchecked_mut(|x| x.as_first_infallible()) }; + this.poll_shutdown(cx) + } + }*/ + + impl<'a, U> AsyncWrite for EitherWrite<'a, DeadSink, U> + where U: AsyncWrite + Unpin + 'a + { + #[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = unsafe { self.map_unchecked_mut(|x| x.as_second_infallible()) }; + this.poll_write(cx, buf) + } + #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = unsafe { self.map_unchecked_mut(|x| x.as_second_infallible()) }; + this.poll_flush(cx) + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = unsafe { self.map_unchecked_mut(|x| x.as_second_infallible()) }; + this.poll_shutdown(cx) + } + } + + + /// An `Infallible` type for `AsyncWrite` & `AsyncRead` + #[derive(Debug)] + pub enum DeadSink { } + + impl AsyncWrite for DeadSink + { + #[inline] fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unreachable!(); + } + #[inline] fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll> { + unreachable!(); + + } + #[inline] fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unreachable!(); + + } + } + + impl AsyncRead for DeadSink + { + #[inline] fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &mut [u8]) -> Poll> { + unreachable!(); + } + } +} diff --git a/src/main.rs b/src/main.rs index 6b304b4..26ac098 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +#![feature(never_type)] #![allow(dead_code)] @@ -101,7 +102,7 @@ async fn normal(cfg: config::Config) -> eyre::Result<()> // Some(path) -> prealloc Some((stream_fut, None)) => { let stream = stream_fut.await?; - serial::write_async(stream, &graph).await + serial::write_async(stream, &graph, serial::compress::No).await .wrap_err(eyre!("Failed to serialise graph to stream"))?; }, #[cfg(feature="prealloc")] Some((_task_fut, Some(output_file))) => { diff --git a/src/serial.rs b/src/serial.rs index 364b4d6..d8c2c96 100644 --- a/src/serial.rs +++ b/src/serial.rs @@ -12,16 +12,71 @@ type Decompressor = BzDecoder; const DEFER_DROP_SIZE_FLOOR: usize = 1024 * 1024; // 1 MB +pub trait Compression +{ + type OutputStream: AsyncWrite + Unpin; + type InputStream: AsyncRead + Unpin; + + fn create_output(from: W) -> Result; + fn create_input(from: W) -> Result; +} + +pub mod compress +{ + + use super::*; + /// No compression. + #[derive(Debug)] + pub struct No; + + impl Compression for No + { + type OutputStream = DeadSink; + type InputStream = DeadSink; + + fn create_input(from: W) -> Result { + Err(from) + } + fn create_output(from: W) -> Result { + Err(from) + } + } + + #[derive(Debug)] + pub struct Bz; + + + impl Compression for Bz + { + type OutputStream = Box; + type InputStream = Box; + + fn create_input(from: W) -> Result { + panic!() + } + fn create_output(from: W) -> Result { + Ok(Box::new(super::Compressor::new(from))) + } + } + +} + +#[inline] fn _type_name(_val: &T) -> &'static str { + std::any::type_name::() +} + /// Serialise this object asynchronously /// /// # Note /// This compresses the output stream. /// It cannot be used by `prealloc` read/write functions, as they do not compress. -pub async fn write_async(mut to: W, item: &T) -> eyre::Result<()> -where W: AsyncWrite + Unpin +pub async fn write_async(mut to: impl AsyncWrite + Unpin, item: &impl Serialize, _comp: Compress) -> eyre::Result<()> { - let sect_type_name = || std::any::type_name::().header("Type trying to serialise was"); - let sect_stream_type_name = || std::any::type_name::().header("Stream type was"); + let name_of_item = _type_name(item); + let name_of_stream = _type_name(&to); + + let sect_type_name = || name_of_item.header("Type trying to serialise was"); + let sect_stream_type_name = || name_of_stream.header("Stream type was"); let vec = tokio::task::block_in_place(|| serde_cbor::to_vec(item)) .wrap_err(eyre!("Failed to serialise item")) @@ -29,9 +84,9 @@ where W: AsyncWrite + Unpin .with_section(sect_type_name.clone())?; { - let mut stream = Compressor::new(&mut to); + let mut stream: EitherWrite<_, _> = Compress::create_output(&mut to).into();//Compressor::new(&mut to); - cfg_eprintln!(Verbose; config::get_global(), "Writing {} bytes of type {:?} to stream of type {:?}", vec.len(), std::any::type_name::(), std::any::type_name::()); + cfg_eprintln!(Verbose; config::get_global(), "Writing {} bytes of type {:?} to stream of type {:?}", vec.len(), name_of_item, name_of_stream); stream.write_all(&vec[..]) .await