diff --git a/Cargo.lock b/Cargo.lock index 3308926..3c8e63a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,12 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + [[package]] name = "cloudabi" version = "0.0.3" @@ -311,7 +317,7 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "wasi 0.9.0+wasi-snapshot-preview1", ] @@ -529,13 +535,16 @@ version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", ] [[package]] name = "markov" version = "0.1.1" dependencies = [ + "cfg-if 1.0.0", + "hyper", + "log", "markov 1.1.0", "pretty_env_logger", "tokio", @@ -591,7 +600,7 @@ version = "0.6.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fce347092656428bc8eaf6201042cb551b8d67855af7374542a92a0fbfcac430" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "fuchsia-zircon", "fuchsia-zircon-sys", "iovec", @@ -673,7 +682,7 @@ version = "0.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ebc3ec692ed7c9a255596c67808dee269f64655d8baf7b4f0638e51ba1d6853" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "winapi 0.3.9", ] @@ -1072,7 +1081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "170a36ea86c864a3f16dd2687712dd6646f7019f301e57537c7f4dc9f5916770" dependencies = [ "block-buffer 0.9.0", - "cfg-if", + "cfg-if 0.1.10", "cpuid-bool", "digest 0.9.0", "opaque-debug 0.3.0", @@ -1100,7 +1109,7 @@ version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1fa70dc5c8104ec096f4fe7ede7a221d35ae13dcd19ba1ad9a81d2cab9a1c44" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "redox_syscall", "winapi 0.3.9", @@ -1123,7 +1132,7 @@ version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "rand 0.7.3", "redox_syscall", @@ -1240,7 +1249,7 @@ version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0987850db3733619253fe60e17cb59b82d37c7e6c0236bb81e4d6b87c879f27" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "log", "pin-project-lite", "tracing-core", diff --git a/Cargo.toml b/Cargo.toml index b8a9e57..60d719e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,15 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +# Trust X-Forwarded-For as real IP(s) +trust-x-forwarded-for = [] + [dependencies] chain = {package = "markov", version = "1.1.0"} tokio = {version = "0.2", features=["full"]} warp = "0.2" pretty_env_logger = "0.4.0" +hyper = "0.13.8" +log = "0.4.11" +cfg-if = "1.0.0" diff --git a/README b/README index 5c0318f..c17a163 100644 --- a/README +++ b/README @@ -1,6 +1,16 @@ -Generate strings from markov chain of stdin +HTTP server connecting to a Markov chain -Usage: +Feeding: +# PUT /put +Request body is fed to the chain -$ cat corpus | markov -$ cat corpus | markov +NOTE: Strings fed to the chain must be valid UTF-8 and below 16 KB in size + +Generating: +# GET /get +Generate a string from the chain + +# GET /get/ +Generate strings from the chain + +NOTE: Number must be lower than 256 \ No newline at end of file diff --git a/src/forwarded_list.rs b/src/forwarded_list.rs new file mode 100644 index 0000000..302be9f --- /dev/null +++ b/src/forwarded_list.rs @@ -0,0 +1,74 @@ +use std::{ + net::{ + IpAddr, + AddrParseError, + }, + str, + error, + fmt, +}; + +#[derive(Debug)] +pub struct XFormatError; + +impl error::Error for XFormatError{} + +impl fmt::Display for XFormatError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "X-Forwarded-For was not in the correct format") + } +} + +#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Default)] +pub struct XForwardedFor(Vec); + +impl XForwardedFor +{ + pub fn new() -> Self + { + Self(Vec::new()) + } + pub fn single(ip: impl Into) -> Self + { + Self(vec![ip.into()]) + } + pub fn addrs(&self) -> &[IpAddr] + { + &self.0[..] + } + + pub fn into_first(self) -> Option + { + self.0.into_iter().next() + } + + pub fn into_addrs(self) -> Vec + { + self.0 + } +} + +impl str::FromStr for XForwardedFor +{ + type Err = XFormatError; + + fn from_str(s: &str) -> Result { + let mut output = Vec::new(); + for next in s.split(',') + { + output.push(next.trim().parse()?) + } + Ok(Self(output)) + } +} + +impl From for XFormatError +{ + #[inline(always)] fn from(_: AddrParseError) -> Self + { + Self + } +} + diff --git a/src/main.rs b/src/main.rs index d5c3f68..c4de034 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,38 +1,103 @@ +#![allow(dead_code)] + +#[macro_use] extern crate log; + use chain::{ Chain, }; use warp::{ Filter, Buf, + reply::Response, }; +use hyper::Body; use std::{ sync::Arc, + fmt, + error, + net::{ + SocketAddr, + IpAddr, + }, }; use tokio::{ sync::{ RwLock, + mpsc, }, stream::{Stream,StreamExt,}, - prelude::*, }; +use cfg_if::cfg_if; + +#[cfg(feature="trust-x-forwarded-for")] +mod forwarded_list; +#[cfg(feature="trust-x-forwarded-for")] +use forwarded_list::XForwardedFor; const MAX_CONTENT_LENGTH: u64 = 1024 * 16; const MAX_GEN_SIZE: usize = 256; -async fn full_body(chain: &mut Chain, mut body: impl Unpin + Stream>) -> Result<(), Box> { +#[derive(Debug)] +pub struct FillBodyError; + +impl error::Error for FillBodyError{} +impl warp::reject::Reject for FillBodyError{} +impl fmt::Display for FillBodyError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "failed to feed chain with this data") + } +} + + +async fn full_body(who: &IpAddr, chain: Arc>>, mut body: impl Unpin + Stream>) -> Result { let mut buffer = Vec::new(); - + + let mut written = 0usize; while let Some(buf) = body.next().await { - let mut body = buf?; + let mut body = buf.map_err(|_| FillBodyError)?; while body.has_remaining() { buffer.extend_from_slice(body.bytes()); let cnt = body.bytes().len(); body.advance(cnt); + written += cnt; } } - let buffer = std::str::from_utf8(&buffer[..])?; + let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?; + info!("{} -> {:?}", who, buffer); + let mut chain = chain.write().await; chain.feed_str(buffer); + Ok(written) +} + +#[derive(Debug)] +pub struct GenBodyError(String); + +impl error::Error for GenBodyError{} +impl fmt::Display for GenBodyError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "failed to write {:?} to body", self.0) + } +} + + +async fn gen_body(chain: Arc>>, num: Option, mut output: mpsc::Sender) -> Result<(), GenBodyError> +{ + let chain = chain.read().await; + if !chain.is_empty() { + match num { + Some(num) if num < MAX_GEN_SIZE => { + for string in chain.str_iter_for(num) { + output.send(string).await.map_err(|e| GenBodyError(e.0))?; + } + }, + _ => output.send(chain.generate_str()).await.map_err(|e| GenBodyError(e.0))?, + } + } Ok(()) } @@ -43,39 +108,53 @@ async fn main() { let chain = Arc::new(RwLock::new(Chain::new())); let chain = warp::any().map(move || Arc::clone(&chain)); + cfg_if!{ + if #[cfg(feature="trust-x-forwarded-for")] { + let client_ip = + warp::header("x-forwarded-for") + .map(|ip: XForwardedFor| ip) + .and_then(|x: XForwardedFor| async move { x.into_first().ok_or_else(|| warp::reject::not_found()) }) + .or(warp::filters::addr::remote() + .and_then(|x: Option| async move { x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found()) })) + .unify(); + } else { + let client_ip = warp::filters::addr::remote().and_then(|x: Option| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())}); + } + } let push = warp::put() .and(chain.clone()) .and(warp::path("put")) + .and(client_ip.clone()) .and(warp::body::content_length_limit(MAX_CONTENT_LENGTH)) .and(warp::body::stream()) - .and_then(|chain: Arc>>, buf| { + .and_then(|chain: Arc>>, host: IpAddr, buf| { async move { - use std::ops::DerefMut; - let res = format!("{:?}", full_body(chain.write().await.deref_mut(), buf).await); - Ok::(res) + full_body(&host, chain, buf).await + .map(|x| format!("{} bytes fed", x)) + .map_err(warp::reject::custom) } }); let read = warp::get() .and(chain.clone()) .and(warp::path("get")) + .and(client_ip.clone()) .and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::::None)).unify()) - .and_then(|chain: Arc>>, num: Option| { + .and_then(|chain: Arc>>, host: IpAddr, num: Option| { async move { - let chain = chain.read().await; - if chain.is_empty() { - Ok(String::default()) - } else { - match num { - Some(num) if num < MAX_GEN_SIZE => Ok(chain.str_iter_for(num).collect()), - _ => Ok::(chain.generate_str()), - } - } + let (tx, rx) = mpsc::channel(16); + tokio::spawn(gen_body(chain, num, tx)); + Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| { + info!("{} <- {:?}", host, x); + Ok::<_, std::convert::Infallible>(x) + })))) } }); - warp::serve(push - .or(read)) - .bind_with_graceful_shutdown(([127,0,0,1], 8777), async { tokio::signal::ctrl_c().await.unwrap(); }).1 - .await + let (addr, server) = warp::serve(push + .or(read)) + .bind_with_graceful_shutdown(([127,0,0,1], 8777), async { tokio::signal::ctrl_c().await.unwrap(); }); + + println!("Server bound on {:?}", addr); + server.await }