diff --git a/Cargo.toml b/Cargo.toml index 79a15b5..5a1a5dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ generational-arena = "0.2.8" jemallocator = "0.3.2" lazy_static = "1.4.0" memmap = "0.7.0" -pin-project = "1.0.2" +pin-project = "1.0.7" rand = "0.8.3" serde = {version = "1.0.118", features= ["derive"]} serde_cbor = "0.11.1" diff --git a/src/ext/mod.rs b/src/ext/mod.rs index 832c4c3..196b372 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -32,12 +32,33 @@ pub use lag::*; mod defer_drop; pub use defer_drop::*; +pub mod sync; + +pub mod plex; +pub use plex::MultiplexStreamExt; + // The extension traits are defined in this file, no need to re-export anything from here. pub mod chunking; /// How many elements should `precollect` allocate on the stack before spilling to the heap. pub const PRECOLLECT_STACK_SIZE: usize = 64; +/// Implement `Default` for a type. +#[macro_export] macro_rules! default { + ($ty:ty: $ex:expr) => { + + impl Default for $ty + { + #[inline] + fn default() -> Self + { + $ex + } + } + } +} + + /// Create a duration with time suffix `h`, `m`, `s`, `ms` or `ns`. /// /// # Combination @@ -197,38 +218,38 @@ pub const PRECOLLECT_STACK_SIZE: usize = 64; ($vis:vis ref $name:ident -> $ty:ty => $internal:ident $(; $comment:literal)?) => { $(#[doc=$comment])? #[inline] $vis fn $name(&self) -> &$ty { - &self.$internal - } + &self.$internal + } }; ($vis:vis ref $name:ident -> $ty:ty => $internal:tt $(; $comment:literal)?) => { $(#[doc=$comment])? #[inline] $vis fn $name(&self) -> &$ty { - &self.$internal - } + &self.$internal + } }; ($vis:vis mut $name:ident -> $ty:ty => $internal:ident $(; $comment:literal)?) => { $(#[doc=$comment])? #[inline] $vis fn $name(&self) -> &mut $ty { - &mut self.$internal - } + &mut self.$internal + } }; ($vis:vis mut $name:ident -> $ty:ty => $internal:tt $(; $comment:literal)?) => { $(#[doc=$comment])? #[inline] $vis fn $name(&self) -> &mut $ty { - &mut self.$internal - } + &mut self.$internal + } }; ($vis:vis move $name:ident -> $ty:ty => $internal:ident $(; $comment:literal)?) => { $(#[doc=$comment])? #[inline] $vis fn $name(&self) -> $ty { - self.$internal - } + self.$internal + } }; ($vis:vis move $name:ident -> $ty:ty => $internal:tt $(; $comment:literal)?) => { $(#[doc=$comment])? #[inline] $vis fn $name(&self) -> $ty { - self.$internal - } + self.$internal + } }; } diff --git a/src/ext/plex.rs b/src/ext/plex.rs new file mode 100644 index 0000000..bdb11e0 --- /dev/null +++ b/src/ext/plex.rs @@ -0,0 +1,465 @@ +//! Multiplexing +use super::*; +use std::io; +use futures::prelude::*; +use std::marker::PhantomData; +use tokio::io::{ + AsyncWrite, +}; +use std::{ + task::{Context, Poll}, + pin::Pin, +}; +use std::{fmt, error}; + +/// For a `WriteRule::compare_byte_sizes()` implementation that never errors. +#[derive(Debug)] +pub enum CompareInfallible{} + +impl error::Error for CompareInfallible{} +impl fmt::Display for CompareInfallible +{ + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result + { + match *self {} + } +} + +impl From for io::Error +{ + fn from(from: CompareInfallible) -> Self + { + match from {} + } +} + + +/// Static rules for a `MultiplexingStream` to use. +pub trait WriteRule +{ + type CompareFailedError: Into = CompareInfallible; + /// When a successful poll to both streams and the number of bytes written differs, chooses which to return. + /// + /// # Errors + /// You can also choose to return an error if the sizes do not match. + /// The error must be convertable to `io::Error`. + /// By default, the error is `CompareInfallible`, which is a never-type alias that implements `Into` for convenience and better optimisations over using a non-vacant error type that is never returned. + /// + /// If you are using an error type that may be returned, set the `CompareFailedError` to the error type you choose, as long as it implements `Into` (or simply set it to `io::Error`, the generic conversion exists just to allow for using a vacant type here when an error will never be returned by this function.) + /// + /// # Default + /// The default is to return the **lowest** number of bytes written. + #[inline(always)] fn compare_byte_sizes(a: usize, b: usize) -> Result + { + Ok(std::cmp::min(a, b)) + } +} + +/// The default `WriteRule` static rules for a `MultiplexingStream` to follow. +#[derive(Debug)] +pub struct DefaultWriteRule(()); + +impl WriteRule for DefaultWriteRule{} + +// When one completes but not the other, we set this enum to the completed one. Then, we keep polling the other until it also completes. After that, this is reset to `None`. +#[derive(Debug)] +enum StatRes +{ + First(io::Result), + Second(io::Result), + None, +} + +impl From> for Option> +{ + #[inline] fn from(from: StatRes) -> Self + { + match from { + StatRes::First(r) | + StatRes::Second(r) => Some(r), + _ => None + } + } +} + +impl StatRes +{ + /// Does this stat have a result? + pub fn has_result(&self) -> bool + { + if let Self::None = self { + false + } else { + true + } + } +} + +impl Default for StatRes +{ + #[inline] + fn default() -> Self + { + Self::None + } +} + +type StatWrite = StatRes; +type StatFlush = StatRes<()>; +type StatShutdown = StatRes<()>; + +#[derive(Debug)] +struct Stat +{ + write: StatWrite, + flush: StatFlush, + shutdown: StatShutdown, +} + +default!(Stat: Self { + write: Default::default(), + flush: Default::default(), + shutdown: Default::default(), +}); + +/// An `AsyncWrite` stream that dispatches its writes to 2 outputs +/// +/// # Notes +/// If the backing stream's `write` impls provide different results for the number of bytes written, which to return is determined by the `Rule` parameter's `compare_byte_sizes()` function. +#[pin_project] +#[derive(Debug)] +pub struct MultiplexWrite +{ + #[pin] s1: T, + #[pin] s2: U, + + // `Stat` is big, box it. + stat: Box, + + _rule: PhantomData, +} + +impl MultiplexWrite +where T: AsyncWrite, + U: AsyncWrite +{ + /// Create a new `AsyncWrite` multiplexer + /// + /// The default static write rule will be used + #[inline(always)] pub fn new(s1: T, s2: U) -> Self + { + Self::new_ruled(s1, s2) + } +} +impl MultiplexWrite +where T: AsyncWrite, + U: AsyncWrite +{ + /// Create a new `AsyncWrite` multiplexer with a specified static write rule. + #[inline] pub fn new_ruled(s1: T, s2: U) -> Self + { + Self { + s1, s2, stat: Box::new(Default::default()), + + _rule: PhantomData + } + } + /// Consume into the 2 backing streams + #[inline] pub fn into_inner(self) -> (T, U) + { + (self.s1, self.s2) + } + /// Chain to another `AsyncWrite` stream + #[inline] pub fn chain(self, s3: V) -> MultiplexWrite + { + MultiplexWrite::new(self, s3) + } + + /// References to the inner streams + #[inline] pub fn streams(&self) -> (&T, &U) + { + (&self.s1, &self.s2) + } + /// Mutable references to the inner streams + #[inline] pub fn streams_mut(&mut self) -> (&mut T, &mut U) + { + (&mut self.s1, &mut self.s2) + } + + /// Extension method for `write_all` that ensures both streams have the same number of bytes written. + #[inline] pub async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> + where T: Unpin, U: Unpin + { + use tokio::prelude::*; + let (s1, s2) = self.streams_mut(); + + let (r1, r2) = tokio::join![ + s1.write_all(buf), + s2.write_all(buf), + ]; + r1?; + r2?; + Ok(()) + } +} + +impl UniplexWrite +{ + /// Create a `MultiplexWrite` with only one output. + /// + /// The default static write rule will be used. + #[inline] pub fn single(s1: T) -> Self + { + Self::new(s1, tokio::io::sink()) + } + +} +impl UniplexWrite +{ + /// Create a `MultiplexWrite` with only one output with a specified static write rule. + #[inline] pub fn single_ruled(s1: T) -> Self + { + Self::new_ruled(s1, tokio::io::sink()) + } + + /// Add a second output to this writer + #[inline] pub fn into_multi(self, s2: U) -> MultiplexWrite + { + MultiplexWrite::new_ruled(self.s1, s2) + } +} + +/// A `MultiplexWrite` with only 1 output. +pub type UniplexWrite = MultiplexWrite; + +impl AsyncWrite for MultiplexWrite +where T: AsyncWrite, U: AsyncWrite +{ + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = self.project(); + let (r1, r2) = match std::mem::replace(&mut this.stat.write, StatRes::None) { + StatRes::First(r1) => { + let r2 = this.s2.poll_write(cx, buf); + + (Poll::Ready(r1), r2) + }, + StatRes::Second(r2) => { + let r1 = this.s1.poll_write(cx, buf); + + (r1, Poll::Ready(r2)) + } + StatRes::None => { + let r1 = this.s1.poll_write(cx, buf); + let r2 = this.s2.poll_write(cx, buf); + + (r1, r2) + } + }; + + match (r1, r2) { + (Poll::Ready(r1), Poll::Ready(r2)) => { + // Both ready. Return result that has the most bytes written. + // Note: No need to update `stat` for this branch, it already has been set to `None` in the above match expr. + return Poll::Ready(Rule::compare_byte_sizes(r1?, r2?).map_err(Into::into)); + }, + (Poll::Ready(r1), _) => { + // First ready. Update stat to first + this.stat.write = StatRes::First(r1); + }, + (_, Poll::Ready(r2)) => { + // Second ready. Update stat to second + this.stat.write = StatRes::Second(r2); + } + // Both are pending, re-poll both next time (as `stat.write` is set to `None`). + _ => () + } + Poll::Pending + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let (r1, r2) = match std::mem::replace(&mut this.stat.flush, StatRes::None) { + StatRes::First(r1) => { + let r2 = this.s2.poll_flush(cx); + + (Poll::Ready(r1), r2) + }, + StatRes::Second(r2) => { + let r1 = this.s1.poll_flush(cx); + + (r1, Poll::Ready(r2)) + } + StatRes::None => { + let r1 = this.s1.poll_flush(cx); + let r2 = this.s2.poll_flush(cx); + + (r1, r2) + } + }; + + match (r1, r2) { + (Poll::Ready(r1), Poll::Ready(r2)) => { + // Both ready. + // Note: No need to update `stat` for this branch, it already has been set to `None` in the above match expr. + r1?; + r2?; + return Poll::Ready(Ok(())); + }, + (Poll::Ready(r1), _) => { + // First ready. Update stat to first + this.stat.flush = StatRes::First(r1); + }, + (_, Poll::Ready(r2)) => { + // Second ready. Update stat to second + this.stat.flush = StatRes::Second(r2); + } + // Both are pending, re-poll both next time (as `stat.flush` is set to `None`). + _ => () + } + Poll::Pending + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let (r1, r2) = match std::mem::replace(&mut this.stat.shutdown, StatRes::None) { + StatRes::First(r1) => { + let r2 = this.s2.poll_shutdown(cx); + + (Poll::Ready(r1), r2) + }, + StatRes::Second(r2) => { + let r1 = this.s1.poll_shutdown(cx); + + (r1, Poll::Ready(r2)) + } + StatRes::None => { + let r1 = this.s1.poll_shutdown(cx); + let r2 = this.s2.poll_shutdown(cx); + + (r1, r2) + } + }; + + match (r1, r2) { + (Poll::Ready(r1), Poll::Ready(r2)) => { + // Both ready. + // Note: No need to update `stat` for this branch, it already has been set to `None` in the above match expr. + r1?; + r2?; + return Poll::Ready(Ok(())); + }, + (Poll::Ready(r1), _) => { + // First ready. Update stat to first + this.stat.shutdown = StatRes::First(r1); + }, + (_, Poll::Ready(r2)) => { + // Second ready. Update stat to second + this.stat.shutdown = StatRes::Second(r2); + } + // Both are pending, re-poll both next time (as `stat.shutdown` is set to `None`). + _ => () + } + Poll::Pending + } +} + +impl From<(T, U)> for MultiplexWrite +where T: AsyncWrite, + U: AsyncWrite, +{ + #[inline] fn from((s1, s2): (T, U)) -> Self + { + Self::new(s1, s2) + } +} + +impl From> for (T, U) +where T: AsyncWrite, + U: AsyncWrite, +{ + fn from(from: MultiplexWrite) -> Self + { + from.into_inner() + } +} + +pub trait MultiplexStreamExt: Sized + AsyncWrite +{ + /// Create a multi-outputting `AsyncWrite` stream writing to both this an another with a static write rule. + fn multiplex_ruled(self, other: T) -> MultiplexWrite; + + /// Create a multi-outputting `AsyncWrite` stream writing to both this an another. + #[inline(always)] fn multiplex(self, other: T) -> MultiplexWrite + { + self.multiplex_ruled::(other) + } +} + +impl MultiplexStreamExt for S +{ + #[inline] fn multiplex_ruled(self, other: T) -> MultiplexWrite { + MultiplexWrite::new_ruled(self, other) + } +} + +#[cfg(test)] +mod tests +{ + use tokio::prelude::*; + use super::{ + MultiplexWrite, + UniplexWrite, + }; + + #[tokio::test] + async fn mp_write_all() + { + const INPUT: &'static str = "Hello world."; + let mut o1 = Vec::new(); + let mut o2 = Vec::new(); + + { + let mut mp = MultiplexWrite::new(&mut o1, &mut o2); + mp.write_all(INPUT.as_bytes()).await.expect("mp write failed"); + mp.flush().await.expect("mp flush"); + mp.shutdown().await.expect("mp shutdown"); + } + + assert_eq!(o1.len(), o2.len()); + assert_eq!(&o1[..], INPUT.as_bytes()); + assert_eq!(&o2[..], INPUT.as_bytes()); + } + #[tokio::test] + async fn multiplex() + { + const INPUT: &'static str = "Hello world."; + let mut o1 = Vec::new(); + let mut o2 = Vec::new(); + + { + let mut mp = MultiplexWrite::new(&mut o1, &mut o2); + assert_eq!(mp.write(INPUT.as_bytes()).await.expect("mp write failed"), INPUT.len()); + mp.flush().await.expect("mp flush"); + mp.shutdown().await.expect("mp shutdown"); + } + + assert_eq!(o1.len(), o2.len()); + assert_eq!(&o1[..], INPUT.as_bytes()); + assert_eq!(&o2[..], INPUT.as_bytes()); + } + #[tokio::test] + async fn uniplex() + { + const INPUT: &'static str = "Hello world."; + let mut o1 = Vec::new(); + + { + let mut mp = UniplexWrite::single(&mut o1); + assert_eq!(mp.write(INPUT.as_bytes()).await.expect("mp write failed"), INPUT.len()); + mp.flush().await.expect("mp flush"); + mp.shutdown().await.expect("mp shutdown"); + } + + assert_eq!(&o1[..], INPUT.as_bytes()); + } + +} + diff --git a/src/ext/sync.rs b/src/ext/sync.rs new file mode 100644 index 0000000..ac5f538 --- /dev/null +++ b/src/ext/sync.rs @@ -0,0 +1,97 @@ +//! Sync utils +use super::*; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::mem::MaybeUninit; +use std::cell::UnsafeCell; + +#[derive(Debug)] +struct SharedInitialiser +{ + data: Arc<(UnsafeCell>, AtomicBool)>, +} + +impl Clone for SharedInitialiser +{ + #[inline] fn clone(&self) -> Self { + Self { data: Arc::clone(&self.data) } + } +} + +#[derive(Debug)] +pub struct SharedInitRx(SharedInitialiser); + +/// Permits initialising across a thread. +// Do we even need this? No.. We can just use `tokio::sync::oneshot`.... +#[derive(Debug)] +pub struct SharedInitTx(SharedInitialiser); + +impl SharedInitTx +{ + /// Consume this instance and initialise it. + /// + /// # Panics + /// If there is already a value set (this should never happen). + pub fn initialise(self, value: T) + { + todo!() + } +} + +impl SharedInitRx +{ + /// Create a sender and receiver pair + pub fn pair() -> (SharedInitTx, Self) + { + let this = Self::new(); + (this.create_tx(), this) + } + /// Create a new, uninitialised receiver. + #[inline] fn new() -> Self + { + Self(SharedInitialiser{data: Arc::new((UnsafeCell::new(MaybeUninit::uninit()), false.into()))}) + } + /// Create an initialiser + /// + /// # Panics (debug) + /// If an initialiser already exists + #[inline] fn create_tx(&self) -> SharedInitTx + { + debug_assert_eq!(Arc::strong_count(&self.0.data), 1, "Sender already exists"); + + SharedInitTx(self.0.clone()) + } + + /// Checks if there is a value present, or if it is possible for a value to be present. + pub fn is_pending(&self) -> bool + { + todo!("Either self.0.data.1 == true, OR, strong_count() == 2") + } + + /// Has a value already been set + pub fn has_value(&self) -> bool + { + todo!("self.0.data.1 == true") + } + + /// Try to consume into the initialised value. + pub fn try_into_value(self) -> Result + { + todo!() + } + + /// Consume into the initialised value + /// + /// # Panics + /// If the value hasn't been initialised + #[inline] pub fn into_value(self) -> T + { + self.try_into_value().map_err(|_| "No initialised value present").unwrap() + } + + /// Does this receiver have an initialser that hasn't yet produced a value? + pub fn has_initialiser(&self) -> bool + { + Arc::strong_count(&self.0.data) == 2 + } +} diff --git a/src/main.rs b/src/main.rs index 23a61fe..e72d81b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ #![feature(never_type)] +#![feature(associated_type_defaults)] #![feature(asm)] #![allow(dead_code)] @@ -7,6 +8,7 @@ #[macro_use] extern crate ad_hoc_iter; #[macro_use] extern crate bitflags; #[macro_use] extern crate lazy_static; +#[macro_use] extern crate pin_project; use serde::{Serialize, Deserialize}; diff --git a/src/service/config.rs b/src/service/config.rs index ae60eda..f46a5c1 100644 --- a/src/service/config.rs +++ b/src/service/config.rs @@ -23,14 +23,7 @@ pub enum StopDirective Error, } -impl Default for StopDirective -{ - #[inline] - fn default() -> Self - { - Self::Error - } -} +default!(StopDirective: Self::Error); /// Settings for how a background service runs #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Serialize, Deserialize)] diff --git a/src/service/host.rs b/src/service/host.rs index a77539e..f164f19 100644 --- a/src/service/host.rs +++ b/src/service/host.rs @@ -5,7 +5,7 @@ use futures::prelude::*; pub type SupervisorError = (); //TODO pub type Error = (); // TODO -pub fn spawn_supervisor() -> JoinHandle> +pub(super) fn spawn_supervisor(service: Service) -> JoinHandle> { tokio::spawn(async move { //TODO: Spawn slave and handle its exiting, restarting, etc according to config @@ -13,8 +13,9 @@ pub fn spawn_supervisor() -> JoinHandle> }) } -fn spawn_slave(rx: mpsc::Receiver<()>) -> JoinHandle> +fn spawn_slave(service: Service) -> JoinHandle> { + let Service { inner: service, rx } = service; tokio::spawn(async move { let mut rx = rx @@ -26,7 +27,7 @@ fn spawn_slave(rx: mpsc::Receiver<()>) -> JoinHandle> block = rx.next() => { match block { Some(block) => { - // TODO: Process this block + // TODO: Filter and/or process this block }, None => { // Reached the end of stream, exit gracefully. diff --git a/src/service/mod.rs b/src/service/mod.rs index dc7a442..1a5d693 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -75,6 +75,15 @@ pub struct Channel { tx: mpsc::Sender, } +/// The service's counterpart to `Channel`. Contains the metadata `ChannelInner` and the receiver for `Channel`s. +#[derive(Debug)] +struct Service +{ + inner: Arc, + + rx: mpsc::Receiver, +} + impl Eq for Channel{} impl PartialEq for Channel { #[inline] fn eq(&self, other: &Self) -> bool