diff --git a/.gitignore b/.gitignore index e2a3069..ba830fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target *~ +chain.dat diff --git a/Cargo.lock b/Cargo.lock index 1a9bde9..583238c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,15 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" +[[package]] +name = "cc" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef611cc68ff783f18535d77ddd080185275713d852c4f5cbb6122c462a7a825c" +dependencies = [ + "jobserver", +] + [[package]] name = "cfg-if" version = "0.1.10" @@ -226,6 +235,7 @@ checksum = "5d8e3078b7b2a8a671cb7a3d17b4760e4181ea243227776ba83fd043b4ca034e" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -248,12 +258,35 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d674eaa0056896d5ada519900dbf97ead2e46a7b6621e8160d79e2f2e1e2784b" +[[package]] +name = "futures-executor" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc709ca1da6f66143b8c9bec8e6260181869893714e9b5a490b169b0414144ab" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5fc94b64bb39543b4e432f1790b6bf18e3ee3b74653c5449f63310e9a74b123c" +[[package]] +name = "futures-macro" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f57ed14da4603b2554682e9f2ff3c65d7567b53188db96cb71538217fc64581b" +dependencies = [ + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.6" @@ -275,11 +308,17 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a894a0acddba51a2d49a6f4263b1e64b8c579ece8af50fa86503d52cd1eea34" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project", "pin-utils", + "proc-macro-hack", + "proc-macro-nested", "slab", ] @@ -341,6 +380,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d36fab90f82edc3c747f9d438e06cf0a491055896f2a279638bb5beed6c40177" + [[package]] name = "hashbrown" version = "0.9.1" @@ -501,6 +546,15 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6" +[[package]] +name = "jobserver" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c71313ebb9439f74b00d9d2dcec36440beaf57a6aa0623068441dd7cd81a7f2" +dependencies = [ + "libc", +] + [[package]] name = "kernel32-sys" version = "0.2.2" @@ -538,16 +592,32 @@ dependencies = [ "cfg-if 0.1.10", ] +[[package]] +name = "lzzzz" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ba777d9f7fe8793f196dcc7b6cd43a74fb94a98e9e01d5c4f14753a589f9029" +dependencies = [ + "cc", + "pin-project", + "tokio", +] + [[package]] name = "markov" -version = "0.1.2" +version = "0.2.0" dependencies = [ "cfg-if 1.0.0", + "futures", "hyper", "log", + "lzzzz", "markov 1.1.0", "pretty_env_logger", + "serde", + "serde_cbor", "tokio", + "toml", "warp", ] @@ -779,6 +849,18 @@ dependencies = [ "log", ] +[[package]] +name = "proc-macro-hack" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99c605b9a0adc77b7211c6b1f722dcb613d68d66859a44f3d485a6da332b0598" + +[[package]] +name = "proc-macro-nested" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eba180dafb9038b050a4c280019bbedf9f2467b61e5d892dcad585bb57aadc5a" + [[package]] name = "proc-macro2" version = "1.0.24" @@ -1015,6 +1097,19 @@ name = "serde" version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96fe57af81d28386a513cbc6858332abc6117cfdb5999647c6444b8f43a370a5" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_cbor" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e18acfa2f90e8b735b2836ab8d538de304cbb6729a7360729ea5a895d15a622" +dependencies = [ + "half", + "serde", +] [[package]] name = "serde_derive" @@ -1237,6 +1332,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffc92d160b1eef40665be3a05630d003936a3bc7da7421277846c2613e92c71a" +dependencies = [ + "serde", +] + [[package]] name = "tower-service" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 886cb6e..0bd60bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,12 @@ [package] name = "markov" -version = "0.1.2" +version = "0.2.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[features] -# Trust X-Forwarded-For as real IP(s) -trust-x-forwarded-for = [] - [dependencies] chain = {package = "markov", version = "1.1.0"} tokio = {version = "0.2", features=["full"]} @@ -19,3 +15,8 @@ pretty_env_logger = "0.4.0" hyper = "0.13.8" log = "0.4.11" cfg-if = "1.0.0" +futures = "0.3.6" +serde_cbor = "0.11.1" +lzzzz = {version = "0.2", features=["tokio-io"]} +serde = {version ="1.0", features=["derive"]} +toml = "0.5.6" diff --git a/markov.toml b/markov.toml new file mode 100644 index 0000000..6090632 --- /dev/null +++ b/markov.toml @@ -0,0 +1,6 @@ +bindpoint = '127.0.0.1:8001' +file = 'chain.dat' +max_content_length = 4194304 +max_gen_size = 256 +#save_interval_secs = 2 +trust_x_forwarded_for = false diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..2940f49 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,108 @@ +//! Server config +use super::*; +use std::{ + net::SocketAddr, + path::Path, + io, + borrow::Cow, + num::NonZeroU64, +}; +use tokio::{ + fs::OpenOptions, + prelude::*, + time::Duration, + io::BufReader, +}; + +pub const DEFAULT_FILE_LOCATION: &'static str = "markov.toml"; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] +pub struct Config +{ + pub bindpoint: SocketAddr, + pub file: String, + pub max_content_length: u64, + pub max_gen_size: usize, + pub save_interval_secs: Option, + pub trust_x_forwarded_for: bool, +} + +impl Default for Config +{ + #[inline] + fn default() -> Self + { + Self { + bindpoint: ([127,0,0,1], 8001).into(), + file: "chain.dat".to_owned(), + max_content_length: 1024 * 1024 * 4, + max_gen_size: 256, + save_interval_secs: Some(unsafe{NonZeroU64::new_unchecked(2)}), + trust_x_forwarded_for: false, + } + } +} + +impl Config +{ + pub fn save_interval(&self) -> Option + { + self.save_interval_secs.map(|x| Duration::from_secs(x.into())) + } + pub async fn load(from: impl AsRef) -> io::Result + { + let file = OpenOptions::new() + .read(true) + .open(from).await?; + + let mut buffer= String::new(); + let reader = BufReader::new(file); + let mut lines = reader.lines(); + while let Some(line) = lines.next_line().await? { + buffer.push_str(&line[..]); + buffer.push('\n'); + } + toml::de::from_str(&buffer[..]).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) + } + + pub async fn save(&self, to: impl AsRef) -> io::Result<()> + { + let config = toml::ser::to_string_pretty(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let mut file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(to).await?; + file.write_all(config.as_bytes()).await?; + file.shutdown().await?; + Ok(()) + } +} + +/// Try to load config file specified by args, or default config file +pub fn load() -> impl futures::future::Future> +{ + load_args(std::env::args().skip(1)) +} + +async fn load_args>(mut from: I) -> Option +{ + let place = if let Some(arg) = from.next() { + trace!("File {:?} provided", arg); + Cow::Owned(arg) + } else { + warn!("No config file provided. Using default location {:?}", DEFAULT_FILE_LOCATION); + Cow::Borrowed(DEFAULT_FILE_LOCATION) + }; + + match Config::load(place.as_ref()).await { + Ok(cfg) => { + info!("Loaded config file {:?}", place); + Some(cfg) + }, + Err(err) => { + error!("Failed to load config file from {:?}: {}", place, err); + None + }, + } +} diff --git a/src/main.rs b/src/main.rs index 478f236..42652d7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,10 +24,14 @@ use tokio::{ sync::{ RwLock, mpsc, + Notify, }, stream::{Stream,StreamExt,}, }; -use cfg_if::cfg_if; +use serde::{ + Serialize, + Deserialize +}; macro_rules! status { ($code:expr) => { @@ -35,14 +39,13 @@ macro_rules! status { }; } -#[cfg(feature="trust-x-forwarded-for")] +mod config; +mod state; +use state::State; +mod save; mod forwarded_list; -#[cfg(feature="trust-x-forwarded-for")] use forwarded_list::XForwardedFor; -const MAX_CONTENT_LENGTH: u64 = 1024 * 1024 * 4; //4MB -const MAX_GEN_SIZE: usize = 256; - #[derive(Debug)] pub struct FillBodyError; @@ -57,7 +60,7 @@ impl fmt::Display for FillBodyError } -async fn full_body(who: &IpAddr, chain: Arc>>, mut body: impl Unpin + Stream>) -> Result { +async fn full_body(who: &IpAddr, state: State, mut body: impl Unpin + Stream>) -> Result { let mut buffer = Vec::new(); let mut written = 0usize; @@ -73,8 +76,12 @@ async fn full_body(who: &IpAddr, chain: Arc>>, mut body: im let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?; info!("{} -> {:?}", who, buffer); - let mut chain = chain.write().await; - chain.feed_str(buffer); + let mut chain = state.chain().write().await; + chain.feed(&buffer.split_whitespace() + .filter(|word| !word.is_empty()) + .map(|s| s.to_owned()).collect::>()); + + state.notify_save(); Ok(written) } @@ -91,79 +98,124 @@ impl fmt::Display for GenBodyError } -async fn gen_body(chain: Arc>>, num: Option, mut output: mpsc::Sender) -> Result<(), GenBodyError> +async fn gen_body(state: State, num: Option, mut output: mpsc::Sender) -> Result<(), GenBodyError> { - let chain = chain.read().await; - if !chain.is_empty() { - match num { - Some(num) if num < MAX_GEN_SIZE => { - //This could DoS `full_body` and writes, potentially. - for string in chain.str_iter_for(num) { - output.send(string).await.map_err(|e| GenBodyError(e.0))?; - } - }, - _ => output.send(chain.generate_str()).await.map_err(|e| GenBodyError(e.0))?, - } + let chain = state.chain().read().await; + if !chain.is_empty() { + match num { + Some(num) if num < state.config().max_gen_size => { + //This could DoS `full_body` and writes, potentially. + for string in chain.str_iter_for(num) { + output.send(string).await.map_err(|e| GenBodyError(e.0))?; + } + }, + _ => output.send(chain.generate_str()).await.map_err(|e| GenBodyError(e.0))?, } + } Ok(()) } - #[tokio::main] async fn main() { pretty_env_logger::init(); - - let chain = Arc::new(RwLock::new(Chain::new())); - let chain = warp::any().map(move || Arc::clone(&chain)); - cfg_if!{ - if #[cfg(feature="trust-x-forwarded-for")] { - let client_ip = - warp::header("x-forwarded-for") + let config = match config::load().await { + Some(v) => v, + _ => { + let cfg = config::Config::default(); + #[cfg(debug_assertions)] + { + if let Err(err) = cfg.save(config::DEFAULT_FILE_LOCATION).await { + error!("Failed to create default config file: {}", err); + } + } + cfg + }, + }; + trace!("Using config {:?}", config); + + let chain = Arc::new(RwLock::new(match save::load(&config.file).await { + Ok(chain) => { + info!("Loaded chain from {:?}", config.file); + chain + }, + Err(e) => { + warn!("Failed to load chain, creating new"); + trace!("Error: {}", e); + Chain::new() + }, + })); + { + let (state, chain, saver) = { + let save_when = Arc::new(Notify::new()); + + let state = State::new(config, + Arc::clone(&chain), + Arc::clone(&save_when)); + let state2 = state.clone(); + let saver = tokio::spawn(save::host(state.clone())); + let chain = warp::any().map(move || state.clone()); + (state2, chain, saver) + }; + + let client_ip = if state.config().trust_x_forwarded_for { + warp::header("x-forwarded-for") .map(|ip: XForwardedFor| ip) .and_then(|x: XForwardedFor| async move { x.into_first().ok_or_else(|| warp::reject::not_found()) }) .or(warp::filters::addr::remote() .and_then(|x: Option| async move { x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found()) })) - .unify(); + .unify().boxed() } else { - let client_ip = warp::filters::addr::remote().and_then(|x: Option| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())}); + warp::filters::addr::remote().and_then(|x: Option| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())}).boxed() + }; + + let push = warp::put() + .and(chain.clone()) + .and(warp::path("put")) + .and(client_ip.clone()) + .and(warp::body::content_length_limit(state.config().max_content_length)) + .and(warp::body::stream()) + .and_then(|state: State, host: IpAddr, buf| { + async move { + full_body(&host, state, buf).await + .map(|_| warp::reply::with_status(warp::reply(), status!(201))) + .map_err(warp::reject::custom) + } + }) + .with(warp::log("markov::put")); + + let read = warp::get() + .and(chain.clone()) + .and(warp::path("get")) + .and(client_ip.clone()) + .and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::::None)).unify()) + .and_then(|state: State, host: IpAddr, num: Option| { + async move { + let (tx, rx) = mpsc::channel(state.config().max_gen_size); + tokio::spawn(gen_body(state, num, tx)); + Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| { + info!("{} <- {:?}", host, x); + Ok::<_, std::convert::Infallible>(x) + })))) + } + }) + .with(warp::log("markov::read")); + + let (addr, server) = warp::serve(push + .or(read)) + .bind_with_graceful_shutdown(state.config().bindpoint, async move { + tokio::signal::ctrl_c().await.unwrap(); + state.shutdown(); + }); + + info!("Server bound on {:?}", addr); + server.await; + + // Cleanup + async move { + trace!("Cleanup"); + + saver.await.expect("Saver panicked"); } - } - let push = warp::put() - .and(chain.clone()) - .and(warp::path("put")) - .and(client_ip.clone()) - .and(warp::body::content_length_limit(MAX_CONTENT_LENGTH)) - .and(warp::body::stream()) - .and_then(|chain: Arc>>, host: IpAddr, buf| { - async move { - full_body(&host, chain, buf).await - .map(|_| warp::reply::with_status(warp::reply(), status!(201))) - .map_err(warp::reject::custom) - } - }) - .with(warp::log("markov::put")); - - let read = warp::get() - .and(chain.clone()) - .and(warp::path("get")) - .and(client_ip.clone()) - .and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::::None)).unify()) - .and_then(|chain: Arc>>, host: IpAddr, num: Option| { - async move { - let (tx, rx) = mpsc::channel(MAX_GEN_SIZE); - tokio::spawn(gen_body(chain, num, tx)); - Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| { - info!("{} <- {:?}", host, x); - Ok::<_, std::convert::Infallible>(x) - })))) - } - }) - .with(warp::log("markov::read")); - - let (addr, server) = warp::serve(push - .or(read)) - .bind_with_graceful_shutdown(([127,0,0,1], 8001), async { tokio::signal::ctrl_c().await.unwrap(); }); - - println!("Server bound on {:?}", addr); - server.await + }.await; + info!("Shut down gracefully") } diff --git a/src/save.rs b/src/save.rs new file mode 100644 index 0000000..b7f8b8a --- /dev/null +++ b/src/save.rs @@ -0,0 +1,89 @@ +//! Saving and loading chain +use super::*; +use std::{ + sync::Arc, + path::{ + Path, + }, + io, +}; +use tokio::{ + time::{ + self, + Duration, + }, + fs::{ + OpenOptions, + }, + prelude::*, +}; +use futures::{ + future::{ + OptionFuture, + }, +}; +use lzzzz::{ + lz4f::{ + self, + AsyncWriteCompressor, + PreferencesBuilder, + AsyncReadDecompressor, + }, +}; + +const SAVE_INTERVAL: Option = Some(Duration::from_secs(2)); + +pub async fn save_now(chain: &Chain, to: impl AsRef) -> io::Result<()> +{ + debug!("Saving chain to {:?}", to.as_ref()); + let file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(to).await?; + let chain = serde_cbor::to_vec(chain).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let mut file = AsyncWriteCompressor::new(file, PreferencesBuilder::new() + .compression_level(lz4f::CLEVEL_HIGH).build())?; + file.write_all(&chain[..]).await?; + file.shutdown().await?; + Ok(()) +} + +/// Start the save loop for this chain +pub async fn host(state: State) +{ + let to = &state.config().file; + let interval = state.config().save_interval(); + while Arc::strong_count(state.when()) > 1 { + { + let chain = state.chain().read().await; + use std::ops::Deref; + if let Err(e) = save_now(chain.deref(), &to).await { + error!("Failed to save chain: {}", e); + } else { + info!("Saved chain to {:?}", to); + } + } + if state.has_shutdown() { + break; + } + OptionFuture::from(interval.map(|interval| time::delay_for(interval))).await; + state.when().notified().await; + } + trace!("Saver exiting"); +} + +/// Try to load a chain from this path +pub async fn load(from: impl AsRef) -> io::Result> +{ + debug!("Loading chain from {:?}", from.as_ref()); + let file = OpenOptions::new() + .read(true) + .open(from).await?; + let mut whole = Vec::new(); + let mut file = AsyncReadDecompressor::new(file)?; + tokio::io::copy(&mut file, &mut whole).await?; + serde_cbor::from_slice(&whole[..]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..5cc8d56 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,74 @@ +//! State +use super::*; +use tokio::{ + sync::{ + watch, + }, +}; +use config::Config; + +#[derive(Debug, Clone)] +pub struct State +{ + config: Arc, //to avoid cloning config + chain: Arc>>, + save: Arc, + + shutdown: Arc>, + shutdown_recv: watch::Receiver, +} + +impl State +{ + pub fn new(config: Config, chain: Arc>>, save: Arc) -> Self + { + let (shutdown, shutdown_recv) = watch::channel(false); + Self { + config: Arc::new(config), + chain, + save, + shutdown: Arc::new(shutdown), + shutdown_recv, + } + } + + pub fn config(&self) -> &Config + { + self.config.as_ref() + } + + pub fn notify_save(&self) + { + self.save.notify(); + } + + pub fn chain(&self) -> &RwLock> + { + &self.chain.as_ref() + } + + pub fn when(&self) -> &Arc + { + &self.save + } + + pub fn shutdown(self) + { + self.shutdown.broadcast(true).expect("Failed to communicate shutdown"); + self.save.notify(); + } + + pub fn has_shutdown(&self) -> bool + { + *self.shutdown_recv.borrow() + } + + pub async fn on_shutdown(mut self) + { + if !self.has_shutdown() { + while let Some(false) = self.shutdown_recv.recv().await { + + } + } + } +}