diff --git a/Cargo.lock b/Cargo.lock index c7df56f..6ba6139 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -616,7 +616,7 @@ dependencies = [ [[package]] name = "markov" -version = "0.8.2" +version = "0.9.0" dependencies = [ "async-compression", "bzip2-sys", diff --git a/Cargo.toml b/Cargo.toml index 6df6cd5..62d6627 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.8.2" +version = "0.9.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" @@ -36,14 +36,7 @@ split-sentance = [] # NOTE: This does nothing if `split-newlines` is not enabled always-aggregate = [] -# Feeds will hog the buffer lock until the whole body has been fed, instead of acquiring lock every time -# This will make feeds of many lines faster but can potentially cause DoS -# -# With: ~169ms -# Without: ~195ms -# -# NOTE: -# This does nothing if `always-aggregate` is enabled and/or `split-newlines` is not enabled +# Does nothing, legacy thing. hog-buffer = [] # Enable the /api/ route diff --git a/markov.toml b/markov.toml index 6b24665..d0b324f 100644 --- a/markov.toml +++ b/markov.toml @@ -7,5 +7,12 @@ trust_x_forwarded_for = false feed_bounds = '2..' [filter] -inbound = '<>/\\' +inbound = '' outbound = '' + +[writer] +backlog = 32 +internal_backlog = 8 +capacity = 4 +timeout_ms = 5000 +throttle_ms = 50 diff --git a/src/chunking.rs b/src/chunking.rs index 47dcc30..e447ba1 100644 --- a/src/chunking.rs +++ b/src/chunking.rs @@ -6,6 +6,7 @@ use std::{ Context, }, pin::Pin, + marker::PhantomData, }; use tokio::{ io::{ @@ -173,3 +174,109 @@ mod tests assert_eq!(&output[..], "Hello world\nHow are you"); } } + +/// A stream that chunks its input. +#[pin_project] +pub struct ChunkingStream> +{ + #[pin] stream: Fuse, + buf: Vec, + cap: usize, + _output: PhantomData, + + push_now: bool, +} + + +impl ChunkingStream +where S: Stream, + Into: From> +{ + pub fn new(stream: S, sz: usize) -> Self + { + Self { + stream: stream.fuse(), + buf: Vec::with_capacity(sz), + cap: sz, + _output: PhantomData, + push_now: false, + } + } + pub fn into_inner(self) -> S + { + self.stream.into_inner() + } + pub fn cap(&self) -> usize + { + self.cap + } + pub fn buffer(&self) -> &[T] + { + &self.buf[..] + } + + pub fn get_ref(&self) -> &S + { + self.stream.get_ref() + } + + pub fn get_mut(&mut self)-> &mut S + { + self.stream.get_mut() + } + + /// Force the next read to send the buffer even if it's not full. + /// + /// # Note + /// The buffer still won't send if it's empty. + pub fn push_now(&mut self) + { + self.push_now= true; + } + + /// Consume into the current held buffer + pub fn into_buffer(self) -> Vec + { + self.buf + } + + /// Take the buffer now + pub fn take_now(&mut self) -> Into + { + std::mem::replace(&mut self.buf, Vec::with_capacity(self.cap)).into() + } +} + +impl Stream for ChunkingStream +where S: Stream, + Into: From> +{ + type Item = Into; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while !(self.push_now && !self.buf.is_empty()) && self.buf.len() < self.cap { + // Buffer isn't full, keep filling + let this = self.as_mut().project(); + + match this.stream.poll_next(cx) { + Poll::Ready(None) => { + // Stream is over + break; + }, + Poll::Ready(Some(item)) => { + this.buf.push(item); + }, + _ => return Poll::Pending, + } + } + debug!("Sending buffer of {} (cap {})", self.buf.len(), self.cap); + // Buffer is full or we reach end of stream + Poll::Ready(if self.buf.len() == 0 { + None + } else { + let this = self.project(); + *this.push_now = false; + let output = std::mem::replace(this.buf, Vec::with_capacity(*this.cap)); + Some(output.into()) + }) + } +} diff --git a/src/config.rs b/src/config.rs index 12706ca..3587e32 100644 --- a/src/config.rs +++ b/src/config.rs @@ -28,9 +28,11 @@ pub struct Config pub save_interval_secs: Option, pub trust_x_forwarded_for: bool, #[serde(default)] + pub feed_bounds: String, + #[serde(default)] pub filter: FilterConfig, #[serde(default)] - pub feed_bounds: String, + pub writer: WriterConfig, } #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] @@ -41,6 +43,49 @@ pub struct FilterConfig outbound: String, } +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] +pub struct WriterConfig +{ + pub backlog: usize, + pub internal_backlog: usize, + pub capacity: usize, + pub timeout_ms: Option, + pub throttle_ms: Option, +} + +impl Default for WriterConfig +{ + #[inline] + fn default() -> Self + { + Self { + backlog: 32, + internal_backlog: 8, + capacity: 4, + timeout_ms: None, + throttle_ms: None, + } + } +} + +impl WriterConfig +{ + fn create_settings(self, bounds: range::DynRange) -> handle::Settings + { + + handle::Settings{ + backlog: self.backlog, + internal_backlog: self.internal_backlog, + capacity: self.capacity, + timeout: self.timeout_ms.map(tokio::time::Duration::from_millis).unwrap_or(handle::DEFAULT_TIMEOUT), + throttle: self.throttle_ms.map(tokio::time::Duration::from_millis), + bounds, + } + + } +} + + impl FilterConfig { fn get_inbound_filter(&self) -> sanitise::filter::Filter @@ -77,6 +122,7 @@ impl Default for Config trust_x_forwarded_for: false, filter: Default::default(), feed_bounds: "2..".to_owned(), + writer: Default::default(), } } } @@ -95,13 +141,15 @@ impl Config } } use std::ops::RangeBounds; - - Ok(Cache { - feed_bounds: section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) { + + let feed_bounds = section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) { Err(InvalidConfigError("feed_bounds", Box::new(opaque_error!("Bounds not allowed to contains 0 (they were `{}`)", bounds)))) } else { Ok(bounds) - })?, + })?; + Ok(Cache { + handler_settings: self.writer.create_settings(feed_bounds.clone()), + feed_bounds, inbound_filter: self.filter.get_inbound_filter(), outbound_filter: self.filter.get_outbound_filter(), }) @@ -205,12 +253,13 @@ impl fmt::Display for InvalidConfigError /// Caches some parsed config arguments -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq)] pub struct Cache { pub feed_bounds: range::DynRange, pub inbound_filter: sanitise::filter::Filter, pub outbound_filter: sanitise::filter::Filter, + pub handler_settings: handle::Settings, } impl fmt::Debug for Cache @@ -221,6 +270,7 @@ impl fmt::Debug for Cache .field("feed_bounds", &self.feed_bounds) .field("inbound_filter", &self.inbound_filter.iter().collect::()) .field("outbound_filter", &self.outbound_filter.iter().collect::()) + .field("handler_settings", &self.handler_settings) .finish() } } diff --git a/src/ext.rs b/src/ext.rs index 5b05901..6ca144d 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -1,4 +1,5 @@ //! Extensions +use super::*; use std::{ iter, ops::{ @@ -162,3 +163,21 @@ impl DerefMut for AssertNotSend &mut self.0 } } + +pub trait ChunkStreamExt: Sized +{ + fn chunk_into>>(self, sz: usize) -> chunking::ChunkingStream; + fn chunk(self, sz: usize) -> chunking::ChunkingStream + { + self.chunk_into(sz) + } +} + +impl ChunkStreamExt for S +where S: Stream +{ + fn chunk_into>>(self, sz: usize) -> chunking::ChunkingStream + { + chunking::ChunkingStream::new(self, sz) + } +} diff --git a/src/feed.rs b/src/feed.rs index ef91bc5..f1e5266 100644 --- a/src/feed.rs +++ b/src/feed.rs @@ -2,8 +2,10 @@ use super::*; #[cfg(any(feature="feed-sentance", feature="split-sentance"))] use sanitise::Sentance; +use futures::stream; -pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom = 2..; //TODO: Add to config somehow + +pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom = 2..; /// Feed `what` into `chain`, at least `bounds` tokens. /// @@ -56,7 +58,7 @@ pub fn feed(chain: &mut Chain, what: impl AsRef, bounds: impl std:: } debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds"); if bounds.contains(&map.len()) { - debug!("Feeding chain {} items", map.len()); + //debug!("Feeding chain {} items", map.len()); chain.feed(map); } else { @@ -73,12 +75,12 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream { + ($buffer:expr) => { { let buffer = $buffer; - feed($chain, &buffer, bounds) + state.chain_write(buffer).await.map_err(|_| FillBodyError)?; } } } @@ -101,44 +103,42 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream {:?}", who, buffer); - let mut chain = state.chain().write().await; cfg_if! { if #[cfg(feature="split-newlines")] { - for buffer in buffer.split('\n').filter(|line| !line.trim().is_empty()) { - feed!(&mut chain, buffer); - } + feed!(stream::iter(buffer.split('\n').filter(|line| !line.trim().is_empty()) + .map(|x| x.to_owned()))) } else { - feed!(&mut chain, buffer); - + feed!(stream::once(async move{buffer.into_owned()})); } } } else { use tokio::prelude::*; let reader = chunking::StreamReader::new(body.filter_map(|x| x.map(|mut x| x.to_bytes()).ok())); - let mut lines = reader.lines(); - - #[cfg(feature="hog-buffer")] - let mut chain = state.chain().write().await; - while let Some(line) = lines.next_line().await.map_err(|_| FillBodyError)? { + let lines = reader.lines(); + + feed!(lines.filter_map(|x| x.ok().and_then(|line| { let line = state.inbound_filter().filter_cow(&line); let line = line.trim(); + if !line.is_empty() { - #[cfg(not(feature="hog-buffer"))] - let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right? - - feed!(&mut chain, line); + //#[cfg(not(feature="hog-buffer"))] + //let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right? + info!("{} -> {:?}", who, line); + written+=line.len(); + Some(line.to_owned()) + } else { + None } - written+=line.len(); - } + + }))); } } - if_debug!{ + if_debug! { trace!("Write took {}ms", timer.elapsed().as_millis()); } - state.notify_save(); Ok(written) } diff --git a/src/gen.rs b/src/gen.rs index 5e75bcb..27575e0 100644 --- a/src/gen.rs +++ b/src/gen.rs @@ -1,34 +1,46 @@ //! Generating the strings use super::*; +use tokio::sync::mpsc::error::SendError; +use futures::StreamExt; -#[derive(Debug)] -pub struct GenBodyError(pub String); +#[derive(Debug, Default)] +pub struct GenBodyError(Option); impl error::Error for GenBodyError{} impl fmt::Display for GenBodyError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "failed to write {:?} to body", self.0) + if let Some(z) = &self.0 { + write!(f, "failed to write read string {:?} to body", z) + } else { + write!(f, "failed to read string from chain. it might be empty.") + } } } pub async fn body(state: State, num: Option, mut output: mpsc::Sender) -> Result<(), GenBodyError> { - let chain = state.chain().read().await; - if !chain.is_empty() { - let filter = state.outbound_filter(); - match num { - Some(num) if num < state.config().max_gen_size => { - //This could DoS `full_body` and writes, potentially. - for string in chain.str_iter_for(num) { - output.send(filter.filter_owned(string)).await.map_err(|e| GenBodyError(e.0))?; - } - }, - _ => output.send(filter.filter_owned(chain.generate_str())).await.map_err(|e| GenBodyError(e.0))?, - } + let mut chain = state.chain_read(); + let filter = state.outbound_filter(); + match num { + Some(num) if num < state.config().max_gen_size => { + let mut chain = chain.take(num); + while let Some(string) = chain.next().await { + output.send(filter.filter_owned(string)).await?; + } + }, + _ => output.send(filter.filter_owned(chain.next().await.ok_or_else(GenBodyError::default)?)).await?, } Ok(()) } + +impl From> for GenBodyError +{ + #[inline] fn from(from: SendError) -> Self + { + Self(Some(from.0)) + } +} diff --git a/src/handle.rs b/src/handle.rs new file mode 100644 index 0000000..10590c6 --- /dev/null +++ b/src/handle.rs @@ -0,0 +1,392 @@ +//! Chain handler. +use super::*; +use std::{ + marker::Send, + sync::Weak, + num::NonZeroUsize, + task::{Poll, Context,}, + pin::Pin, +}; +use tokio::{ + sync::{ + RwLock, + RwLockReadGuard, + mpsc::{ + self, + error::SendError, + }, + watch, + Notify, + }, + task::JoinHandle, + time::{ + self, + Duration, + }, +}; +use futures::StreamExt; + +pub const DEFAULT_TIMEOUT: Duration= Duration::from_secs(5); + +/// Settings for chain handler +#[derive(Debug, Clone, PartialEq)] +pub struct Settings +{ + pub backlog: usize, + pub internal_backlog: usize, + pub capacity: usize, + pub timeout: Duration, + pub throttle: Option, + pub bounds: range::DynRange, +} + +impl Settings +{ + /// Should we keep this string. + #[inline] fn matches(&self, _s: &str) -> bool + { + true + } +} + +impl Default for Settings +{ + #[inline] + fn default() -> Self + { + Self { + backlog: 32, + internal_backlog: 8, + capacity: 4, + timeout: Duration::from_secs(5), + throttle: Some(Duration::from_millis(200)), + bounds: feed::DEFAULT_FEED_BOUNDS.into(), + } + } +} + + +#[derive(Debug)] +struct HostInner +{ + input: mpsc::Receiver>, + shutdown: watch::Receiver, +} + +#[derive(Debug)] +struct Handle +{ + chain: RwLock>, + input: mpsc::Sender>, + opt: Settings, + notify_write: Arc, + push_now: Arc, + shutdown: watch::Sender, + + /// Data used only for the worker task. + host: msg::Once>, +} + +#[derive(Clone, Debug)] +pub struct ChainHandle(Arc>>); + +impl ChainHandle +{ + pub fn with_settings(chain: chain::Chain, opt: Settings) -> Self + { + let (shutdown_tx, shutdown) = watch::channel(false); + let (itx, irx) = mpsc::channel(opt.backlog); + Self(Arc::new(Box::new(Handle{ + chain: RwLock::new(chain), + input: itx, + opt, + push_now: Arc::new(Notify::new()), + notify_write: Arc::new(Notify::new()), + shutdown: shutdown_tx, + + host: msg::Once::new(HostInner{ + input: irx, + shutdown, + }) + }))) + } + + /// Acquire the chain read lock + async fn chain(&self) -> RwLockReadGuard<'_, chain::Chain> + { + self.0.chain.read().await + } + + /// A reference to the chain + pub fn chain_ref(&self) -> &RwLock> + { + &self.0.chain + } + + /// Create a stream that reads generated values forever. + pub fn read(&self) -> ChainStream + { + ChainStream{ + chain: Arc::downgrade(&self.0), + buffer: Vec::with_capacity(self.0.opt.backlog), + } + } + + /// Send this buffer to the chain + pub fn write(&self, buf: Vec) -> impl futures::Future>>> + 'static + { + let mut write = self.0.input.clone(); + async move { + write.send(buf).await + } + } + + /// Send this stream buffer to the chain + pub fn write_stream<'a, I: Stream>(&self, buf: I) -> impl futures::Future>>> + 'a + where I: 'a + { + let mut write = self.0.input.clone(); + async move { + write.send(buf.collect().await).await + } + } + + /// Send this buffer to the chain + pub async fn write_in_place(&self, buf: Vec) -> Result<(), SendError>> + { + self.0.input.clone().send(buf).await + } + + /// A referencer for the notifier + pub fn notify_when(&self) -> &Arc + { + &self.0.notify_write + } + + /// Force the pending buffers to be written to the chain now + pub fn push_now(&self) + { + self.0.push_now.notify(); + } + + /// Hang the worker thread, preventing it from taking any more inputs and also flushing it. + /// + /// # Panics + /// If there was no worker thread. + pub fn hang(&self) + { + trace!("Communicating hang request"); + self.0.shutdown.broadcast(true).expect("Failed to communicate hang"); + } +} + +impl ChainHandle +{ + #[deprecated = "use read() pls"] + pub async fn generate_body(&self, state: &state::State, num: Option, mut output: mpsc::Sender) -> Result<(), SendError> + { + let chain = self.chain().await; + if !chain.is_empty() { + let filter = state.outbound_filter(); + match num { + Some(num) if num < state.config().max_gen_size => { + //This could DoS writes, potentially. + for string in chain.str_iter_for(num) { + output.send(filter.filter_owned(string)).await?; + } + }, + _ => output.send(filter.filter_owned(chain.generate_str())).await?, + } + } + Ok(()) + } +} + +/// Host this handle on the current task. +/// +/// # Panics +/// If `from` has already been hosted. +pub async fn host(from: ChainHandle) +{ + let opt = from.0.opt.clone(); + let mut data = from.0.host.unwrap().await; + + let (mut tx, mut child) = { + // The `real` input channel. + let from = from.clone(); + let opt = opt.clone(); + let (tx, rx) = mpsc::channel::>>(opt.internal_backlog); + (tx, tokio::spawn(async move { + let mut rx = if let Some(thr) = opt.throttle { + time::throttle(thr, rx).boxed() + } else { + rx.boxed() + }; + trace!("child: Begin waiting on parent"); + while let Some(item) = rx.next().await { + if item.len() > 0 { + info!("Write lock acq"); + let mut lock = from.0.chain.write().await; + for item in item.into_iter() + { + use std::ops::DerefMut; + for item in item.into_iter() { + feed::feed(lock.deref_mut(), item, &from.0.opt.bounds); + } + } + trace!("Signalling write"); + from.0.notify_write.notify(); + } + } + trace!("child: exiting"); + })) + }; + + trace!("Begin polling on child"); + tokio::select!{ + v = &mut child => { + match v { + #[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen. + Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel. + } + }, + _ = async move { + let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh, oh well. + + if !data.shutdown.recv().await.unwrap_or(true) { //first shutdown we get for free + while Arc::strong_count(&from.0) > 2 { + if *data.shutdown.borrow() { + break; + } + + tokio::select!{ + Some(true) = data.shutdown.recv() => { + debug!("Got shutdown (hang) request. Sending now then breaking"); + + let mut rest = { + let irx = rx.get_mut(); + irx.close(); //accept no more inputs + let mut output = Vec::with_capacity(opt.capacity); + while let Ok(item) = irx.try_recv() { + output.push(item); + } + output + }; + rest.extend(rx.take_now()); + if rest.len() > 0 { + if let Err(err) = tx.send(rest).await { + error!("Failed to force send buffer, exiting now: {}", err); + } + } + break; + } + _ = time::delay_for(opt.timeout) => { + trace!("Setting push now"); + rx.push_now(); + } + _ = from.0.push_now.notified() => { + debug!("Got force push signal"); + let take =rx.take_now(); + rx.push_now(); + if take.len() > 0 { + if let Err(err) = tx.send(take).await { + error!("Failed to force send buffer: {}", err); + break; + } + } + } + Some(buffer) = rx.next() => { + debug!("Sending {} (cap {})", buffer.len(), buffer.capacity()); + if let Err(err) = tx.send(buffer).await { + // Receive closed? + // + // This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`. + // In any case, it means we should abort. + #[cold] error!("Failed to send buffer: {}", err); + break; + } + } + } + } + } + let last = rx.into_buffer(); + if last.len() > 0 { + if let Err(err) = tx.send(last).await { + error!("Failed to force send last part of buffer: {}", err); + } else { + trace!("Sent rest of buffer"); + } + } + } => { + // Normal exit + trace!("Normal exit") + }, + } + trace!("Waiting on child"); + // No more handles except child, no more possible inputs. + child.await.expect("Child panic"); + trace!("Returning"); +} + +/// Spawn a new chain handler for this chain. +pub fn spawn(from: chain::Chain, opt: Settings) -> (JoinHandle<()>, ChainHandle) +{ + debug!("Spawning with opt: {:?}", opt); + let handle = ChainHandle::with_settings(from, opt); + (tokio::spawn(host(handle.clone())), handle) +} + +#[derive(Debug)] +pub struct ChainStream +{ + chain: Weak>>, + buffer: Vec, +} + +impl ChainStream +{ + async fn try_pull(&mut self, n: usize) -> Option + { + if n == 0 { + return None; + } + if let Some(read) = self.chain.upgrade() { + let chain = read.chain.read().await; + if chain.is_empty() { + return None; + } + + let n = if n == 1 { + self.buffer.push(chain.generate_str()); + 1 + } else { + self.buffer.extend(chain.str_iter_for(n)); + n //for now + }; + Some(unsafe{NonZeroUsize::new_unchecked(n)}) + } else { + None + } + } +} + +impl Stream for ChainStream +{ + type Item = String; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use futures::Future; + let this = self.get_mut(); + + if this.buffer.len() == 0 { + let pull = this.try_pull(this.buffer.capacity()); + tokio::pin!(pull); + match pull.poll(cx) { + Poll::Ready(Some(_)) => {}, + Poll::Pending => return Poll::Pending, + _ => return Poll::Ready(None), + }; + } + debug_assert!(this.buffer.len()>0); + Poll::Ready(Some(this.buffer.remove(0))) + } +} diff --git a/src/main.rs b/src/main.rs index a0b6f30..368668f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,6 +78,7 @@ use state::State; mod save; mod forwarded_list; use forwarded_list::XForwardedFor; +mod handle; mod feed; mod gen; @@ -134,7 +135,7 @@ async fn main() { debug!("Using config {:?}", config); trace!("With config cached: {:?}", ccache); - let chain = Arc::new(RwLock::new(match save::load(&config.file).await { + let (chain_handle, chain) = handle::spawn(match save::load(&config.file).await { Ok(chain) => { info!("Loaded chain from {:?}", config.file); chain @@ -144,16 +145,15 @@ async fn main() { trace!("Error: {}", e); Chain::new() }, - })); + }, ccache.handler_settings.clone()); { let mut tasks = Vec::>::new(); + tasks.push(chain_handle.map(|res| res.expect("Chain handle panicked")).boxed()); let (state, chain) = { - let save_when = Arc::new(Notify::new()); let state = State::new(config, ccache, - Arc::clone(&chain), - Arc::clone(&save_when)); + chain); let state2 = state.clone(); let saver = tokio::spawn(save::host(Box::new(state.clone()))); let chain = warp::any().map(move || state.clone()); diff --git a/src/msg.rs b/src/msg.rs index 42ba70b..cd9daaf 100644 --- a/src/msg.rs +++ b/src/msg.rs @@ -3,6 +3,7 @@ use super::*; use tokio::{ sync::{ watch, + Mutex, }, }; use std::{ @@ -12,7 +13,9 @@ use std::{ error, }; use futures::{ - future::Future, + future::{ + Future, + }, }; #[derive(Debug)] @@ -160,3 +163,48 @@ impl Future for Initialiser uhh.poll(cx) } } + +/// A value that can be consumed once. +#[derive(Debug)] +pub struct Once(Mutex>); + +impl Once +{ + /// Create a new instance + pub fn new(from: T) -> Self + { + Self(Mutex::new(Some(from))) + } + /// Consume into the instance from behind a potentially shared reference. + pub async fn consume_shared(self: Arc) -> Option + { + match Arc::try_unwrap(self) { + Ok(x) => x.0.into_inner(), + Err(x) => x.0.lock().await.take(), + } + } + + /// Consume from a shared reference and panic if the value has already been consumed. + pub async fn unwrap_shared(self: Arc) -> T + { + self.consume_shared().await.unwrap() + } + + /// Consume into the instance. + pub async fn consume(&self) -> Option + { + self.0.lock().await.take() + } + + /// Consume and panic if the value has already been consumed. + pub async fn unwrap(&self) -> T + { + self.consume().await.unwrap() + } + + /// Consume into the inner value + pub fn into_inner(self) -> Option + { + self.0.into_inner() + } +} diff --git a/src/sanitise/filter.rs b/src/sanitise/filter.rs index cbc896b..ba4f79c 100644 --- a/src/sanitise/filter.rs +++ b/src/sanitise/filter.rs @@ -272,6 +272,6 @@ mod tests let string = "abcdef ghi jk1\nhian"; assert_eq!(filter.filter_str(&string).to_string(), filter.filter_cow(&string).to_string()); - assert_eq!(filter.filter_cow(&string).to_string(), filter.filter(string.chars()).collect::()); + assert_eq!(filter.filter_cow(&string).to_string(), filter.filter_iter(string.chars()).collect::()); } } diff --git a/src/sanitise/sentance.rs b/src/sanitise/sentance.rs index edae602..dc4ee22 100644 --- a/src/sanitise/sentance.rs +++ b/src/sanitise/sentance.rs @@ -25,7 +25,7 @@ macro_rules! new { }; } -const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?']; +const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?', '~']; lazy_static! { static ref BOUNDARIES: smallmap::Map = { diff --git a/src/sanitise/word.rs b/src/sanitise/word.rs index c50a5fc..92f639d 100644 --- a/src/sanitise/word.rs +++ b/src/sanitise/word.rs @@ -25,7 +25,7 @@ macro_rules! new { }; } -const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ',']; +const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ',', '*']; lazy_static! { static ref BOUNDARIES: smallmap::Map = { diff --git a/src/save.rs b/src/save.rs index f083b0c..75f3f58 100644 --- a/src/save.rs +++ b/src/save.rs @@ -43,7 +43,7 @@ type Decompressor = BzDecoder; pub async fn save_now(state: &State) -> io::Result<()> { - let chain = state.chain().read().await; + let chain = state.chain_ref().read().await; use std::ops::Deref; let to = &state.config().file; save_now_to(chain.deref(),to).await @@ -77,12 +77,13 @@ pub async fn host(mut state: Box) { let to = state.config().file.to_owned(); let interval = state.config().save_interval(); + let when = Arc::clone(state.when_ref()); trace!("Setup oke. Waiting on init"); if state.on_init().await.is_ok() { debug!("Begin save handler"); - while Arc::strong_count(state.when()) > 1 { + while Arc::strong_count(&when) > 1 { { - let chain = state.chain().read().await; + let chain = state.chain_ref().read().await; use std::ops::Deref; if let Err(e) = save_now_to(chain.deref(), &to).await { error!("Failed to save chain: {}", e); @@ -97,7 +98,7 @@ pub async fn host(mut state: Box) break; } } - state.when().notified().await; + when.notified().await; if state.has_shutdown() { break; } diff --git a/src/sentance.rs b/src/sentance.rs index 41dd16d..36d3fdf 100644 --- a/src/sentance.rs +++ b/src/sentance.rs @@ -1,17 +1,19 @@ //! /sentance/ use super::*; +use futures::StreamExt; pub async fn body(state: State, num: Option, mut output: mpsc::Sender) -> Result<(), gen::GenBodyError> { let string = { - let chain = state.chain().read().await; - if chain.is_empty() { - return Ok(()); - } + let mut chain = state.chain_read(); match num { - None => chain.generate_str(), - Some(num) => (0..num).map(|_| chain.generate_str()).join("\n"), + None => chain.next().await.ok_or_else(gen::GenBodyError::default)?, + Some(num) if num < state.config().max_gen_size => {//(0..num).map(|_| chain.generate_str()).join("\n"), + let chain = chain.take(num); + chain.collect::>().await.join("\n")//TODO: Stream version of JoinStrExt + }, + _ => return Err(Default::default()), } }; @@ -20,14 +22,14 @@ pub async fn body(state: State, num: Option, mut output: mpsc::Sender x, #[cold] None => return Ok(()), - }.to_owned())).await.map_err(|e| gen::GenBodyError(e.0))?; + }.to_owned())).await?; } Ok(()) } diff --git a/src/signals.rs b/src/signals.rs index 672ec3f..2b16db7 100644 --- a/src/signals.rs +++ b/src/signals.rs @@ -12,8 +12,9 @@ use tokio::{ pub async fn handle(mut state: State) { let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("Failed to hook SIGUSR1"); - let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2"); + let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2"); let mut quit = unix::signal(SignalKind::quit()).expect("Failed to hook SIGQUIT"); + let mut io = unix::signal(SignalKind::io()).expect("Failed to hook IO"); trace!("Setup oke. Waiting on init"); if state.on_init().await.is_ok() { @@ -24,19 +25,15 @@ pub async fn handle(mut state: State) break; } _ = usr1.recv() => { - info!("Got SIGUSR1. Saving chain immediately."); - if let Err(e) = save::save_now(&state).await { - error!("Failed to save chain: {}", e); - } else{ - trace!("Saved chain okay"); - } + info!("Got SIGUSR1. Causing chain write."); + state.push_now(); }, _ = usr2.recv() => { - info!("Got SIGUSR1. Loading chain immediately."); + info!("Got SIGUSR2. Loading chain immediately."); match save::load(&state.config().file).await { Ok(new) => { { - let mut chain = state.chain().write().await; + let mut chain = state.chain_ref().write().await; *chain = new; } trace!("Replaced with read chain"); @@ -46,6 +43,15 @@ pub async fn handle(mut state: State) }, } }, + + _ = io.recv() => { + info!("Got SIGIO. Saving chain immediately."); + if let Err(e) = save::save_now(&state).await { + error!("Failed to save chain: {}", e); + } else{ + trace!("Saved chain okay"); + } + }, _ = quit.recv() => { warn!("Got SIGQUIT. Saving chain then aborting."); if let Err(e) = save::save_now(&state).await { diff --git a/src/state.rs b/src/state.rs index f5cd137..9a6bae7 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,6 +3,7 @@ use super::*; use tokio::{ sync::{ watch, + mpsc::error::SendError, }, }; use config::Config; @@ -25,8 +26,8 @@ impl fmt::Display for ShutdownError pub struct State { config: Arc>, //to avoid cloning config - chain: Arc>>, - save: Arc, + chain: handle::ChainHandle, + //save: Arc, begin: Initialiser, shutdown: Arc>, @@ -78,13 +79,12 @@ impl State &self.config_cache().outbound_filter } - pub fn new(config: Config, cache: config::Cache, chain: Arc>>, save: Arc) -> Self + pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle) -> Self { let (shutdown, shutdown_recv) = watch::channel(false); Self { config: Arc::new(Box::new((config, cache))), chain, - save, begin: Initialiser::new(), shutdown: Arc::new(shutdown), shutdown_recv, @@ -101,25 +101,48 @@ impl State &self.config.as_ref().1 } - pub fn notify_save(&self) + /*pub fn notify_save(&self) { - self.save.notify(); + self.save.notify(); +}*/ + + /*pub fn chain(&self) -> &RwLock> + { + &self.chain.as_ref() +}*/ + pub fn chain_ref(&self) -> &RwLock> + { + &self.chain.chain_ref() + } + + pub fn chain_read(&self) -> handle::ChainStream + { + self.chain.read() } - pub fn chain(&self) -> &RwLock> + /// Write to this chain + pub async fn chain_write<'a, T: Stream>(&'a self, buffer: T) -> Result<(), SendError>> + { + self.chain.write_stream(buffer).await + } + + + pub fn when_ref(&self) -> &Arc { - &self.chain.as_ref() + &self.chain.notify_when() } - pub fn when(&self) -> &Arc + /// Force the chain to push through now + pub fn push_now(&self) { - &self.save + self.chain.push_now() } pub fn shutdown(self) { self.shutdown.broadcast(true).expect("Failed to communicate shutdown"); - self.save.notify(); + self.chain.hang(); + self.when_ref().notify(); } pub fn has_shutdown(&self) -> bool