From 75730cbe0ffc8a2fb9f73607a38a0657eb5391a5 Mon Sep 17 00:00:00 2001 From: Avril Date: Wed, 14 Oct 2020 01:08:27 +0100 Subject: [PATCH] working implementation of handler --- Cargo.lock | 2 +- Cargo.toml | 11 +--- markov.toml | 9 ++- src/chunking.rs | 23 +++++++ src/config.rs | 62 ++++++++++++++++-- src/feed.rs | 34 +++++----- src/handle.rs | 167 +++++++++++++++++++++++++++++++++++++++--------- src/main.rs | 7 +- src/save.rs | 5 +- src/signals.rs | 22 ++++--- src/state.rs | 32 ++++++---- 11 files changed, 285 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c7df56f..6ba6139 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -616,7 +616,7 @@ dependencies = [ [[package]] name = "markov" -version = "0.8.2" +version = "0.9.0" dependencies = [ "async-compression", "bzip2-sys", diff --git a/Cargo.toml b/Cargo.toml index 6df6cd5..62d6627 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.8.2" +version = "0.9.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" @@ -36,14 +36,7 @@ split-sentance = [] # NOTE: This does nothing if `split-newlines` is not enabled always-aggregate = [] -# Feeds will hog the buffer lock until the whole body has been fed, instead of acquiring lock every time -# This will make feeds of many lines faster but can potentially cause DoS -# -# With: ~169ms -# Without: ~195ms -# -# NOTE: -# This does nothing if `always-aggregate` is enabled and/or `split-newlines` is not enabled +# Does nothing, legacy thing. hog-buffer = [] # Enable the /api/ route diff --git a/markov.toml b/markov.toml index 6b24665..d0b324f 100644 --- a/markov.toml +++ b/markov.toml @@ -7,5 +7,12 @@ trust_x_forwarded_for = false feed_bounds = '2..' [filter] -inbound = '<>/\\' +inbound = '' outbound = '' + +[writer] +backlog = 32 +internal_backlog = 8 +capacity = 4 +timeout_ms = 5000 +throttle_ms = 50 diff --git a/src/chunking.rs b/src/chunking.rs index 0658f08..e447ba1 100644 --- a/src/chunking.rs +++ b/src/chunking.rs @@ -215,6 +215,16 @@ where S: Stream, &self.buf[..] } + pub fn get_ref(&self) -> &S + { + self.stream.get_ref() + } + + pub fn get_mut(&mut self)-> &mut S + { + self.stream.get_mut() + } + /// Force the next read to send the buffer even if it's not full. /// /// # Note @@ -223,6 +233,18 @@ where S: Stream, { self.push_now= true; } + + /// Consume into the current held buffer + pub fn into_buffer(self) -> Vec + { + self.buf + } + + /// Take the buffer now + pub fn take_now(&mut self) -> Into + { + std::mem::replace(&mut self.buf, Vec::with_capacity(self.cap)).into() + } } impl Stream for ChunkingStream @@ -246,6 +268,7 @@ where S: Stream, _ => return Poll::Pending, } } + debug!("Sending buffer of {} (cap {})", self.buf.len(), self.cap); // Buffer is full or we reach end of stream Poll::Ready(if self.buf.len() == 0 { None diff --git a/src/config.rs b/src/config.rs index 12706ca..3587e32 100644 --- a/src/config.rs +++ b/src/config.rs @@ -28,9 +28,11 @@ pub struct Config pub save_interval_secs: Option, pub trust_x_forwarded_for: bool, #[serde(default)] + pub feed_bounds: String, + #[serde(default)] pub filter: FilterConfig, #[serde(default)] - pub feed_bounds: String, + pub writer: WriterConfig, } #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] @@ -41,6 +43,49 @@ pub struct FilterConfig outbound: String, } +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] +pub struct WriterConfig +{ + pub backlog: usize, + pub internal_backlog: usize, + pub capacity: usize, + pub timeout_ms: Option, + pub throttle_ms: Option, +} + +impl Default for WriterConfig +{ + #[inline] + fn default() -> Self + { + Self { + backlog: 32, + internal_backlog: 8, + capacity: 4, + timeout_ms: None, + throttle_ms: None, + } + } +} + +impl WriterConfig +{ + fn create_settings(self, bounds: range::DynRange) -> handle::Settings + { + + handle::Settings{ + backlog: self.backlog, + internal_backlog: self.internal_backlog, + capacity: self.capacity, + timeout: self.timeout_ms.map(tokio::time::Duration::from_millis).unwrap_or(handle::DEFAULT_TIMEOUT), + throttle: self.throttle_ms.map(tokio::time::Duration::from_millis), + bounds, + } + + } +} + + impl FilterConfig { fn get_inbound_filter(&self) -> sanitise::filter::Filter @@ -77,6 +122,7 @@ impl Default for Config trust_x_forwarded_for: false, filter: Default::default(), feed_bounds: "2..".to_owned(), + writer: Default::default(), } } } @@ -95,13 +141,15 @@ impl Config } } use std::ops::RangeBounds; - - Ok(Cache { - feed_bounds: section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) { + + let feed_bounds = section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) { Err(InvalidConfigError("feed_bounds", Box::new(opaque_error!("Bounds not allowed to contains 0 (they were `{}`)", bounds)))) } else { Ok(bounds) - })?, + })?; + Ok(Cache { + handler_settings: self.writer.create_settings(feed_bounds.clone()), + feed_bounds, inbound_filter: self.filter.get_inbound_filter(), outbound_filter: self.filter.get_outbound_filter(), }) @@ -205,12 +253,13 @@ impl fmt::Display for InvalidConfigError /// Caches some parsed config arguments -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq)] pub struct Cache { pub feed_bounds: range::DynRange, pub inbound_filter: sanitise::filter::Filter, pub outbound_filter: sanitise::filter::Filter, + pub handler_settings: handle::Settings, } impl fmt::Debug for Cache @@ -221,6 +270,7 @@ impl fmt::Debug for Cache .field("feed_bounds", &self.feed_bounds) .field("inbound_filter", &self.inbound_filter.iter().collect::()) .field("outbound_filter", &self.outbound_filter.iter().collect::()) + .field("handler_settings", &self.handler_settings) .finish() } } diff --git a/src/feed.rs b/src/feed.rs index 166f0d1..f1e5266 100644 --- a/src/feed.rs +++ b/src/feed.rs @@ -2,7 +2,7 @@ use super::*; #[cfg(any(feature="feed-sentance", feature="split-sentance"))] use sanitise::Sentance; -use std::iter; +use futures::stream; pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom = 2..; @@ -58,7 +58,7 @@ pub fn feed(chain: &mut Chain, what: impl AsRef, bounds: impl std:: } debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds"); if bounds.contains(&map.len()) { - debug!("Feeding chain {} items", map.len()); + //debug!("Feeding chain {} items", map.len()); chain.feed(map); } else { @@ -80,7 +80,7 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream { { let buffer = $buffer; - state.chain_write(buffer.map(ToOwned::to_owned)).await.map_err(|_| FillBodyError)?; + state.chain_write(buffer).await.map_err(|_| FillBodyError)?; } } } @@ -105,38 +105,40 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream {:?}", who, buffer); cfg_if! { if #[cfg(feature="split-newlines")] { - feed!(buffer.split('\n').filter(|line| !line.trim().is_empty())) + feed!(stream::iter(buffer.split('\n').filter(|line| !line.trim().is_empty()) + .map(|x| x.to_owned()))) } else { - feed!(iter::once(buffer)); + feed!(stream::once(async move{buffer.into_owned()})); } } } else { use tokio::prelude::*; let reader = chunking::StreamReader::new(body.filter_map(|x| x.map(|mut x| x.to_bytes()).ok())); - let mut lines = reader.lines(); - - #[cfg(feature="hog-buffer")] - let mut chain = state.chain().write().await; - while let Some(line) = lines.next_line().await.map_err(|_| FillBodyError)? { + let lines = reader.lines(); + + feed!(lines.filter_map(|x| x.ok().and_then(|line| { let line = state.inbound_filter().filter_cow(&line); let line = line.trim(); + if !line.is_empty() { //#[cfg(not(feature="hog-buffer"))] //let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right? - - feed!(iter::once(line)); + info!("{} -> {:?}", who, line); + written+=line.len(); + Some(line.to_owned()) + } else { + None } - written+=line.len(); - } + + }))); } } - if_debug!{ + if_debug! { trace!("Write took {}ms", timer.elapsed().as_millis()); } - state.notify_save(); Ok(written) } diff --git a/src/handle.rs b/src/handle.rs index 0669c15..10590c6 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -15,6 +15,8 @@ use tokio::{ self, error::SendError, }, + watch, + Notify, }, task::JoinHandle, time::{ @@ -24,11 +26,14 @@ use tokio::{ }; use futures::StreamExt; +pub const DEFAULT_TIMEOUT: Duration= Duration::from_secs(5); + /// Settings for chain handler #[derive(Debug, Clone, PartialEq)] pub struct Settings { pub backlog: usize, + pub internal_backlog: usize, pub capacity: usize, pub timeout: Duration, pub throttle: Option, @@ -38,7 +43,7 @@ pub struct Settings impl Settings { /// Should we keep this string. - #[inline] fn matches(&self, s: &str) -> bool + #[inline] fn matches(&self, _s: &str) -> bool { true } @@ -51,6 +56,7 @@ impl Default for Settings { Self { backlog: 32, + internal_backlog: 8, capacity: 4, timeout: Duration::from_secs(5), throttle: Some(Duration::from_millis(200)), @@ -64,6 +70,7 @@ impl Default for Settings struct HostInner { input: mpsc::Receiver>, + shutdown: watch::Receiver, } #[derive(Debug)] @@ -72,7 +79,10 @@ struct Handle chain: RwLock>, input: mpsc::Sender>, opt: Settings, - + notify_write: Arc, + push_now: Arc, + shutdown: watch::Sender, + /// Data used only for the worker task. host: msg::Once>, } @@ -80,22 +90,23 @@ struct Handle #[derive(Clone, Debug)] pub struct ChainHandle(Arc>>); -impl ChainHandle +impl ChainHandle { - #[inline] pub fn new(chain: chain::Chain) -> Self - { - Self::with_settings(chain, Default::default()) - } pub fn with_settings(chain: chain::Chain, opt: Settings) -> Self { + let (shutdown_tx, shutdown) = watch::channel(false); let (itx, irx) = mpsc::channel(opt.backlog); Self(Arc::new(Box::new(Handle{ chain: RwLock::new(chain), input: itx, opt, + push_now: Arc::new(Notify::new()), + notify_write: Arc::new(Notify::new()), + shutdown: shutdown_tx, host: msg::Once::new(HostInner{ input: irx, + shutdown, }) }))) } @@ -122,10 +133,51 @@ impl ChainHandle } /// Send this buffer to the chain - pub async fn write(&self, buf: Vec) -> Result<(), SendError>> + pub fn write(&self, buf: Vec) -> impl futures::Future>>> + 'static + { + let mut write = self.0.input.clone(); + async move { + write.send(buf).await + } + } + + /// Send this stream buffer to the chain + pub fn write_stream<'a, I: Stream>(&self, buf: I) -> impl futures::Future>>> + 'a + where I: 'a + { + let mut write = self.0.input.clone(); + async move { + write.send(buf.collect().await).await + } + } + + /// Send this buffer to the chain + pub async fn write_in_place(&self, buf: Vec) -> Result<(), SendError>> { self.0.input.clone().send(buf).await } + + /// A referencer for the notifier + pub fn notify_when(&self) -> &Arc + { + &self.0.notify_write + } + + /// Force the pending buffers to be written to the chain now + pub fn push_now(&self) + { + self.0.push_now.notify(); + } + + /// Hang the worker thread, preventing it from taking any more inputs and also flushing it. + /// + /// # Panics + /// If there was no worker thread. + pub fn hang(&self) + { + trace!("Communicating hang request"); + self.0.shutdown.broadcast(true).expect("Failed to communicate hang"); + } } impl ChainHandle @@ -157,13 +209,13 @@ impl ChainHandle pub async fn host(from: ChainHandle) { let opt = from.0.opt.clone(); - let data = from.0.host.unwrap().await; + let mut data = from.0.host.unwrap().await; - let (mut tx, child) = { + let (mut tx, mut child) = { // The `real` input channel. let from = from.clone(); let opt = opt.clone(); - let (tx, rx) = mpsc::channel::>>(opt.backlog); + let (tx, rx) = mpsc::channel::>>(opt.internal_backlog); (tx, tokio::spawn(async move { let mut rx = if let Some(thr) = opt.throttle { time::throttle(thr, rx).boxed() @@ -172,13 +224,18 @@ pub async fn host(from: ChainHandle) }; trace!("child: Begin waiting on parent"); while let Some(item) = rx.next().await { - let mut lock = from.0.chain.write().await; - for item in item.into_iter() - { - use std::ops::DerefMut; - for item in item.into_iter() { - feed::feed(lock.deref_mut(), item, &from.0.opt.bounds); + if item.len() > 0 { + info!("Write lock acq"); + let mut lock = from.0.chain.write().await; + for item in item.into_iter() + { + use std::ops::DerefMut; + for item in item.into_iter() { + feed::feed(lock.deref_mut(), item, &from.0.opt.bounds); + } } + trace!("Signalling write"); + from.0.notify_write.notify(); } } trace!("child: exiting"); @@ -187,44 +244,94 @@ pub async fn host(from: ChainHandle) trace!("Begin polling on child"); tokio::select!{ - v = child => { + v = &mut child => { match v { #[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen. Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel. } }, _ = async move { - let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh + let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh, oh well. - while Arc::strong_count(&from.0) > 2 { - tokio::select!{ - _ = time::delay_for(opt.timeout) => { - rx.push_now(); + if !data.shutdown.recv().await.unwrap_or(true) { //first shutdown we get for free + while Arc::strong_count(&from.0) > 2 { + if *data.shutdown.borrow() { + break; } - Some(buffer) = rx.next() => { - if let Err(err) = tx.send(buffer).await { - // Receive closed? - // - // This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`. - // In any case, it means we should abort. - error!("Failed to send buffer: {}", err); + + tokio::select!{ + Some(true) = data.shutdown.recv() => { + debug!("Got shutdown (hang) request. Sending now then breaking"); + + let mut rest = { + let irx = rx.get_mut(); + irx.close(); //accept no more inputs + let mut output = Vec::with_capacity(opt.capacity); + while let Ok(item) = irx.try_recv() { + output.push(item); + } + output + }; + rest.extend(rx.take_now()); + if rest.len() > 0 { + if let Err(err) = tx.send(rest).await { + error!("Failed to force send buffer, exiting now: {}", err); + } + } break; } + _ = time::delay_for(opt.timeout) => { + trace!("Setting push now"); + rx.push_now(); + } + _ = from.0.push_now.notified() => { + debug!("Got force push signal"); + let take =rx.take_now(); + rx.push_now(); + if take.len() > 0 { + if let Err(err) = tx.send(take).await { + error!("Failed to force send buffer: {}", err); + break; + } + } + } + Some(buffer) = rx.next() => { + debug!("Sending {} (cap {})", buffer.len(), buffer.capacity()); + if let Err(err) = tx.send(buffer).await { + // Receive closed? + // + // This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`. + // In any case, it means we should abort. + #[cold] error!("Failed to send buffer: {}", err); + break; + } + } } } } + let last = rx.into_buffer(); + if last.len() > 0 { + if let Err(err) = tx.send(last).await { + error!("Failed to force send last part of buffer: {}", err); + } else { + trace!("Sent rest of buffer"); + } + } } => { // Normal exit trace!("Normal exit") }, } + trace!("Waiting on child"); // No more handles except child, no more possible inputs. + child.await.expect("Child panic"); trace!("Returning"); } /// Spawn a new chain handler for this chain. pub fn spawn(from: chain::Chain, opt: Settings) -> (JoinHandle<()>, ChainHandle) { + debug!("Spawning with opt: {:?}", opt); let handle = ChainHandle::with_settings(from, opt); (tokio::spawn(host(handle.clone())), handle) } diff --git a/src/main.rs b/src/main.rs index 547a2da..368668f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -145,16 +145,15 @@ async fn main() { trace!("Error: {}", e); Chain::new() }, - }, Default::default()/*TODO*/); + }, ccache.handler_settings.clone()); { let mut tasks = Vec::>::new(); + tasks.push(chain_handle.map(|res| res.expect("Chain handle panicked")).boxed()); let (state, chain) = { - let save_when = Arc::new(Notify::new()); let state = State::new(config, ccache, - chain, - Arc::clone(&save_when)); + chain); let state2 = state.clone(); let saver = tokio::spawn(save::host(Box::new(state.clone()))); let chain = warp::any().map(move || state.clone()); diff --git a/src/save.rs b/src/save.rs index 5f6b74c..75f3f58 100644 --- a/src/save.rs +++ b/src/save.rs @@ -77,10 +77,11 @@ pub async fn host(mut state: Box) { let to = state.config().file.to_owned(); let interval = state.config().save_interval(); + let when = Arc::clone(state.when_ref()); trace!("Setup oke. Waiting on init"); if state.on_init().await.is_ok() { debug!("Begin save handler"); - while Arc::strong_count(state.when()) > 1 { + while Arc::strong_count(&when) > 1 { { let chain = state.chain_ref().read().await; use std::ops::Deref; @@ -97,7 +98,7 @@ pub async fn host(mut state: Box) break; } } - state.when().notified().await; + when.notified().await; if state.has_shutdown() { break; } diff --git a/src/signals.rs b/src/signals.rs index 3a9be9e..2b16db7 100644 --- a/src/signals.rs +++ b/src/signals.rs @@ -12,8 +12,9 @@ use tokio::{ pub async fn handle(mut state: State) { let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("Failed to hook SIGUSR1"); - let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2"); + let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2"); let mut quit = unix::signal(SignalKind::quit()).expect("Failed to hook SIGQUIT"); + let mut io = unix::signal(SignalKind::io()).expect("Failed to hook IO"); trace!("Setup oke. Waiting on init"); if state.on_init().await.is_ok() { @@ -24,15 +25,11 @@ pub async fn handle(mut state: State) break; } _ = usr1.recv() => { - info!("Got SIGUSR1. Saving chain immediately."); - if let Err(e) = save::save_now(&state).await { - error!("Failed to save chain: {}", e); - } else{ - trace!("Saved chain okay"); - } + info!("Got SIGUSR1. Causing chain write."); + state.push_now(); }, _ = usr2.recv() => { - info!("Got SIGUSR1. Loading chain immediately."); + info!("Got SIGUSR2. Loading chain immediately."); match save::load(&state.config().file).await { Ok(new) => { { @@ -46,6 +43,15 @@ pub async fn handle(mut state: State) }, } }, + + _ = io.recv() => { + info!("Got SIGIO. Saving chain immediately."); + if let Err(e) = save::save_now(&state).await { + error!("Failed to save chain: {}", e); + } else{ + trace!("Saved chain okay"); + } + }, _ = quit.recv() => { warn!("Got SIGQUIT. Saving chain then aborting."); if let Err(e) = save::save_now(&state).await { diff --git a/src/state.rs b/src/state.rs index 6bb78c2..9a6bae7 100644 --- a/src/state.rs +++ b/src/state.rs @@ -27,7 +27,7 @@ pub struct State { config: Arc>, //to avoid cloning config chain: handle::ChainHandle, - save: Arc, + //save: Arc, begin: Initialiser, shutdown: Arc>, @@ -79,13 +79,12 @@ impl State &self.config_cache().outbound_filter } - pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle, save: Arc) -> Self + pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle) -> Self { let (shutdown, shutdown_recv) = watch::channel(false); Self { config: Arc::new(Box::new((config, cache))), chain, - save, begin: Initialiser::new(), shutdown: Arc::new(shutdown), shutdown_recv, @@ -102,14 +101,14 @@ impl State &self.config.as_ref().1 } - pub fn notify_save(&self) + /*pub fn notify_save(&self) { - self.save.notify(); - } + self.save.notify(); +}*/ /*pub fn chain(&self) -> &RwLock> { - &self.chain.as_ref() + &self.chain.as_ref() }*/ pub fn chain_ref(&self) -> &RwLock> { @@ -121,20 +120,29 @@ impl State self.chain.read() } - pub async fn chain_write(&self, buffer: impl IntoIterator) -> Result<(), SendError>> + /// Write to this chain + pub async fn chain_write<'a, T: Stream>(&'a self, buffer: T) -> Result<(), SendError>> + { + self.chain.write_stream(buffer).await + } + + + pub fn when_ref(&self) -> &Arc { - self.chain.write(buffer.into_iter().collect()).await + &self.chain.notify_when() } - pub fn when(&self) -> &Arc + /// Force the chain to push through now + pub fn push_now(&self) { - &self.save + self.chain.push_now() } pub fn shutdown(self) { self.shutdown.broadcast(true).expect("Failed to communicate shutdown"); - self.save.notify(); + self.chain.hang(); + self.when_ref().notify(); } pub fn has_shutdown(&self) -> bool