#![allow(dead_code)] #[macro_use] extern crate log; use chain::{ Chain, }; use warp::{ Filter, Buf, reply::Response, http, }; use hyper::Body; use std::{ sync::Arc, fmt, error, net::{ SocketAddr, IpAddr, }, }; use tokio::{ sync::{ RwLock, mpsc, }, stream::{Stream,StreamExt,}, }; use cfg_if::cfg_if; #[cfg(feature="trust-x-forwarded-for")] mod forwarded_list; #[cfg(feature="trust-x-forwarded-for")] use forwarded_list::XForwardedFor; const MAX_CONTENT_LENGTH: u64 = 1024 * 16; const MAX_GEN_SIZE: usize = 256; #[derive(Debug)] pub struct FillBodyError; impl error::Error for FillBodyError{} impl warp::reject::Reject for FillBodyError{} impl fmt::Display for FillBodyError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "failed to feed chain with this data") } } async fn full_body(who: &IpAddr, chain: Arc>>, mut body: impl Unpin + Stream>) -> Result { let mut buffer = Vec::new(); let mut written = 0usize; while let Some(buf) = body.next().await { let mut body = buf.map_err(|_| FillBodyError)?; while body.has_remaining() { buffer.extend_from_slice(body.bytes()); let cnt = body.bytes().len(); body.advance(cnt); written += cnt; } } let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?; info!("{} -> {:?}", who, buffer); let mut chain = chain.write().await; chain.feed_str(buffer); Ok(written) } #[derive(Debug)] pub struct GenBodyError(String); impl error::Error for GenBodyError{} impl fmt::Display for GenBodyError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "failed to write {:?} to body", self.0) } } async fn gen_body(chain: Arc>>, num: Option, mut output: mpsc::Sender) -> Result<(), GenBodyError> { let chain = chain.read().await; if !chain.is_empty() { match num { Some(num) if num < MAX_GEN_SIZE => { for string in chain.str_iter_for(num) { output.send(string).await.map_err(|e| GenBodyError(e.0))?; } }, _ => output.send(chain.generate_str()).await.map_err(|e| GenBodyError(e.0))?, } } Ok(()) } #[tokio::main] async fn main() { pretty_env_logger::init(); let chain = Arc::new(RwLock::new(Chain::new())); let chain = warp::any().map(move || Arc::clone(&chain)); cfg_if!{ if #[cfg(feature="trust-x-forwarded-for")] { let client_ip = warp::header("x-forwarded-for") .map(|ip: XForwardedFor| ip) .and_then(|x: XForwardedFor| async move { x.into_first().ok_or_else(|| warp::reject::not_found()) }) .or(warp::filters::addr::remote() .and_then(|x: Option| async move { x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found()) })) .unify(); } else { let client_ip = warp::filters::addr::remote().and_then(|x: Option| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())}); } } let push = warp::put() .and(chain.clone()) .and(warp::path("put")) .and(client_ip.clone()) .and(warp::body::content_length_limit(MAX_CONTENT_LENGTH)) .and(warp::body::stream()) .and_then(|chain: Arc>>, host: IpAddr, buf| { async move { full_body(&host, chain, buf).await .map(|x| format!("{} bytes fed", x)) .map_err(warp::reject::custom) } }) .map(|x| http::Response::builder() .status(201) .body(x) .unwrap()) .with(warp::log("markov::put")); let read = warp::get() .and(chain.clone()) .and(warp::path("get")) .and(client_ip.clone()) .and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::::None)).unify()) .and_then(|chain: Arc>>, host: IpAddr, num: Option| { async move { let (tx, rx) = mpsc::channel(16); tokio::spawn(gen_body(chain, num, tx)); Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| { info!("{} <- {:?}", host, x); Ok::<_, std::convert::Infallible>(x) })))) } }) .with(warp::log("markov::read")); let (addr, server) = warp::serve(push .or(read)) .bind_with_graceful_shutdown(([127,0,0,1], 8001), async { tokio::signal::ctrl_c().await.unwrap(); }); println!("Server bound on {:?}", addr); server.await }