diff --git a/src/chunking.rs b/src/chunking.rs index 47dcc30..0658f08 100644 --- a/src/chunking.rs +++ b/src/chunking.rs @@ -6,6 +6,7 @@ use std::{ Context, }, pin::Pin, + marker::PhantomData, }; use tokio::{ io::{ @@ -173,3 +174,86 @@ mod tests assert_eq!(&output[..], "Hello world\nHow are you"); } } + +/// A stream that chunks its input. +#[pin_project] +pub struct ChunkingStream> +{ + #[pin] stream: Fuse, + buf: Vec, + cap: usize, + _output: PhantomData, + + push_now: bool, +} + + +impl ChunkingStream +where S: Stream, + Into: From> +{ + pub fn new(stream: S, sz: usize) -> Self + { + Self { + stream: stream.fuse(), + buf: Vec::with_capacity(sz), + cap: sz, + _output: PhantomData, + push_now: false, + } + } + pub fn into_inner(self) -> S + { + self.stream.into_inner() + } + pub fn cap(&self) -> usize + { + self.cap + } + pub fn buffer(&self) -> &[T] + { + &self.buf[..] + } + + /// Force the next read to send the buffer even if it's not full. + /// + /// # Note + /// The buffer still won't send if it's empty. + pub fn push_now(&mut self) + { + self.push_now= true; + } +} + +impl Stream for ChunkingStream +where S: Stream, + Into: From> +{ + type Item = Into; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while !(self.push_now && !self.buf.is_empty()) && self.buf.len() < self.cap { + // Buffer isn't full, keep filling + let this = self.as_mut().project(); + + match this.stream.poll_next(cx) { + Poll::Ready(None) => { + // Stream is over + break; + }, + Poll::Ready(Some(item)) => { + this.buf.push(item); + }, + _ => return Poll::Pending, + } + } + // Buffer is full or we reach end of stream + Poll::Ready(if self.buf.len() == 0 { + None + } else { + let this = self.project(); + *this.push_now = false; + let output = std::mem::replace(this.buf, Vec::with_capacity(*this.cap)); + Some(output.into()) + }) + } +} diff --git a/src/ext.rs b/src/ext.rs index 5b05901..6ca144d 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -1,4 +1,5 @@ //! Extensions +use super::*; use std::{ iter, ops::{ @@ -162,3 +163,21 @@ impl DerefMut for AssertNotSend &mut self.0 } } + +pub trait ChunkStreamExt: Sized +{ + fn chunk_into>>(self, sz: usize) -> chunking::ChunkingStream; + fn chunk(self, sz: usize) -> chunking::ChunkingStream + { + self.chunk_into(sz) + } +} + +impl ChunkStreamExt for S +where S: Stream +{ + fn chunk_into>>(self, sz: usize) -> chunking::ChunkingStream + { + chunking::ChunkingStream::new(self, sz) + } +} diff --git a/src/feed.rs b/src/feed.rs index ef91bc5..166f0d1 100644 --- a/src/feed.rs +++ b/src/feed.rs @@ -2,8 +2,10 @@ use super::*; #[cfg(any(feature="feed-sentance", feature="split-sentance"))] use sanitise::Sentance; +use std::iter; -pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom = 2..; //TODO: Add to config somehow + +pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom = 2..; /// Feed `what` into `chain`, at least `bounds` tokens. /// @@ -73,12 +75,12 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream { + ($buffer:expr) => { { let buffer = $buffer; - feed($chain, &buffer, bounds) + state.chain_write(buffer.map(ToOwned::to_owned)).await.map_err(|_| FillBodyError)?; } } } @@ -101,15 +103,11 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream {:?}", who, buffer); - let mut chain = state.chain().write().await; cfg_if! { if #[cfg(feature="split-newlines")] { - for buffer in buffer.split('\n').filter(|line| !line.trim().is_empty()) { - feed!(&mut chain, buffer); - } + feed!(buffer.split('\n').filter(|line| !line.trim().is_empty())) } else { - feed!(&mut chain, buffer); - + feed!(iter::once(buffer)); } } } else { @@ -124,10 +122,10 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream {:?}", who, line); } written+=line.len(); diff --git a/src/gen.rs b/src/gen.rs index 5e75bcb..27575e0 100644 --- a/src/gen.rs +++ b/src/gen.rs @@ -1,34 +1,46 @@ //! Generating the strings use super::*; +use tokio::sync::mpsc::error::SendError; +use futures::StreamExt; -#[derive(Debug)] -pub struct GenBodyError(pub String); +#[derive(Debug, Default)] +pub struct GenBodyError(Option); impl error::Error for GenBodyError{} impl fmt::Display for GenBodyError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "failed to write {:?} to body", self.0) + if let Some(z) = &self.0 { + write!(f, "failed to write read string {:?} to body", z) + } else { + write!(f, "failed to read string from chain. it might be empty.") + } } } pub async fn body(state: State, num: Option, mut output: mpsc::Sender) -> Result<(), GenBodyError> { - let chain = state.chain().read().await; - if !chain.is_empty() { - let filter = state.outbound_filter(); - match num { - Some(num) if num < state.config().max_gen_size => { - //This could DoS `full_body` and writes, potentially. - for string in chain.str_iter_for(num) { - output.send(filter.filter_owned(string)).await.map_err(|e| GenBodyError(e.0))?; - } - }, - _ => output.send(filter.filter_owned(chain.generate_str())).await.map_err(|e| GenBodyError(e.0))?, - } + let mut chain = state.chain_read(); + let filter = state.outbound_filter(); + match num { + Some(num) if num < state.config().max_gen_size => { + let mut chain = chain.take(num); + while let Some(string) = chain.next().await { + output.send(filter.filter_owned(string)).await?; + } + }, + _ => output.send(filter.filter_owned(chain.next().await.ok_or_else(GenBodyError::default)?)).await?, } Ok(()) } + +impl From> for GenBodyError +{ + #[inline] fn from(from: SendError) -> Self + { + Self(Some(from.0)) + } +} diff --git a/src/handle.rs b/src/handle.rs new file mode 100644 index 0000000..0669c15 --- /dev/null +++ b/src/handle.rs @@ -0,0 +1,285 @@ +//! Chain handler. +use super::*; +use std::{ + marker::Send, + sync::Weak, + num::NonZeroUsize, + task::{Poll, Context,}, + pin::Pin, +}; +use tokio::{ + sync::{ + RwLock, + RwLockReadGuard, + mpsc::{ + self, + error::SendError, + }, + }, + task::JoinHandle, + time::{ + self, + Duration, + }, +}; +use futures::StreamExt; + +/// Settings for chain handler +#[derive(Debug, Clone, PartialEq)] +pub struct Settings +{ + pub backlog: usize, + pub capacity: usize, + pub timeout: Duration, + pub throttle: Option, + pub bounds: range::DynRange, +} + +impl Settings +{ + /// Should we keep this string. + #[inline] fn matches(&self, s: &str) -> bool + { + true + } +} + +impl Default for Settings +{ + #[inline] + fn default() -> Self + { + Self { + backlog: 32, + capacity: 4, + timeout: Duration::from_secs(5), + throttle: Some(Duration::from_millis(200)), + bounds: feed::DEFAULT_FEED_BOUNDS.into(), + } + } +} + + +#[derive(Debug)] +struct HostInner +{ + input: mpsc::Receiver>, +} + +#[derive(Debug)] +struct Handle +{ + chain: RwLock>, + input: mpsc::Sender>, + opt: Settings, + + /// Data used only for the worker task. + host: msg::Once>, +} + +#[derive(Clone, Debug)] +pub struct ChainHandle(Arc>>); + +impl ChainHandle +{ + #[inline] pub fn new(chain: chain::Chain) -> Self + { + Self::with_settings(chain, Default::default()) + } + pub fn with_settings(chain: chain::Chain, opt: Settings) -> Self + { + let (itx, irx) = mpsc::channel(opt.backlog); + Self(Arc::new(Box::new(Handle{ + chain: RwLock::new(chain), + input: itx, + opt, + + host: msg::Once::new(HostInner{ + input: irx, + }) + }))) + } + + /// Acquire the chain read lock + async fn chain(&self) -> RwLockReadGuard<'_, chain::Chain> + { + self.0.chain.read().await + } + + /// A reference to the chain + pub fn chain_ref(&self) -> &RwLock> + { + &self.0.chain + } + + /// Create a stream that reads generated values forever. + pub fn read(&self) -> ChainStream + { + ChainStream{ + chain: Arc::downgrade(&self.0), + buffer: Vec::with_capacity(self.0.opt.backlog), + } + } + + /// Send this buffer to the chain + pub async fn write(&self, buf: Vec) -> Result<(), SendError>> + { + self.0.input.clone().send(buf).await + } +} + +impl ChainHandle +{ + #[deprecated = "use read() pls"] + pub async fn generate_body(&self, state: &state::State, num: Option, mut output: mpsc::Sender) -> Result<(), SendError> + { + let chain = self.chain().await; + if !chain.is_empty() { + let filter = state.outbound_filter(); + match num { + Some(num) if num < state.config().max_gen_size => { + //This could DoS writes, potentially. + for string in chain.str_iter_for(num) { + output.send(filter.filter_owned(string)).await?; + } + }, + _ => output.send(filter.filter_owned(chain.generate_str())).await?, + } + } + Ok(()) + } +} + +/// Host this handle on the current task. +/// +/// # Panics +/// If `from` has already been hosted. +pub async fn host(from: ChainHandle) +{ + let opt = from.0.opt.clone(); + let data = from.0.host.unwrap().await; + + let (mut tx, child) = { + // The `real` input channel. + let from = from.clone(); + let opt = opt.clone(); + let (tx, rx) = mpsc::channel::>>(opt.backlog); + (tx, tokio::spawn(async move { + let mut rx = if let Some(thr) = opt.throttle { + time::throttle(thr, rx).boxed() + } else { + rx.boxed() + }; + trace!("child: Begin waiting on parent"); + while let Some(item) = rx.next().await { + let mut lock = from.0.chain.write().await; + for item in item.into_iter() + { + use std::ops::DerefMut; + for item in item.into_iter() { + feed::feed(lock.deref_mut(), item, &from.0.opt.bounds); + } + } + } + trace!("child: exiting"); + })) + }; + + trace!("Begin polling on child"); + tokio::select!{ + v = child => { + match v { + #[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen. + Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel. + } + }, + _ = async move { + let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh + + while Arc::strong_count(&from.0) > 2 { + tokio::select!{ + _ = time::delay_for(opt.timeout) => { + rx.push_now(); + } + Some(buffer) = rx.next() => { + if let Err(err) = tx.send(buffer).await { + // Receive closed? + // + // This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`. + // In any case, it means we should abort. + error!("Failed to send buffer: {}", err); + break; + } + } + } + } + } => { + // Normal exit + trace!("Normal exit") + }, + } + // No more handles except child, no more possible inputs. + trace!("Returning"); +} + +/// Spawn a new chain handler for this chain. +pub fn spawn(from: chain::Chain, opt: Settings) -> (JoinHandle<()>, ChainHandle) +{ + let handle = ChainHandle::with_settings(from, opt); + (tokio::spawn(host(handle.clone())), handle) +} + +#[derive(Debug)] +pub struct ChainStream +{ + chain: Weak>>, + buffer: Vec, +} + +impl ChainStream +{ + async fn try_pull(&mut self, n: usize) -> Option + { + if n == 0 { + return None; + } + if let Some(read) = self.chain.upgrade() { + let chain = read.chain.read().await; + if chain.is_empty() { + return None; + } + + let n = if n == 1 { + self.buffer.push(chain.generate_str()); + 1 + } else { + self.buffer.extend(chain.str_iter_for(n)); + n //for now + }; + Some(unsafe{NonZeroUsize::new_unchecked(n)}) + } else { + None + } + } +} + +impl Stream for ChainStream +{ + type Item = String; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use futures::Future; + let this = self.get_mut(); + + if this.buffer.len() == 0 { + let pull = this.try_pull(this.buffer.capacity()); + tokio::pin!(pull); + match pull.poll(cx) { + Poll::Ready(Some(_)) => {}, + Poll::Pending => return Poll::Pending, + _ => return Poll::Ready(None), + }; + } + debug_assert!(this.buffer.len()>0); + Poll::Ready(Some(this.buffer.remove(0))) + } +} diff --git a/src/main.rs b/src/main.rs index a0b6f30..547a2da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,6 +78,7 @@ use state::State; mod save; mod forwarded_list; use forwarded_list::XForwardedFor; +mod handle; mod feed; mod gen; @@ -134,7 +135,7 @@ async fn main() { debug!("Using config {:?}", config); trace!("With config cached: {:?}", ccache); - let chain = Arc::new(RwLock::new(match save::load(&config.file).await { + let (chain_handle, chain) = handle::spawn(match save::load(&config.file).await { Ok(chain) => { info!("Loaded chain from {:?}", config.file); chain @@ -144,7 +145,7 @@ async fn main() { trace!("Error: {}", e); Chain::new() }, - })); + }, Default::default()/*TODO*/); { let mut tasks = Vec::>::new(); let (state, chain) = { @@ -152,7 +153,7 @@ async fn main() { let state = State::new(config, ccache, - Arc::clone(&chain), + chain, Arc::clone(&save_when)); let state2 = state.clone(); let saver = tokio::spawn(save::host(Box::new(state.clone()))); diff --git a/src/msg.rs b/src/msg.rs index 42ba70b..cd9daaf 100644 --- a/src/msg.rs +++ b/src/msg.rs @@ -3,6 +3,7 @@ use super::*; use tokio::{ sync::{ watch, + Mutex, }, }; use std::{ @@ -12,7 +13,9 @@ use std::{ error, }; use futures::{ - future::Future, + future::{ + Future, + }, }; #[derive(Debug)] @@ -160,3 +163,48 @@ impl Future for Initialiser uhh.poll(cx) } } + +/// A value that can be consumed once. +#[derive(Debug)] +pub struct Once(Mutex>); + +impl Once +{ + /// Create a new instance + pub fn new(from: T) -> Self + { + Self(Mutex::new(Some(from))) + } + /// Consume into the instance from behind a potentially shared reference. + pub async fn consume_shared(self: Arc) -> Option + { + match Arc::try_unwrap(self) { + Ok(x) => x.0.into_inner(), + Err(x) => x.0.lock().await.take(), + } + } + + /// Consume from a shared reference and panic if the value has already been consumed. + pub async fn unwrap_shared(self: Arc) -> T + { + self.consume_shared().await.unwrap() + } + + /// Consume into the instance. + pub async fn consume(&self) -> Option + { + self.0.lock().await.take() + } + + /// Consume and panic if the value has already been consumed. + pub async fn unwrap(&self) -> T + { + self.consume().await.unwrap() + } + + /// Consume into the inner value + pub fn into_inner(self) -> Option + { + self.0.into_inner() + } +} diff --git a/src/sanitise/filter.rs b/src/sanitise/filter.rs index cbc896b..ba4f79c 100644 --- a/src/sanitise/filter.rs +++ b/src/sanitise/filter.rs @@ -272,6 +272,6 @@ mod tests let string = "abcdef ghi jk1\nhian"; assert_eq!(filter.filter_str(&string).to_string(), filter.filter_cow(&string).to_string()); - assert_eq!(filter.filter_cow(&string).to_string(), filter.filter(string.chars()).collect::()); + assert_eq!(filter.filter_cow(&string).to_string(), filter.filter_iter(string.chars()).collect::()); } } diff --git a/src/sanitise/sentance.rs b/src/sanitise/sentance.rs index edae602..dc4ee22 100644 --- a/src/sanitise/sentance.rs +++ b/src/sanitise/sentance.rs @@ -25,7 +25,7 @@ macro_rules! new { }; } -const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?']; +const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?', '~']; lazy_static! { static ref BOUNDARIES: smallmap::Map = { diff --git a/src/sanitise/word.rs b/src/sanitise/word.rs index c50a5fc..92f639d 100644 --- a/src/sanitise/word.rs +++ b/src/sanitise/word.rs @@ -25,7 +25,7 @@ macro_rules! new { }; } -const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ',']; +const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ',', '*']; lazy_static! { static ref BOUNDARIES: smallmap::Map = { diff --git a/src/save.rs b/src/save.rs index f083b0c..5f6b74c 100644 --- a/src/save.rs +++ b/src/save.rs @@ -43,7 +43,7 @@ type Decompressor = BzDecoder; pub async fn save_now(state: &State) -> io::Result<()> { - let chain = state.chain().read().await; + let chain = state.chain_ref().read().await; use std::ops::Deref; let to = &state.config().file; save_now_to(chain.deref(),to).await @@ -82,7 +82,7 @@ pub async fn host(mut state: Box) debug!("Begin save handler"); while Arc::strong_count(state.when()) > 1 { { - let chain = state.chain().read().await; + let chain = state.chain_ref().read().await; use std::ops::Deref; if let Err(e) = save_now_to(chain.deref(), &to).await { error!("Failed to save chain: {}", e); diff --git a/src/sentance.rs b/src/sentance.rs index 41dd16d..36d3fdf 100644 --- a/src/sentance.rs +++ b/src/sentance.rs @@ -1,17 +1,19 @@ //! /sentance/ use super::*; +use futures::StreamExt; pub async fn body(state: State, num: Option, mut output: mpsc::Sender) -> Result<(), gen::GenBodyError> { let string = { - let chain = state.chain().read().await; - if chain.is_empty() { - return Ok(()); - } + let mut chain = state.chain_read(); match num { - None => chain.generate_str(), - Some(num) => (0..num).map(|_| chain.generate_str()).join("\n"), + None => chain.next().await.ok_or_else(gen::GenBodyError::default)?, + Some(num) if num < state.config().max_gen_size => {//(0..num).map(|_| chain.generate_str()).join("\n"), + let chain = chain.take(num); + chain.collect::>().await.join("\n")//TODO: Stream version of JoinStrExt + }, + _ => return Err(Default::default()), } }; @@ -20,14 +22,14 @@ pub async fn body(state: State, num: Option, mut output: mpsc::Sender x, #[cold] None => return Ok(()), - }.to_owned())).await.map_err(|e| gen::GenBodyError(e.0))?; + }.to_owned())).await?; } Ok(()) } diff --git a/src/signals.rs b/src/signals.rs index 672ec3f..3a9be9e 100644 --- a/src/signals.rs +++ b/src/signals.rs @@ -36,7 +36,7 @@ pub async fn handle(mut state: State) match save::load(&state.config().file).await { Ok(new) => { { - let mut chain = state.chain().write().await; + let mut chain = state.chain_ref().write().await; *chain = new; } trace!("Replaced with read chain"); diff --git a/src/state.rs b/src/state.rs index f5cd137..6bb78c2 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,6 +3,7 @@ use super::*; use tokio::{ sync::{ watch, + mpsc::error::SendError, }, }; use config::Config; @@ -25,7 +26,7 @@ impl fmt::Display for ShutdownError pub struct State { config: Arc>, //to avoid cloning config - chain: Arc>>, + chain: handle::ChainHandle, save: Arc, begin: Initialiser, @@ -78,7 +79,7 @@ impl State &self.config_cache().outbound_filter } - pub fn new(config: Config, cache: config::Cache, chain: Arc>>, save: Arc) -> Self + pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle, save: Arc) -> Self { let (shutdown, shutdown_recv) = watch::channel(false); Self { @@ -106,9 +107,23 @@ impl State self.save.notify(); } - pub fn chain(&self) -> &RwLock> + /*pub fn chain(&self) -> &RwLock> { &self.chain.as_ref() +}*/ + pub fn chain_ref(&self) -> &RwLock> + { + &self.chain.chain_ref() + } + + pub fn chain_read(&self) -> handle::ChainStream + { + self.chain.read() + } + + pub async fn chain_write(&self, buffer: impl IntoIterator) -> Result<(), SendError>> + { + self.chain.write(buffer.into_iter().collect()).await } pub fn when(&self) -> &Arc