|
|
|
@ -24,10 +24,14 @@ use tokio::{
|
|
|
|
|
sync::{
|
|
|
|
|
RwLock,
|
|
|
|
|
mpsc,
|
|
|
|
|
Notify,
|
|
|
|
|
},
|
|
|
|
|
stream::{Stream,StreamExt,},
|
|
|
|
|
};
|
|
|
|
|
use cfg_if::cfg_if;
|
|
|
|
|
use serde::{
|
|
|
|
|
Serialize,
|
|
|
|
|
Deserialize
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
macro_rules! status {
|
|
|
|
|
($code:expr) => {
|
|
|
|
@ -35,14 +39,13 @@ macro_rules! status {
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(feature="trust-x-forwarded-for")]
|
|
|
|
|
mod config;
|
|
|
|
|
mod state;
|
|
|
|
|
use state::State;
|
|
|
|
|
mod save;
|
|
|
|
|
mod forwarded_list;
|
|
|
|
|
#[cfg(feature="trust-x-forwarded-for")]
|
|
|
|
|
use forwarded_list::XForwardedFor;
|
|
|
|
|
|
|
|
|
|
const MAX_CONTENT_LENGTH: u64 = 1024 * 1024 * 4; //4MB
|
|
|
|
|
const MAX_GEN_SIZE: usize = 256;
|
|
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
|
pub struct FillBodyError;
|
|
|
|
|
|
|
|
|
@ -57,7 +60,7 @@ impl fmt::Display for FillBodyError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn full_body(who: &IpAddr, chain: Arc<RwLock<Chain<String>>>, mut body: impl Unpin + Stream<Item = Result<impl Buf, impl std::error::Error + 'static>>) -> Result<usize, FillBodyError> {
|
|
|
|
|
async fn full_body(who: &IpAddr, state: State, mut body: impl Unpin + Stream<Item = Result<impl Buf, impl std::error::Error + 'static>>) -> Result<usize, FillBodyError> {
|
|
|
|
|
let mut buffer = Vec::new();
|
|
|
|
|
|
|
|
|
|
let mut written = 0usize;
|
|
|
|
@ -73,8 +76,12 @@ async fn full_body(who: &IpAddr, chain: Arc<RwLock<Chain<String>>>, mut body: im
|
|
|
|
|
|
|
|
|
|
let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?;
|
|
|
|
|
info!("{} -> {:?}", who, buffer);
|
|
|
|
|
let mut chain = chain.write().await;
|
|
|
|
|
chain.feed_str(buffer);
|
|
|
|
|
let mut chain = state.chain().write().await;
|
|
|
|
|
chain.feed(&buffer.split_whitespace()
|
|
|
|
|
.filter(|word| !word.is_empty())
|
|
|
|
|
.map(|s| s.to_owned()).collect::<Vec<_>>());
|
|
|
|
|
|
|
|
|
|
state.notify_save();
|
|
|
|
|
Ok(written)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -91,79 +98,124 @@ impl fmt::Display for GenBodyError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn gen_body(chain: Arc<RwLock<Chain<String>>>, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError>
|
|
|
|
|
async fn gen_body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError>
|
|
|
|
|
{
|
|
|
|
|
let chain = chain.read().await;
|
|
|
|
|
if !chain.is_empty() {
|
|
|
|
|
match num {
|
|
|
|
|
Some(num) if num < MAX_GEN_SIZE => {
|
|
|
|
|
//This could DoS `full_body` and writes, potentially.
|
|
|
|
|
for string in chain.str_iter_for(num) {
|
|
|
|
|
output.send(string).await.map_err(|e| GenBodyError(e.0))?;
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
_ => output.send(chain.generate_str()).await.map_err(|e| GenBodyError(e.0))?,
|
|
|
|
|
}
|
|
|
|
|
let chain = state.chain().read().await;
|
|
|
|
|
if !chain.is_empty() {
|
|
|
|
|
match num {
|
|
|
|
|
Some(num) if num < state.config().max_gen_size => {
|
|
|
|
|
//This could DoS `full_body` and writes, potentially.
|
|
|
|
|
for string in chain.str_iter_for(num) {
|
|
|
|
|
output.send(string).await.map_err(|e| GenBodyError(e.0))?;
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
_ => output.send(chain.generate_str()).await.map_err(|e| GenBodyError(e.0))?,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tokio::main]
|
|
|
|
|
async fn main() {
|
|
|
|
|
pretty_env_logger::init();
|
|
|
|
|
|
|
|
|
|
let chain = Arc::new(RwLock::new(Chain::new()));
|
|
|
|
|
let chain = warp::any().map(move || Arc::clone(&chain));
|
|
|
|
|
|
|
|
|
|
cfg_if!{
|
|
|
|
|
if #[cfg(feature="trust-x-forwarded-for")] {
|
|
|
|
|
let client_ip =
|
|
|
|
|
warp::header("x-forwarded-for")
|
|
|
|
|
let config = match config::load().await {
|
|
|
|
|
Some(v) => v,
|
|
|
|
|
_ => {
|
|
|
|
|
let cfg = config::Config::default();
|
|
|
|
|
#[cfg(debug_assertions)]
|
|
|
|
|
{
|
|
|
|
|
if let Err(err) = cfg.save(config::DEFAULT_FILE_LOCATION).await {
|
|
|
|
|
error!("Failed to create default config file: {}", err);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
cfg
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
trace!("Using config {:?}", config);
|
|
|
|
|
|
|
|
|
|
let chain = Arc::new(RwLock::new(match save::load(&config.file).await {
|
|
|
|
|
Ok(chain) => {
|
|
|
|
|
info!("Loaded chain from {:?}", config.file);
|
|
|
|
|
chain
|
|
|
|
|
},
|
|
|
|
|
Err(e) => {
|
|
|
|
|
warn!("Failed to load chain, creating new");
|
|
|
|
|
trace!("Error: {}", e);
|
|
|
|
|
Chain::new()
|
|
|
|
|
},
|
|
|
|
|
}));
|
|
|
|
|
{
|
|
|
|
|
let (state, chain, saver) = {
|
|
|
|
|
let save_when = Arc::new(Notify::new());
|
|
|
|
|
|
|
|
|
|
let state = State::new(config,
|
|
|
|
|
Arc::clone(&chain),
|
|
|
|
|
Arc::clone(&save_when));
|
|
|
|
|
let state2 = state.clone();
|
|
|
|
|
let saver = tokio::spawn(save::host(state.clone()));
|
|
|
|
|
let chain = warp::any().map(move || state.clone());
|
|
|
|
|
(state2, chain, saver)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let client_ip = if state.config().trust_x_forwarded_for {
|
|
|
|
|
warp::header("x-forwarded-for")
|
|
|
|
|
.map(|ip: XForwardedFor| ip)
|
|
|
|
|
.and_then(|x: XForwardedFor| async move { x.into_first().ok_or_else(|| warp::reject::not_found()) })
|
|
|
|
|
.or(warp::filters::addr::remote()
|
|
|
|
|
.and_then(|x: Option<SocketAddr>| async move { x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found()) }))
|
|
|
|
|
.unify();
|
|
|
|
|
.unify().boxed()
|
|
|
|
|
} else {
|
|
|
|
|
let client_ip = warp::filters::addr::remote().and_then(|x: Option<SocketAddr>| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())});
|
|
|
|
|
warp::filters::addr::remote().and_then(|x: Option<SocketAddr>| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())}).boxed()
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let push = warp::put()
|
|
|
|
|
.and(chain.clone())
|
|
|
|
|
.and(warp::path("put"))
|
|
|
|
|
.and(client_ip.clone())
|
|
|
|
|
.and(warp::body::content_length_limit(state.config().max_content_length))
|
|
|
|
|
.and(warp::body::stream())
|
|
|
|
|
.and_then(|state: State, host: IpAddr, buf| {
|
|
|
|
|
async move {
|
|
|
|
|
full_body(&host, state, buf).await
|
|
|
|
|
.map(|_| warp::reply::with_status(warp::reply(), status!(201)))
|
|
|
|
|
.map_err(warp::reject::custom)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
.with(warp::log("markov::put"));
|
|
|
|
|
|
|
|
|
|
let read = warp::get()
|
|
|
|
|
.and(chain.clone())
|
|
|
|
|
.and(warp::path("get"))
|
|
|
|
|
.and(client_ip.clone())
|
|
|
|
|
.and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::<usize>::None)).unify())
|
|
|
|
|
.and_then(|state: State, host: IpAddr, num: Option<usize>| {
|
|
|
|
|
async move {
|
|
|
|
|
let (tx, rx) = mpsc::channel(state.config().max_gen_size);
|
|
|
|
|
tokio::spawn(gen_body(state, num, tx));
|
|
|
|
|
Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| {
|
|
|
|
|
info!("{} <- {:?}", host, x);
|
|
|
|
|
Ok::<_, std::convert::Infallible>(x)
|
|
|
|
|
}))))
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
.with(warp::log("markov::read"));
|
|
|
|
|
|
|
|
|
|
let (addr, server) = warp::serve(push
|
|
|
|
|
.or(read))
|
|
|
|
|
.bind_with_graceful_shutdown(state.config().bindpoint, async move {
|
|
|
|
|
tokio::signal::ctrl_c().await.unwrap();
|
|
|
|
|
state.shutdown();
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
info!("Server bound on {:?}", addr);
|
|
|
|
|
server.await;
|
|
|
|
|
|
|
|
|
|
// Cleanup
|
|
|
|
|
async move {
|
|
|
|
|
trace!("Cleanup");
|
|
|
|
|
|
|
|
|
|
saver.await.expect("Saver panicked");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let push = warp::put()
|
|
|
|
|
.and(chain.clone())
|
|
|
|
|
.and(warp::path("put"))
|
|
|
|
|
.and(client_ip.clone())
|
|
|
|
|
.and(warp::body::content_length_limit(MAX_CONTENT_LENGTH))
|
|
|
|
|
.and(warp::body::stream())
|
|
|
|
|
.and_then(|chain: Arc<RwLock<Chain<String>>>, host: IpAddr, buf| {
|
|
|
|
|
async move {
|
|
|
|
|
full_body(&host, chain, buf).await
|
|
|
|
|
.map(|_| warp::reply::with_status(warp::reply(), status!(201)))
|
|
|
|
|
.map_err(warp::reject::custom)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
.with(warp::log("markov::put"));
|
|
|
|
|
|
|
|
|
|
let read = warp::get()
|
|
|
|
|
.and(chain.clone())
|
|
|
|
|
.and(warp::path("get"))
|
|
|
|
|
.and(client_ip.clone())
|
|
|
|
|
.and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::<usize>::None)).unify())
|
|
|
|
|
.and_then(|chain: Arc<RwLock<Chain<String>>>, host: IpAddr, num: Option<usize>| {
|
|
|
|
|
async move {
|
|
|
|
|
let (tx, rx) = mpsc::channel(MAX_GEN_SIZE);
|
|
|
|
|
tokio::spawn(gen_body(chain, num, tx));
|
|
|
|
|
Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| {
|
|
|
|
|
info!("{} <- {:?}", host, x);
|
|
|
|
|
Ok::<_, std::convert::Infallible>(x)
|
|
|
|
|
}))))
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
.with(warp::log("markov::read"));
|
|
|
|
|
|
|
|
|
|
let (addr, server) = warp::serve(push
|
|
|
|
|
.or(read))
|
|
|
|
|
.bind_with_graceful_shutdown(([127,0,0,1], 8001), async { tokio::signal::ctrl_c().await.unwrap(); });
|
|
|
|
|
|
|
|
|
|
println!("Server bound on {:?}", addr);
|
|
|
|
|
server.await
|
|
|
|
|
}.await;
|
|
|
|
|
info!("Shut down gracefully")
|
|
|
|
|
}
|
|
|
|
|