From 5ba673e64f530d8853a5c11828a37090bc85f740 Mon Sep 17 00:00:00 2001 From: Avril Date: Mon, 12 Oct 2020 19:25:30 +0100 Subject: [PATCH] safe main-thread panics; save cannot happen until server initialised --- Cargo.toml | 2 +- src/ext.rs | 12 ++++++ src/main.rs | 17 +++++--- src/msg.rs | 108 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/save.rs | 39 +++++++++++-------- src/state.rs | 51 +++++++++++++++++++++++- 6 files changed, 205 insertions(+), 24 deletions(-) create mode 100644 src/msg.rs diff --git a/Cargo.toml b/Cargo.toml index 43e1b01..c9b5940 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.7.1" +version = "0.7.2" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" diff --git a/src/ext.rs b/src/ext.rs index ad80d79..5b05901 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -137,6 +137,18 @@ impl AssertNotSend t } +/// Require a value implements a specific trait +#[macro_export] macro_rules! require_impl { + ($t:path: $val:expr) => { + { + #[inline(always)] fn require_impl(val: T) -> T { + val + } + require_impl($val) + } + } +} + impl Deref for AssertNotSend { type Target = T; diff --git a/src/main.rs b/src/main.rs index 80af06f..73f6f67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,6 +71,7 @@ mod api; #[cfg(target_family="unix")] mod signals; mod config; +mod msg; mod state; use state::State; mod save; @@ -139,7 +140,7 @@ async fn main() { Arc::clone(&chain), Arc::clone(&save_when)); let state2 = state.clone(); - let saver = tokio::spawn(save::host(state.clone())); + let saver = tokio::spawn(save::host(Box::new(state.clone()))); let chain = warp::any().map(move || state.clone()); tasks.push(saver.map(|res| res.expect("Saver panicked")).boxed()); @@ -256,8 +257,8 @@ async fn main() { #[cfg(target_family="unix")] tasks.push(tokio::spawn(signals::handle(state.clone())).map(|res| res.expect("Signal handler panicked")).boxed()); - require_send(async { - let server = { + require_impl!(Send: async { + let (server, init) = { let s2 = AssertNotSend::new(state.clone()); //temp clone the Arcs here for shutdown if server fails to bind, assert they cannot remain cloned across an await boundary. match bind::try_serve(warp::serve(push .or(read)), @@ -268,7 +269,7 @@ async fn main() { }) { Ok((addr, server)) => { info!("Server bound on {:?}", addr); - server + (server, s2.into_inner().into_save_initialiser()) }, Err(err) => { error!("Failed to bind server: {}", err); @@ -277,7 +278,13 @@ async fn main() { }, } }; - server.await; + tokio::join![ + server, + async move { + trace!("Init set"); + init.set().expect("Failed to initialise saver") + }, + ]; }).await; // Cleanup diff --git a/src/msg.rs b/src/msg.rs new file mode 100644 index 0000000..a5bfc08 --- /dev/null +++ b/src/msg.rs @@ -0,0 +1,108 @@ +//! Message passing things +use super::*; +use tokio::{ + sync::{ + watch, + }, +}; +use std::{ + task::{Poll, Context}, + pin::Pin, + fmt, + error, +}; +use futures::{ + future::Future, +}; + +#[derive(Debug)] +pub struct InitError; +#[derive(Debug)] +pub struct InitWaitError; + +impl error::Error for InitError{} +impl fmt::Display for InitError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "failed to set init value") + } +} + +impl error::Error for InitWaitError{} +impl fmt::Display for InitWaitError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "failed to receive init value") + } +} + +#[derive(Clone, Debug)] +pub struct Initialiser +{ + tx: Arc>, + rx: watch::Receiver +} + +impl Initialiser +{ + pub fn new() -> Self + { + let (tx, rx) = watch::channel(false); + Self { + tx: Arc::new(tx), + rx, + } + } + + pub fn new_set() -> Self + { + let (tx, rx) = watch::channel(true); + Self { + tx: Arc::new(tx), + rx, + } + } + + pub async fn wait(&mut self) -> Result<(), InitWaitError> + { + if !*self.rx.borrow() { + self.rx.recv().await + .ok_or_else(|| InitWaitError) + .and_then(|x| if x {Ok(())} else {Err(InitWaitError)}) + } else { + Ok(()) + } + } + + pub fn is_set(&self) -> bool + { + *self.rx.borrow() + } + + pub fn set(self) -> Result<(), InitError> + { + if !*self.rx.borrow() { + self.tx.broadcast(true).map_err(|_| InitError) + } else { + Ok(()) + } + } +} + +impl Future for Initialiser +{ + type Output = Result<(), InitWaitError>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + + if !*self.rx.borrow() { + let rx = self.rx.recv(); + tokio::pin!(rx); + rx.poll(cx).map(|x| x.ok_or_else(|| InitWaitError) + .and_then(|x| if x {Ok(())} else {Err(InitWaitError)})) + } else { + Poll::Ready(Ok(())) + } + } +} diff --git a/src/save.rs b/src/save.rs index c7e3c41..2ad51d0 100644 --- a/src/save.rs +++ b/src/save.rs @@ -73,31 +73,36 @@ async fn save_now_to(chain: &Chain, to: impl AsRef) -> io::Result< } /// Start the save loop for this chain -pub async fn host(mut state: State) +pub async fn host(mut state: Box) { let to = state.config().file.to_owned(); let interval = state.config().save_interval(); - while Arc::strong_count(state.when()) > 1 { - { - let chain = state.chain().read().await; - use std::ops::Deref; - if let Err(e) = save_now_to(chain.deref(), &to).await { - error!("Failed to save chain: {}", e); - } else { - info!("Saved chain to {:?}", to); + if state.on_init_save().await.is_ok() { + trace!("Init get"); + while Arc::strong_count(state.when()) > 1 { + { + let chain = state.chain().read().await; + use std::ops::Deref; + if let Err(e) = save_now_to(chain.deref(), &to).await { + error!("Failed to save chain: {}", e); + } else { + info!("Saved chain to {:?}", to); + } } - } - tokio::select!{ - _ = OptionFuture::from(interval.map(|interval| time::delay_for(interval))) => {}, - _ = state.on_shutdown() => { + tokio::select!{ + _ = OptionFuture::from(interval.map(|interval| time::delay_for(interval))) => {}, + _ = state.on_shutdown() => { + break; + } + } + state.when().notified().await; + if state.has_shutdown() { break; } } - state.when().notified().await; - if state.has_shutdown() { - break; - } + } else { + trace!("Shutdown called before init completed"); } trace!("Saver exiting"); } diff --git a/src/state.rs b/src/state.rs index 459431e..4888dd4 100644 --- a/src/state.rs +++ b/src/state.rs @@ -6,6 +6,20 @@ use tokio::{ }, }; use config::Config; +use msg::Initialiser; + +#[derive(Debug)] +pub struct ShutdownError; + +impl error::Error for ShutdownError{} +impl fmt::Display for ShutdownError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "shutdown signal caught") + } +} + #[derive(Debug, Clone)] pub struct State @@ -14,6 +28,7 @@ pub struct State exclude: Arc<(sanitise::filter::Filter, sanitise::filter::Filter)>, chain: Arc>>, save: Arc, + save_begin: Initialiser, shutdown: Arc>, shutdown_recv: watch::Receiver, @@ -21,11 +36,44 @@ pub struct State impl State { + /// Consume this `state` into its initialiser + pub fn into_save_initialiser(self) -> Initialiser + { + self.save_begin + } + + /// Allow the saver task to start work + pub fn init_save(self) -> Result<(), msg::InitError> + { + self.save_begin.set() + } + + /// Has `init_save` been called? + pub fn is_init_save(&self) -> bool + { + self.save_begin.is_set() + } + + /// A future that completes either when `init_save` is called, or `shutdown`. + pub async fn on_init_save(&mut self) -> Result<(), ShutdownError> + { + tokio::select!{ + Ok(()) = self.save_begin.wait() => { + Ok(()) + } + Some(true) = self.shutdown_recv.recv() => { + debug!("on_init_save(): shutdown received"); + Err(ShutdownError) + } + else => Err(ShutdownError) + } + } + pub fn inbound_filter(&self) -> &sanitise::filter::Filter { &self.exclude.0 } - pub fn outbound_filter(&self) -> &sanitise::filter::Filter + pub fn outbound_filter(&self) -> &sanitise::filter::Filter { &self.exclude.1 } @@ -39,6 +87,7 @@ impl State config: Arc::new(config), chain, save, + save_begin: Initialiser::new(), shutdown: Arc::new(shutdown), shutdown_recv, }