diff --git a/Cargo.toml b/Cargo.toml index a256ae6..66c22fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.4.1" +version = "0.5.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" @@ -8,7 +8,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["compress-chain"] +default = ["compress-chain", "split-newlines"] # Compress the chain data file when saved to disk compress-chain = ["async-compression"] @@ -16,6 +16,24 @@ compress-chain = ["async-compression"] # Treat each new line as a new set to feed instead of feeding the whole data at once split-newlines = [] +# Always aggregate incoming buffer instead of streaming them +# This will make feeds faster but allocate full buffers for the aggregated body +# +# Large write: ~95ms +# +# NOTE: This does nothing if `split-newlines` is not enabled +always-aggregate = [] + +# Feeds will hog the buffer lock until the whole body has been fed, instead of acquiring lock every time +# This will make feeds of many lines faster but can potentially cause DoS +# +# With: ~169ms +# Without: ~195ms +# +# NOTE: +# This does nothing if `always-aggregate` is enabled and/or `split-newlines` is not enabled +hog-buffer = [] + # Enable the /api/ route api = [] diff --git a/Makefile b/Makefile index ddfe99b..23a29d5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -FEATURES:="api,split-newlines" +FEATURES:="api,always-aggregate" markov: cargo build --release --features $(FEATURES) diff --git a/src/api.rs b/src/api.rs index 54e36ba..98cee79 100644 --- a/src/api.rs +++ b/src/api.rs @@ -14,17 +14,17 @@ use futures::{ }, }; -fn aggregate(mut body: impl Buf) -> Result +#[inline] fn aggregate(mut body: impl Buf) -> Result { - let mut output = Vec::new(); + /*let mut output = Vec::new(); while body.has_remaining() { - let bytes = body.bytes(); - output.extend_from_slice(&bytes[..]); - let cnt = bytes.len(); - body.advance(cnt); - } + let bytes = body.bytes(); + output.extend_from_slice(&bytes[..]); + let cnt = bytes.len(); + body.advance(cnt); +}*/ - String::from_utf8(output) + std::str::from_utf8(&body.to_bytes()).map(ToOwned::to_owned) } pub async fn single(host: IpAddr, num: Option, body: impl Buf) -> Result @@ -37,13 +37,18 @@ pub async fn single(host: IpAddr, num: Option, body: impl Buf) -> Result< .map_err(warp::reject::custom) } +//TODO: Change to stream impl like normal `feed` has, instead of taking aggregate? async fn single_stream(host: IpAddr, num: Option, body: impl Buf) -> Result>, ApiError> { let body = aggregate(body)?; info!("{} <- {:?}", host, &body[..]); let mut chain = Chain::new(); - cfg_if!{ + + if_debug! { + let timer = std::time::Instant::now(); + } + cfg_if! { if #[cfg(feature="split-newlines")] { for body in body.split('\n').filter(|line| !line.trim().is_empty()) { feed::feed(&mut chain, body); @@ -52,6 +57,9 @@ async fn single_stream(host: IpAddr, num: Option, body: impl Buf) -> Resu feed::feed(&mut chain, body); } } + if_debug!{ + trace!("Write took {}ms", timer.elapsed().as_millis()); + } match num { None => Ok(stream::iter(iter::once(Ok(chain.generate_str()))).boxed()), Some(num) => { @@ -82,9 +90,9 @@ impl std::fmt::Display for ApiError } } -impl From for ApiError +impl From for ApiError { - fn from(_: std::string::FromUtf8Error) -> Self + fn from(_: std::str::Utf8Error) -> Self { Self::Body } diff --git a/src/feed.rs b/src/feed.rs index 4b2281a..fee939c 100644 --- a/src/feed.rs +++ b/src/feed.rs @@ -11,37 +11,68 @@ pub fn feed(chain: &mut Chain, what: impl AsRef) } } -pub async fn full(who: &IpAddr, state: State, mut body: impl Unpin + Stream>) -> Result { - let mut buffer = Vec::new(); +pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream>) -> Result { let mut written = 0usize; - //TODO: Change to pushing lines to mpsc channel, instead of manually aggregating. - while let Some(buf) = body.next().await { - let mut body = buf.map_err(|_| FillBodyError)?; - while body.has_remaining() { - if body.bytes().len() > 0 { - buffer.extend_from_slice(body.bytes()); - let cnt = body.bytes().len(); - body.advance(cnt); - written += cnt; - } - } + if_debug! { + let timer = std::time::Instant::now(); } - let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?; - info!("{} -> {:?}", who, buffer); - let mut chain = state.chain().write().await; - cfg_if! { - if #[cfg(feature="split-newlines")] { - for buffer in buffer.split('\n').filter(|line| !line.trim().is_empty()) { - feed(&mut chain, buffer); + cfg_if!{ + if #[cfg(any(not(feature="split-newlines"), feature="always-aggregate"))] { + let mut body = body; + let mut buffer = Vec::new(); + while let Some(buf) = body.next().await { + let mut body = buf.map_err(|_| FillBodyError)?; + while body.has_remaining() { + if body.bytes().len() > 0 { + buffer.extend_from_slice(body.bytes()); + let cnt = body.bytes().len(); + body.advance(cnt); + written += cnt; + } + } + } + let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?; + info!("{} -> {:?}", who, buffer); + let mut chain = state.chain().write().await; + cfg_if! { + if #[cfg(feature="split-newlines")] { + for buffer in buffer.split('\n').filter(|line| !line.trim().is_empty()) { + feed(&mut chain, buffer); + } + } else { + feed(&mut chain, buffer); + + } } } else { - feed(&mut chain, buffer); + use tokio::prelude::*; + let reader = chunking::StreamReader::new(body.map(|x| x.map(|mut x| x.to_bytes()).unwrap_or_default())); + let mut lines = reader.lines(); + + #[cfg(feature="hog-buffer")] + let mut chain = state.chain().write().await; + while let Some(line) = lines.next_line().await.map_err(|_| FillBodyError)? { + let line = line.trim(); + if !line.is_empty() { + #[cfg(not(feature="hog-buffer"))] + let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right? + + feed(&mut chain, line); + info!("{} -> {:?}", who, line); + } + written+=line.len(); + } } } + + if_debug!{ + trace!("Write took {}ms", timer.elapsed().as_millis()); + } state.notify_save(); Ok(written) + } diff --git a/src/main.rs b/src/main.rs index d85414a..6925d8d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -41,6 +41,16 @@ use futures::{ }; use cfg_if::cfg_if; +macro_rules! if_debug { + ($($tt:tt)*) => { + cfg_if::cfg_if!{ + if #[cfg(debug_assertions)] { + $($tt)* + } + } + } +} + macro_rules! status { ($code:expr) => { ::warp::http::status::StatusCode::from_u16($code).unwrap()