config! saveload!

serve
Avril 4 years ago
parent 35f1de6c5e
commit ecc8854e44
Signed by: flanchan
GPG Key ID: 284488987C31F630

1
.gitignore vendored

@ -1,2 +1,3 @@
/target
*~
chain.dat

106
Cargo.lock generated

@ -108,6 +108,15 @@ version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38"
[[package]]
name = "cc"
version = "1.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef611cc68ff783f18535d77ddd080185275713d852c4f5cbb6122c462a7a825c"
dependencies = [
"jobserver",
]
[[package]]
name = "cfg-if"
version = "0.1.10"
@ -226,6 +235,7 @@ checksum = "5d8e3078b7b2a8a671cb7a3d17b4760e4181ea243227776ba83fd043b4ca034e"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
@ -248,12 +258,35 @@ version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d674eaa0056896d5ada519900dbf97ead2e46a7b6621e8160d79e2f2e1e2784b"
[[package]]
name = "futures-executor"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc709ca1da6f66143b8c9bec8e6260181869893714e9b5a490b169b0414144ab"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fc94b64bb39543b4e432f1790b6bf18e3ee3b74653c5449f63310e9a74b123c"
[[package]]
name = "futures-macro"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f57ed14da4603b2554682e9f2ff3c65d7567b53188db96cb71538217fc64581b"
dependencies = [
"proc-macro-hack",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.6"
@ -275,11 +308,17 @@ version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a894a0acddba51a2d49a6f4263b1e64b8c579ece8af50fa86503d52cd1eea34"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project",
"pin-utils",
"proc-macro-hack",
"proc-macro-nested",
"slab",
]
@ -341,6 +380,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "half"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d36fab90f82edc3c747f9d438e06cf0a491055896f2a279638bb5beed6c40177"
[[package]]
name = "hashbrown"
version = "0.9.1"
@ -501,6 +546,15 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6"
[[package]]
name = "jobserver"
version = "0.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c71313ebb9439f74b00d9d2dcec36440beaf57a6aa0623068441dd7cd81a7f2"
dependencies = [
"libc",
]
[[package]]
name = "kernel32-sys"
version = "0.2.2"
@ -538,16 +592,32 @@ dependencies = [
"cfg-if 0.1.10",
]
[[package]]
name = "lzzzz"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ba777d9f7fe8793f196dcc7b6cd43a74fb94a98e9e01d5c4f14753a589f9029"
dependencies = [
"cc",
"pin-project",
"tokio",
]
[[package]]
name = "markov"
version = "0.1.2"
version = "0.2.0"
dependencies = [
"cfg-if 1.0.0",
"futures",
"hyper",
"log",
"lzzzz",
"markov 1.1.0",
"pretty_env_logger",
"serde",
"serde_cbor",
"tokio",
"toml",
"warp",
]
@ -779,6 +849,18 @@ dependencies = [
"log",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99c605b9a0adc77b7211c6b1f722dcb613d68d66859a44f3d485a6da332b0598"
[[package]]
name = "proc-macro-nested"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eba180dafb9038b050a4c280019bbedf9f2467b61e5d892dcad585bb57aadc5a"
[[package]]
name = "proc-macro2"
version = "1.0.24"
@ -1015,6 +1097,19 @@ name = "serde"
version = "1.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96fe57af81d28386a513cbc6858332abc6117cfdb5999647c6444b8f43a370a5"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_cbor"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e18acfa2f90e8b735b2836ab8d538de304cbb6729a7360729ea5a895d15a622"
dependencies = [
"half",
"serde",
]
[[package]]
name = "serde_derive"
@ -1237,6 +1332,15 @@ dependencies = [
"tokio",
]
[[package]]
name = "toml"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffc92d160b1eef40665be3a05630d003936a3bc7da7421277846c2613e92c71a"
dependencies = [
"serde",
]
[[package]]
name = "tower-service"
version = "0.3.0"

@ -1,16 +1,12 @@
[package]
name = "markov"
version = "0.1.2"
version = "0.2.0"
description = "Generate string of text from Markov chain fed by stdin"
authors = ["Avril <flanchan@cumallover.me>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
# Trust X-Forwarded-For as real IP(s)
trust-x-forwarded-for = []
[dependencies]
chain = {package = "markov", version = "1.1.0"}
tokio = {version = "0.2", features=["full"]}
@ -19,3 +15,8 @@ pretty_env_logger = "0.4.0"
hyper = "0.13.8"
log = "0.4.11"
cfg-if = "1.0.0"
futures = "0.3.6"
serde_cbor = "0.11.1"
lzzzz = {version = "0.2", features=["tokio-io"]}
serde = {version ="1.0", features=["derive"]}
toml = "0.5.6"

@ -0,0 +1,6 @@
bindpoint = '127.0.0.1:8001'
file = 'chain.dat'
max_content_length = 4194304
max_gen_size = 256
#save_interval_secs = 2
trust_x_forwarded_for = false

@ -0,0 +1,108 @@
//! Server config
use super::*;
use std::{
net::SocketAddr,
path::Path,
io,
borrow::Cow,
num::NonZeroU64,
};
use tokio::{
fs::OpenOptions,
prelude::*,
time::Duration,
io::BufReader,
};
pub const DEFAULT_FILE_LOCATION: &'static str = "markov.toml";
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)]
pub struct Config
{
pub bindpoint: SocketAddr,
pub file: String,
pub max_content_length: u64,
pub max_gen_size: usize,
pub save_interval_secs: Option<NonZeroU64>,
pub trust_x_forwarded_for: bool,
}
impl Default for Config
{
#[inline]
fn default() -> Self
{
Self {
bindpoint: ([127,0,0,1], 8001).into(),
file: "chain.dat".to_owned(),
max_content_length: 1024 * 1024 * 4,
max_gen_size: 256,
save_interval_secs: Some(unsafe{NonZeroU64::new_unchecked(2)}),
trust_x_forwarded_for: false,
}
}
}
impl Config
{
pub fn save_interval(&self) -> Option<Duration>
{
self.save_interval_secs.map(|x| Duration::from_secs(x.into()))
}
pub async fn load(from: impl AsRef<Path>) -> io::Result<Self>
{
let file = OpenOptions::new()
.read(true)
.open(from).await?;
let mut buffer= String::new();
let reader = BufReader::new(file);
let mut lines = reader.lines();
while let Some(line) = lines.next_line().await? {
buffer.push_str(&line[..]);
buffer.push('\n');
}
toml::de::from_str(&buffer[..]).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}
pub async fn save(&self, to: impl AsRef<Path>) -> io::Result<()>
{
let config = toml::ser::to_string_pretty(self).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(to).await?;
file.write_all(config.as_bytes()).await?;
file.shutdown().await?;
Ok(())
}
}
/// Try to load config file specified by args, or default config file
pub fn load() -> impl futures::future::Future<Output =Option<Config>>
{
load_args(std::env::args().skip(1))
}
async fn load_args<I: Iterator<Item=String>>(mut from: I) -> Option<Config>
{
let place = if let Some(arg) = from.next() {
trace!("File {:?} provided", arg);
Cow::Owned(arg)
} else {
warn!("No config file provided. Using default location {:?}", DEFAULT_FILE_LOCATION);
Cow::Borrowed(DEFAULT_FILE_LOCATION)
};
match Config::load(place.as_ref()).await {
Ok(cfg) => {
info!("Loaded config file {:?}", place);
Some(cfg)
},
Err(err) => {
error!("Failed to load config file from {:?}: {}", place, err);
None
},
}
}

@ -24,10 +24,14 @@ use tokio::{
sync::{
RwLock,
mpsc,
Notify,
},
stream::{Stream,StreamExt,},
};
use cfg_if::cfg_if;
use serde::{
Serialize,
Deserialize
};
macro_rules! status {
($code:expr) => {
@ -35,14 +39,13 @@ macro_rules! status {
};
}
#[cfg(feature="trust-x-forwarded-for")]
mod config;
mod state;
use state::State;
mod save;
mod forwarded_list;
#[cfg(feature="trust-x-forwarded-for")]
use forwarded_list::XForwardedFor;
const MAX_CONTENT_LENGTH: u64 = 1024 * 1024 * 4; //4MB
const MAX_GEN_SIZE: usize = 256;
#[derive(Debug)]
pub struct FillBodyError;
@ -57,7 +60,7 @@ impl fmt::Display for FillBodyError
}
async fn full_body(who: &IpAddr, chain: Arc<RwLock<Chain<String>>>, mut body: impl Unpin + Stream<Item = Result<impl Buf, impl std::error::Error + 'static>>) -> Result<usize, FillBodyError> {
async fn full_body(who: &IpAddr, state: State, mut body: impl Unpin + Stream<Item = Result<impl Buf, impl std::error::Error + 'static>>) -> Result<usize, FillBodyError> {
let mut buffer = Vec::new();
let mut written = 0usize;
@ -73,8 +76,12 @@ async fn full_body(who: &IpAddr, chain: Arc<RwLock<Chain<String>>>, mut body: im
let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?;
info!("{} -> {:?}", who, buffer);
let mut chain = chain.write().await;
chain.feed_str(buffer);
let mut chain = state.chain().write().await;
chain.feed(&buffer.split_whitespace()
.filter(|word| !word.is_empty())
.map(|s| s.to_owned()).collect::<Vec<_>>());
state.notify_save();
Ok(written)
}
@ -91,12 +98,12 @@ impl fmt::Display for GenBodyError
}
async fn gen_body(chain: Arc<RwLock<Chain<String>>>, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError>
async fn gen_body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError>
{
let chain = chain.read().await;
let chain = state.chain().read().await;
if !chain.is_empty() {
match num {
Some(num) if num < MAX_GEN_SIZE => {
Some(num) if num < state.config().max_gen_size => {
//This could DoS `full_body` and writes, potentially.
for string in chain.str_iter_for(num) {
output.send(string).await.map_err(|e| GenBodyError(e.0))?;
@ -107,36 +114,69 @@ async fn gen_body(chain: Arc<RwLock<Chain<String>>>, num: Option<usize>, mut out
}
Ok(())
}
#[tokio::main]
async fn main() {
pretty_env_logger::init();
let chain = Arc::new(RwLock::new(Chain::new()));
let chain = warp::any().map(move || Arc::clone(&chain));
let config = match config::load().await {
Some(v) => v,
_ => {
let cfg = config::Config::default();
#[cfg(debug_assertions)]
{
if let Err(err) = cfg.save(config::DEFAULT_FILE_LOCATION).await {
error!("Failed to create default config file: {}", err);
}
}
cfg
},
};
trace!("Using config {:?}", config);
let chain = Arc::new(RwLock::new(match save::load(&config.file).await {
Ok(chain) => {
info!("Loaded chain from {:?}", config.file);
chain
},
Err(e) => {
warn!("Failed to load chain, creating new");
trace!("Error: {}", e);
Chain::new()
},
}));
{
let (state, chain, saver) = {
let save_when = Arc::new(Notify::new());
let state = State::new(config,
Arc::clone(&chain),
Arc::clone(&save_when));
let state2 = state.clone();
let saver = tokio::spawn(save::host(state.clone()));
let chain = warp::any().map(move || state.clone());
(state2, chain, saver)
};
cfg_if!{
if #[cfg(feature="trust-x-forwarded-for")] {
let client_ip =
let client_ip = if state.config().trust_x_forwarded_for {
warp::header("x-forwarded-for")
.map(|ip: XForwardedFor| ip)
.and_then(|x: XForwardedFor| async move { x.into_first().ok_or_else(|| warp::reject::not_found()) })
.or(warp::filters::addr::remote()
.and_then(|x: Option<SocketAddr>| async move { x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found()) }))
.unify();
.unify().boxed()
} else {
let client_ip = warp::filters::addr::remote().and_then(|x: Option<SocketAddr>| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())});
}
}
warp::filters::addr::remote().and_then(|x: Option<SocketAddr>| async move {x.map(|x| x.ip()).ok_or_else(|| warp::reject::not_found())}).boxed()
};
let push = warp::put()
.and(chain.clone())
.and(warp::path("put"))
.and(client_ip.clone())
.and(warp::body::content_length_limit(MAX_CONTENT_LENGTH))
.and(warp::body::content_length_limit(state.config().max_content_length))
.and(warp::body::stream())
.and_then(|chain: Arc<RwLock<Chain<String>>>, host: IpAddr, buf| {
.and_then(|state: State, host: IpAddr, buf| {
async move {
full_body(&host, chain, buf).await
full_body(&host, state, buf).await
.map(|_| warp::reply::with_status(warp::reply(), status!(201)))
.map_err(warp::reject::custom)
}
@ -148,10 +188,10 @@ async fn main() {
.and(warp::path("get"))
.and(client_ip.clone())
.and(warp::path::param().map(|opt: usize| Some(opt)).or(warp::any().map(|| Option::<usize>::None)).unify())
.and_then(|chain: Arc<RwLock<Chain<String>>>, host: IpAddr, num: Option<usize>| {
.and_then(|state: State, host: IpAddr, num: Option<usize>| {
async move {
let (tx, rx) = mpsc::channel(MAX_GEN_SIZE);
tokio::spawn(gen_body(chain, num, tx));
let (tx, rx) = mpsc::channel(state.config().max_gen_size);
tokio::spawn(gen_body(state, num, tx));
Ok::<_, std::convert::Infallible>(Response::new(Body::wrap_stream(rx.map(move |x| {
info!("{} <- {:?}", host, x);
Ok::<_, std::convert::Infallible>(x)
@ -162,8 +202,20 @@ async fn main() {
let (addr, server) = warp::serve(push
.or(read))
.bind_with_graceful_shutdown(([127,0,0,1], 8001), async { tokio::signal::ctrl_c().await.unwrap(); });
.bind_with_graceful_shutdown(state.config().bindpoint, async move {
tokio::signal::ctrl_c().await.unwrap();
state.shutdown();
});
info!("Server bound on {:?}", addr);
server.await;
println!("Server bound on {:?}", addr);
server.await
// Cleanup
async move {
trace!("Cleanup");
saver.await.expect("Saver panicked");
}
}.await;
info!("Shut down gracefully")
}

@ -0,0 +1,89 @@
//! Saving and loading chain
use super::*;
use std::{
sync::Arc,
path::{
Path,
},
io,
};
use tokio::{
time::{
self,
Duration,
},
fs::{
OpenOptions,
},
prelude::*,
};
use futures::{
future::{
OptionFuture,
},
};
use lzzzz::{
lz4f::{
self,
AsyncWriteCompressor,
PreferencesBuilder,
AsyncReadDecompressor,
},
};
const SAVE_INTERVAL: Option<Duration> = Some(Duration::from_secs(2));
pub async fn save_now(chain: &Chain<String>, to: impl AsRef<Path>) -> io::Result<()>
{
debug!("Saving chain to {:?}", to.as_ref());
let file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(to).await?;
let chain = serde_cbor::to_vec(chain).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let mut file = AsyncWriteCompressor::new(file, PreferencesBuilder::new()
.compression_level(lz4f::CLEVEL_HIGH).build())?;
file.write_all(&chain[..]).await?;
file.shutdown().await?;
Ok(())
}
/// Start the save loop for this chain
pub async fn host(state: State)
{
let to = &state.config().file;
let interval = state.config().save_interval();
while Arc::strong_count(state.when()) > 1 {
{
let chain = state.chain().read().await;
use std::ops::Deref;
if let Err(e) = save_now(chain.deref(), &to).await {
error!("Failed to save chain: {}", e);
} else {
info!("Saved chain to {:?}", to);
}
}
if state.has_shutdown() {
break;
}
OptionFuture::from(interval.map(|interval| time::delay_for(interval))).await;
state.when().notified().await;
}
trace!("Saver exiting");
}
/// Try to load a chain from this path
pub async fn load(from: impl AsRef<Path>) -> io::Result<Chain<String>>
{
debug!("Loading chain from {:?}", from.as_ref());
let file = OpenOptions::new()
.read(true)
.open(from).await?;
let mut whole = Vec::new();
let mut file = AsyncReadDecompressor::new(file)?;
tokio::io::copy(&mut file, &mut whole).await?;
serde_cbor::from_slice(&whole[..])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}

@ -0,0 +1,74 @@
//! State
use super::*;
use tokio::{
sync::{
watch,
},
};
use config::Config;
#[derive(Debug, Clone)]
pub struct State
{
config: Arc<Config>, //to avoid cloning config
chain: Arc<RwLock<Chain<String>>>,
save: Arc<Notify>,
shutdown: Arc<watch::Sender<bool>>,
shutdown_recv: watch::Receiver<bool>,
}
impl State
{
pub fn new(config: Config, chain: Arc<RwLock<Chain<String>>>, save: Arc<Notify>) -> Self
{
let (shutdown, shutdown_recv) = watch::channel(false);
Self {
config: Arc::new(config),
chain,
save,
shutdown: Arc::new(shutdown),
shutdown_recv,
}
}
pub fn config(&self) -> &Config
{
self.config.as_ref()
}
pub fn notify_save(&self)
{
self.save.notify();
}
pub fn chain(&self) -> &RwLock<Chain<String>>
{
&self.chain.as_ref()
}
pub fn when(&self) -> &Arc<Notify>
{
&self.save
}
pub fn shutdown(self)
{
self.shutdown.broadcast(true).expect("Failed to communicate shutdown");
self.save.notify();
}
pub fn has_shutdown(&self) -> bool
{
*self.shutdown_recv.borrow()
}
pub async fn on_shutdown(mut self)
{
if !self.has_shutdown() {
while let Some(false) = self.shutdown_recv.recv().await {
}
}
}
}
Loading…
Cancel
Save