diff --git a/Cargo.toml b/Cargo.toml index 0bd60bc..5d8526a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.2.0" +version = "0.3.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" diff --git a/src/main.rs b/src/main.rs index 1eb0a8f..248d603 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,6 +32,14 @@ use serde::{ Serialize, Deserialize }; +use futures::{ + future::{ + FutureExt, + BoxFuture, + join_all, + }, +}; + macro_rules! status { ($code:expr) => { @@ -39,6 +47,8 @@ macro_rules! status { }; } +#[cfg(target_family="unix")] +mod signals; mod config; mod state; use state::State; @@ -80,7 +90,8 @@ async fn main() { }, })); { - let (state, chain, saver) = { + let mut tasks = Vec::>::new(); + let (state, chain) = { let save_when = Arc::new(Notify::new()); let state = State::new(config, @@ -89,7 +100,9 @@ async fn main() { let state2 = state.clone(); let saver = tokio::spawn(save::host(state.clone())); let chain = warp::any().map(move || state.clone()); - (state2, chain, saver) + + tasks.push(saver.map(|res| res.expect("Saver panicked")).boxed()); + (state2, chain) }; let client_ip = if state.config().trust_x_forwarded_for { @@ -134,6 +147,10 @@ async fn main() { } }) .with(warp::log("markov::read")); + + + #[cfg(target_family="unix")] + tasks.push(tokio::spawn(signals::handle(state.clone())).map(|res| res.expect("Signal handler panicked")).boxed()); let (addr, server) = warp::serve(push .or(read)) @@ -141,15 +158,14 @@ async fn main() { tokio::signal::ctrl_c().await.unwrap(); state.shutdown(); }); - info!("Server bound on {:?}", addr); server.await; // Cleanup async move { trace!("Cleanup"); - - saver.await.expect("Saver panicked"); + + join_all(tasks).await; } }.await; info!("Shut down gracefully") diff --git a/src/save.rs b/src/save.rs index b7f8b8a..fa8461c 100644 --- a/src/save.rs +++ b/src/save.rs @@ -33,7 +33,16 @@ use lzzzz::{ const SAVE_INTERVAL: Option = Some(Duration::from_secs(2)); -pub async fn save_now(chain: &Chain, to: impl AsRef) -> io::Result<()> + +pub async fn save_now(state: &State) -> io::Result<()> +{ + let chain = state.chain().read().await; + use std::ops::Deref; + let to = &state.config().file; + save_now_to(chain.deref(),to).await +} + +async fn save_now_to(chain: &Chain, to: impl AsRef) -> io::Result<()> { debug!("Saving chain to {:?}", to.as_ref()); let file = OpenOptions::new() @@ -59,7 +68,7 @@ pub async fn host(state: State) { let chain = state.chain().read().await; use std::ops::Deref; - if let Err(e) = save_now(chain.deref(), &to).await { + if let Err(e) = save_now_to(chain.deref(), &to).await { error!("Failed to save chain: {}", e); } else { info!("Saved chain to {:?}", to); diff --git a/src/signals.rs b/src/signals.rs new file mode 100644 index 0000000..8a8a4d0 --- /dev/null +++ b/src/signals.rs @@ -0,0 +1,59 @@ +//! Unix signals +use super::*; +use tokio::{ + signal::unix::{ + self, + SignalKind, + }, +}; + + + +pub async fn handle(mut state: State) +{ + let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("Failed to hook SIGUSR1"); + let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2"); + let mut quit = unix::signal(SignalKind::quit()).expect("Failed to hook SIGQUIT"); + + loop { + tokio::select! { + _ = state.on_shutdown() => { + break; + } + _ = usr1.recv() => { + info!("Got SIGUSR1. Saving chain immediately."); + if let Err(e) = save::save_now(&state).await { + error!("Failed to save chain: {}", e); + } else{ + trace!("Saved chain okay"); + } + }, + _ = usr2.recv() => { + info!("Got SIGUSR1. Loading chain immediately."); + match save::load(&state.config().file).await { + Ok(new) => { + { + let mut chain = state.chain().write().await; + *chain = new; + } + trace!("Replaced with read chain"); + }, + Err(e) => { + error!("Failed to load chain from file, keeping current: {}", e); + }, + } + }, + _ = quit.recv() => { + warn!("Got SIGQUIT. Saving chain then aborting."); + if let Err(e) = save::save_now(&state).await { + error!("Failed to save chain: {}", e); + } else{ + trace!("Saved chain okay."); + } + error!("Aborting"); + std::process::abort() + }, + } + } + trace!("Graceful shutdown"); +} diff --git a/src/state.rs b/src/state.rs index 5cc8d56..6ec7883 100644 --- a/src/state.rs +++ b/src/state.rs @@ -63,7 +63,7 @@ impl State *self.shutdown_recv.borrow() } - pub async fn on_shutdown(mut self) + pub async fn on_shutdown(&mut self) { if !self.has_shutdown() { while let Some(false) = self.shutdown_recv.recv().await {