diff --git a/Cargo.toml b/Cargo.toml index 2e31c38..91e2b5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.3.4" +version = "0.4.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" @@ -9,8 +9,13 @@ edition = "2018" [features] default = ["compress-chain"] + +# Compress the chain data file when saved to disk compress-chain = ["async-compression"] +# Enable the /api/ route +api = [] + [profile.release] opt-level = 3 lto = "fat" diff --git a/src/api.rs b/src/api.rs new file mode 100644 index 0000000..95b1415 --- /dev/null +++ b/src/api.rs @@ -0,0 +1,83 @@ +//! For API calls if enabled +use super::*; +use std::{ + fmt, + error, + iter, + convert::Infallible, +}; +use futures::{ + stream::{ + self, + BoxStream, + StreamExt, + }, +}; + +fn aggregate(mut body: impl Buf) -> Result +{ + let mut output = Vec::new(); + while body.has_remaining() { + let bytes = body.bytes(); + output.extend_from_slice(&bytes[..]); + let cnt = bytes.len(); + body.advance(cnt); + } + + String::from_utf8(output) +} + +pub async fn single(host: IpAddr, num: Option, body: impl Buf) -> Result +{ + single_stream(host, num, body).await + .map(|rx| Response::new(Body::wrap_stream(rx.map(move |x| { + info!("{} <- {:?}", host, x); + x + })))) + .map_err(warp::reject::custom) +} + +async fn single_stream(host: IpAddr, num: Option, body: impl Buf) -> Result>, ApiError> +{ + let body = aggregate(body)?; + info!("{} <- {:?}", host, &body[..]); + + let mut chain = Chain::new(); + feed::feed(&mut chain, body); + match num { + None => Ok(stream::iter(iter::once(Ok(chain.generate_str()))).boxed()), + Some(num) => { + let (mut tx, rx) = mpsc::channel(num); + tokio::spawn(async move { + for string in chain.str_iter_for(num) { + tx.send(string).await.expect("Failed to send string to body"); + } + }); + Ok(StreamExt::map(rx, |x| Ok::<_, Infallible>(x)).boxed()) + } + } +} + +#[derive(Debug)] +pub enum ApiError { + Body, +} +impl warp::reject::Reject for ApiError{} +impl error::Error for ApiError{} +impl std::fmt::Display for ApiError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + match self { + Self::Body => write!(f, "invalid data in request body"), + } + } +} + +impl From for ApiError +{ + fn from(_: std::string::FromUtf8Error) -> Self + { + Self::Body + } +} diff --git a/src/feed.rs b/src/feed.rs index 0cbc5f0..2b90da3 100644 --- a/src/feed.rs +++ b/src/feed.rs @@ -1,7 +1,7 @@ //! Feeding the chain use super::*; -fn feed(chain: &mut Chain, what: impl AsRef) +pub fn feed(chain: &mut Chain, what: impl AsRef) { chain.feed(what.as_ref().split_whitespace() .filter(|word| !word.is_empty()) diff --git a/src/main.rs b/src/main.rs index 248d603..278db34 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,14 +39,15 @@ use futures::{ join_all, }, }; - +use cfg_if::cfg_if; macro_rules! status { ($code:expr) => { ::warp::http::status::StatusCode::from_u16($code).unwrap() }; } - +#[cfg(feature="api")] +mod api; #[cfg(target_family="unix")] mod signals; mod config; @@ -131,6 +132,37 @@ async fn main() { }) .with(warp::log("markov::put")); + + cfg_if!{ + if #[cfg(feature="api")] { + let api = { + let api_single = { + let msz = state.config().max_gen_size; + warp::post() + .and(warp::path("single")) + .and(client_ip.clone()) + .and(warp::path::param() + .map(move |sz: usize| { + if sz == 0 || (2..=msz).contains(&sz) { + Some(sz) + } else { + None + } + }) + .or(warp::any().map(|| None)) + .unify()) + .and(warp::body::content_length_limit(state.config().max_content_length)) + .and(warp::body::aggregate()) + .and_then(api::single) + .with(warp::log("markov::api::single")) + }; + + warp::path("api") + .and(api_single) + }; + } + } + let read = warp::get() .and(chain.clone()) .and(warp::path("get")) @@ -147,7 +179,9 @@ async fn main() { } }) .with(warp::log("markov::read")); - + + #[cfg(feature="api")] + let read = read.or(api); #[cfg(target_family="unix")] tasks.push(tokio::spawn(signals::handle(state.clone())).map(|res| res.expect("Signal handler panicked")).boxed());