You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
genmarkov/src/main.rs

220 lines
5.1 KiB

#![allow(dead_code)]
#[macro_use] extern crate log;
use chain::{
Chain,
};
use warp::{
Filter,
Buf,
reply::Response,
};
use hyper::Body;
use std::{
sync::Arc,
fmt,
error,
net::{
SocketAddr,
IpAddr,
},
};
use tokio::{
sync::{
RwLock,
mpsc,
Notify,
},
stream::{Stream,StreamExt,},
};
use serde::{
Serialize,
Deserialize
};
use futures::{
future::{
FutureExt,
BoxFuture,
join_all,
},
};
use cfg_if::cfg_if;
macro_rules! if_debug {
($($tt:tt)*) => {
cfg_if::cfg_if!{
if #[cfg(debug_assertions)] {
$($tt)*
}
}
}
}
macro_rules! status {
($code:expr) => {
::warp::http::status::StatusCode::from_u16($code).unwrap()
};
}
mod bytes;
mod chunking;
#[cfg(feature="api")]
mod api;
#[cfg(target_family="unix")]
mod signals;
mod config;
mod state;
use state::State;
mod save;
mod forwarded_list;
use forwarded_list::XForwardedFor;
mod feed;
mod gen;
#[tokio::main]
async fn main() {
pretty_env_logger::init();
let config = match config::load().await {
Some(v) => v,
_ => {
let cfg = config::Config::default();
#[cfg(debug_assertions)]
{
if let Err(err) = cfg.save(config::DEFAULT_FILE_LOCATION).await {
error!("Failed to create default config file: {}", err);
}
}
cfg
},
};
trace!("Using config {:?}", config);
let chain = Arc::new(RwLock::new(match save::load(&config.file).await {
Ok(chain) => {
info!("Loaded chain from {:?}", config.file);
chain
},
Err(e) => {
warn!("Failed to load chain, creating new");
trace!("Error: {}", e);
Chain::new()
},
}));
{
let mut tasks = Vec::<BoxFuture<'static, ()>>::new();
let (state, chain) = {
let save_when = Arc::new(Notify::new());
let state = State::new(config,
Arc::clone(&chain),
Arc::clone(&save_when));
let state2 = state.clone();
let saver = tokio::spawn(save::host(state.clone()));
let chain = warp::any().map(move || state.clone());
tasks.push(saver.map(|res| res.expect("Saver panicked")).boxed());
(state2, chain)
};
let client_ip = if state.config().trust_x_forwarded_for {
warp::header("x-forwarded-for")
.map(|ip: XForwardedFor| ip)
.and_then(|x: XForwardedFor| async move { x.into_first().ok_or_else(|| warp::reject::not_found()) })
.or(warp::filters::addr::remote()
.and_then(|x: Option<SocketAddr>| async move { x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found()) }))
.unify().boxed()
} else {
warp::filters::addr::remote().and_then(|x: Option<SocketAddr>| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())}).boxed()
};
let push = warp::put()
.and(chain.clone())
.and(warp::path("put"))
.and(client_ip.clone())
.and(warp::body::content_length_limit(state.config().max_content_length))
.and(warp::body::stream())
.and_then(|state: State, host: IpAddr, buf| {
async move {
feed::full(&host, state, buf).await
.map(|_| warp::reply::with_status(warp::reply(), status!(201)))
.map_err(|_| warp::reject::not_found()) //(warp::reject::custom) //TODO: Recover rejection filter down below for custom error return
}
})
.with(warp::log("markov::put"));
cfg_if!{
if #[cfg(feature="api")] {
let api = {
let api_single = {
let msz = state.config().max_gen_size;
warp::post()
.and(warp::path("single"))
.and(client_ip.clone())
.and(warp::path::param()
.map(move |sz: usize| {
if sz == 0 || (2..=msz).contains(&sz) {
Some(sz)
} else {
None
}
})
.or(warp::any().map(|| None))
.unify())
.and(warp::body::content_length_limit(state.config().max_content_length))
.and(warp::body::aggregate())
.and_then(api::single)
.with(warp::log("markov::api::single"))
};
warp::path("api")
.and(api_single)
};
}
}
let read = warp::get()
.and(chain.clone())
.and(warp::path("get"))
.and(client_ip.clone())
.and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::<usize>::None)).unify())
.and_then(|state: State, host: IpAddr, num: Option<usize>| {
async move {
let (tx, rx) = mpsc::channel(state.config().max_gen_size);
tokio::spawn(gen::body(state, num, tx));
Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| {
info!("{} <- {:?}", host, x);
Ok::<_, std::convert::Infallible>(x)
}))))
}
})
.with(warp::log("markov::read"));
#[cfg(feature="api")]
let read = read.or(api);
#[cfg(target_family="unix")]
tasks.push(tokio::spawn(signals::handle(state.clone())).map(|res| res.expect("Signal handler panicked")).boxed());
let (addr, server) = warp::serve(push
.or(read))
.bind_with_graceful_shutdown(state.config().bindpoint, async move {
tokio::signal::ctrl_c().await.unwrap();
state.shutdown();
});
info!("Server bound on {:?}", addr);
server.await;
// Cleanup
async move {
trace!("Cleanup");
join_all(tasks).await;
}
}.await;
info!("Shut down gracefully")
}