You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

466 lines
12 KiB

//! Multiplexing
use super::*;
use std::io;
use futures::prelude::*;
use std::marker::PhantomData;
use tokio::io::{
AsyncWrite,
};
use std::{
task::{Context, Poll},
pin::Pin,
};
use std::{fmt, error};
/// For a `WriteRule::compare_byte_sizes()` implementation that never errors.
#[derive(Debug)]
pub enum CompareInfallible{}
impl error::Error for CompareInfallible{}
impl fmt::Display for CompareInfallible
{
fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result
{
match *self {}
}
}
impl From<CompareInfallible> for io::Error
{
fn from(from: CompareInfallible) -> Self
{
match from {}
}
}
/// Static rules for a `MultiplexingStream` to use.
pub trait WriteRule
{
type CompareFailedError: Into<io::Error> = CompareInfallible;
/// When a successful poll to both streams and the number of bytes written differs, chooses which to return.
///
/// # Errors
/// You can also choose to return an error if the sizes do not match.
/// The error must be convertable to `io::Error`.
/// By default, the error is `CompareInfallible`, which is a never-type alias that implements `Into<io::Error>` for convenience and better optimisations over using a non-vacant error type that is never returned.
///
/// If you are using an error type that may be returned, set the `CompareFailedError` to the error type you choose, as long as it implements `Into<sio::Error>` (or simply set it to `io::Error`, the generic conversion exists just to allow for using a vacant type here when an error will never be returned by this function.)
///
/// # Default
/// The default is to return the **lowest** number of bytes written.
#[inline(always)] fn compare_byte_sizes(a: usize, b: usize) -> Result<usize, Self::CompareFailedError>
{
Ok(std::cmp::min(a, b))
}
}
/// The default `WriteRule` static rules for a `MultiplexingStream` to follow.
#[derive(Debug)]
pub struct DefaultWriteRule(());
impl WriteRule for DefaultWriteRule{}
// When one completes but not the other, we set this enum to the completed one. Then, we keep polling the other until it also completes. After that, this is reset to `None`.
#[derive(Debug)]
enum StatRes<T>
{
First(io::Result<T>),
Second(io::Result<T>),
None,
}
impl<T> From<StatRes<T>> for Option<io::Result<T>>
{
#[inline] fn from(from: StatRes<T>) -> Self
{
match from {
StatRes::First(r) |
StatRes::Second(r) => Some(r),
_ => None
}
}
}
impl<T> StatRes<T>
{
/// Does this stat have a result?
pub fn has_result(&self) -> bool
{
if let Self::None = self {
false
} else {
true
}
}
}
impl<T> Default for StatRes<T>
{
#[inline]
fn default() -> Self
{
Self::None
}
}
type StatWrite = StatRes<usize>;
type StatFlush = StatRes<()>;
type StatShutdown = StatRes<()>;
#[derive(Debug)]
struct Stat
{
write: StatWrite,
flush: StatFlush,
shutdown: StatShutdown,
}
default!(Stat: Self {
write: Default::default(),
flush: Default::default(),
shutdown: Default::default(),
});
/// An `AsyncWrite` stream that dispatches its writes to 2 outputs
///
/// # Notes
/// If the backing stream's `write` impls provide different results for the number of bytes written, which to return is determined by the `Rule` parameter's `compare_byte_sizes()` function.
#[pin_project]
#[derive(Debug)]
pub struct MultiplexWrite<T,U, Rule: WriteRule + ?Sized = DefaultWriteRule>
{
#[pin] s1: T,
#[pin] s2: U,
// `Stat` is big, box it.
stat: Box<Stat>,
_rule: PhantomData<Rule>,
}
impl<T,U> MultiplexWrite<T, U>
where T: AsyncWrite,
U: AsyncWrite
{
/// Create a new `AsyncWrite` multiplexer
///
/// The default static write rule will be used
#[inline(always)] pub fn new(s1: T, s2: U) -> Self
{
Self::new_ruled(s1, s2)
}
}
impl<T,U, Rule: WriteRule + ?Sized> MultiplexWrite<T, U, Rule>
where T: AsyncWrite,
U: AsyncWrite
{
/// Create a new `AsyncWrite` multiplexer with a specified static write rule.
#[inline] pub fn new_ruled(s1: T, s2: U) -> Self
{
Self {
s1, s2, stat: Box::new(Default::default()),
_rule: PhantomData
}
}
/// Consume into the 2 backing streams
#[inline] pub fn into_inner(self) -> (T, U)
{
(self.s1, self.s2)
}
/// Chain to another `AsyncWrite` stream
#[inline] pub fn chain<V: AsyncWrite>(self, s3: V) -> MultiplexWrite<Self, V>
{
MultiplexWrite::new(self, s3)
}
/// References to the inner streams
#[inline] pub fn streams(&self) -> (&T, &U)
{
(&self.s1, &self.s2)
}
/// Mutable references to the inner streams
#[inline] pub fn streams_mut(&mut self) -> (&mut T, &mut U)
{
(&mut self.s1, &mut self.s2)
}
/// Extension method for `write_all` that ensures both streams have the same number of bytes written.
#[inline] pub async fn write_all(&mut self, buf: &[u8]) -> io::Result<()>
where T: Unpin, U: Unpin
{
use tokio::prelude::*;
let (s1, s2) = self.streams_mut();
let (r1, r2) = tokio::join![
s1.write_all(buf),
s2.write_all(buf),
];
r1?;
r2?;
Ok(())
}
}
impl<T: AsyncWrite> UniplexWrite<T>
{
/// Create a `MultiplexWrite` with only one output.
///
/// The default static write rule will be used.
#[inline] pub fn single(s1: T) -> Self
{
Self::new(s1, tokio::io::sink())
}
}
impl<T: AsyncWrite, Rule: WriteRule + ?Sized> UniplexWrite<T, Rule>
{
/// Create a `MultiplexWrite` with only one output with a specified static write rule.
#[inline] pub fn single_ruled(s1: T) -> Self
{
Self::new_ruled(s1, tokio::io::sink())
}
/// Add a second output to this writer
#[inline] pub fn into_multi<U: AsyncWrite>(self, s2: U) -> MultiplexWrite<T, U, Rule>
{
MultiplexWrite::new_ruled(self.s1, s2)
}
}
/// A `MultiplexWrite` with only 1 output.
pub type UniplexWrite<T, Rule = DefaultWriteRule> = MultiplexWrite<T, tokio::io::Sink, Rule>;
impl<T, U, Rule: WriteRule + ?Sized> AsyncWrite for MultiplexWrite<T, U, Rule>
where T: AsyncWrite, U: AsyncWrite
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
let this = self.project();
let (r1, r2) = match std::mem::replace(&mut this.stat.write, StatRes::None) {
StatRes::First(r1) => {
let r2 = this.s2.poll_write(cx, buf);
(Poll::Ready(r1), r2)
},
StatRes::Second(r2) => {
let r1 = this.s1.poll_write(cx, buf);
(r1, Poll::Ready(r2))
}
StatRes::None => {
let r1 = this.s1.poll_write(cx, buf);
let r2 = this.s2.poll_write(cx, buf);
(r1, r2)
}
};
match (r1, r2) {
(Poll::Ready(r1), Poll::Ready(r2)) => {
// Both ready. Return result that has the most bytes written.
// Note: No need to update `stat` for this branch, it already has been set to `None` in the above match expr.
return Poll::Ready(Rule::compare_byte_sizes(r1?, r2?).map_err(Into::into));
},
(Poll::Ready(r1), _) => {
// First ready. Update stat to first
this.stat.write = StatRes::First(r1);
},
(_, Poll::Ready(r2)) => {
// Second ready. Update stat to second
this.stat.write = StatRes::Second(r2);
}
// Both are pending, re-poll both next time (as `stat.write` is set to `None`).
_ => ()
}
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
let (r1, r2) = match std::mem::replace(&mut this.stat.flush, StatRes::None) {
StatRes::First(r1) => {
let r2 = this.s2.poll_flush(cx);
(Poll::Ready(r1), r2)
},
StatRes::Second(r2) => {
let r1 = this.s1.poll_flush(cx);
(r1, Poll::Ready(r2))
}
StatRes::None => {
let r1 = this.s1.poll_flush(cx);
let r2 = this.s2.poll_flush(cx);
(r1, r2)
}
};
match (r1, r2) {
(Poll::Ready(r1), Poll::Ready(r2)) => {
// Both ready.
// Note: No need to update `stat` for this branch, it already has been set to `None` in the above match expr.
r1?;
r2?;
return Poll::Ready(Ok(()));
},
(Poll::Ready(r1), _) => {
// First ready. Update stat to first
this.stat.flush = StatRes::First(r1);
},
(_, Poll::Ready(r2)) => {
// Second ready. Update stat to second
this.stat.flush = StatRes::Second(r2);
}
// Both are pending, re-poll both next time (as `stat.flush` is set to `None`).
_ => ()
}
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
let (r1, r2) = match std::mem::replace(&mut this.stat.shutdown, StatRes::None) {
StatRes::First(r1) => {
let r2 = this.s2.poll_shutdown(cx);
(Poll::Ready(r1), r2)
},
StatRes::Second(r2) => {
let r1 = this.s1.poll_shutdown(cx);
(r1, Poll::Ready(r2))
}
StatRes::None => {
let r1 = this.s1.poll_shutdown(cx);
let r2 = this.s2.poll_shutdown(cx);
(r1, r2)
}
};
match (r1, r2) {
(Poll::Ready(r1), Poll::Ready(r2)) => {
// Both ready.
// Note: No need to update `stat` for this branch, it already has been set to `None` in the above match expr.
r1?;
r2?;
return Poll::Ready(Ok(()));
},
(Poll::Ready(r1), _) => {
// First ready. Update stat to first
this.stat.shutdown = StatRes::First(r1);
},
(_, Poll::Ready(r2)) => {
// Second ready. Update stat to second
this.stat.shutdown = StatRes::Second(r2);
}
// Both are pending, re-poll both next time (as `stat.shutdown` is set to `None`).
_ => ()
}
Poll::Pending
}
}
impl<T, U> From<(T, U)> for MultiplexWrite<T, U>
where T: AsyncWrite,
U: AsyncWrite,
{
#[inline] fn from((s1, s2): (T, U)) -> Self
{
Self::new(s1, s2)
}
}
impl<T,U> From<MultiplexWrite<T, U>> for (T, U)
where T: AsyncWrite,
U: AsyncWrite,
{
fn from(from: MultiplexWrite<T, U>) -> Self
{
from.into_inner()
}
}
pub trait MultiplexStreamExt: Sized + AsyncWrite
{
/// Create a multi-outputting `AsyncWrite` stream writing to both this an another with a static write rule.
fn multiplex_ruled<T: AsyncWrite, Rule: ?Sized + WriteRule>(self, other: T) -> MultiplexWrite<Self, T, Rule>;
/// Create a multi-outputting `AsyncWrite` stream writing to both this an another.
#[inline(always)] fn multiplex<T: AsyncWrite>(self, other: T) -> MultiplexWrite<Self, T>
{
self.multiplex_ruled::<T, DefaultWriteRule>(other)
}
}
impl<S: AsyncWrite> MultiplexStreamExt for S
{
#[inline] fn multiplex_ruled<T: AsyncWrite, Rule: ?Sized + WriteRule>(self, other: T) -> MultiplexWrite<Self, T, Rule> {
MultiplexWrite::new_ruled(self, other)
}
}
#[cfg(test)]
mod tests
{
use tokio::prelude::*;
use super::{
MultiplexWrite,
UniplexWrite,
};
#[tokio::test]
async fn mp_write_all()
{
const INPUT: &'static str = "Hello world.";
let mut o1 = Vec::new();
let mut o2 = Vec::new();
{
let mut mp = MultiplexWrite::new(&mut o1, &mut o2);
mp.write_all(INPUT.as_bytes()).await.expect("mp write failed");
mp.flush().await.expect("mp flush");
mp.shutdown().await.expect("mp shutdown");
}
assert_eq!(o1.len(), o2.len());
assert_eq!(&o1[..], INPUT.as_bytes());
assert_eq!(&o2[..], INPUT.as_bytes());
}
#[tokio::test]
async fn multiplex()
{
const INPUT: &'static str = "Hello world.";
let mut o1 = Vec::new();
let mut o2 = Vec::new();
{
let mut mp = MultiplexWrite::new(&mut o1, &mut o2);
assert_eq!(mp.write(INPUT.as_bytes()).await.expect("mp write failed"), INPUT.len());
mp.flush().await.expect("mp flush");
mp.shutdown().await.expect("mp shutdown");
}
assert_eq!(o1.len(), o2.len());
assert_eq!(&o1[..], INPUT.as_bytes());
assert_eq!(&o2[..], INPUT.as_bytes());
}
#[tokio::test]
async fn uniplex()
{
const INPUT: &'static str = "Hello world.";
let mut o1 = Vec::new();
{
let mut mp = UniplexWrite::single(&mut o1);
assert_eq!(mp.write(INPUT.as_bytes()).await.expect("mp write failed"), INPUT.len());
mp.flush().await.expect("mp flush");
mp.shutdown().await.expect("mp shutdown");
}
assert_eq!(&o1[..], INPUT.as_bytes());
}
}