working implementation of handler

feed
Avril 4 years ago
parent 5dc10547d5
commit 75730cbe0f
Signed by: flanchan
GPG Key ID: 284488987C31F630

2
Cargo.lock generated

@ -616,7 +616,7 @@ dependencies = [
[[package]] [[package]]
name = "markov" name = "markov"
version = "0.8.2" version = "0.9.0"
dependencies = [ dependencies = [
"async-compression", "async-compression",
"bzip2-sys", "bzip2-sys",

@ -1,6 +1,6 @@
[package] [package]
name = "markov" name = "markov"
version = "0.8.2" version = "0.9.0"
description = "Generate string of text from Markov chain fed by stdin" description = "Generate string of text from Markov chain fed by stdin"
authors = ["Avril <flanchan@cumallover.me>"] authors = ["Avril <flanchan@cumallover.me>"]
edition = "2018" edition = "2018"
@ -36,14 +36,7 @@ split-sentance = []
# NOTE: This does nothing if `split-newlines` is not enabled # NOTE: This does nothing if `split-newlines` is not enabled
always-aggregate = [] always-aggregate = []
# Feeds will hog the buffer lock until the whole body has been fed, instead of acquiring lock every time # Does nothing, legacy thing.
# This will make feeds of many lines faster but can potentially cause DoS
#
# With: ~169ms
# Without: ~195ms
#
# NOTE:
# This does nothing if `always-aggregate` is enabled and/or `split-newlines` is not enabled
hog-buffer = [] hog-buffer = []
# Enable the /api/ route # Enable the /api/ route

@ -7,5 +7,12 @@ trust_x_forwarded_for = false
feed_bounds = '2..' feed_bounds = '2..'
[filter] [filter]
inbound = '<>/\\' inbound = ''
outbound = '' outbound = ''
[writer]
backlog = 32
internal_backlog = 8
capacity = 4
timeout_ms = 5000
throttle_ms = 50

@ -215,6 +215,16 @@ where S: Stream<Item=T>,
&self.buf[..] &self.buf[..]
} }
pub fn get_ref(&self) -> &S
{
self.stream.get_ref()
}
pub fn get_mut(&mut self)-> &mut S
{
self.stream.get_mut()
}
/// Force the next read to send the buffer even if it's not full. /// Force the next read to send the buffer even if it's not full.
/// ///
/// # Note /// # Note
@ -223,6 +233,18 @@ where S: Stream<Item=T>,
{ {
self.push_now= true; self.push_now= true;
} }
/// Consume into the current held buffer
pub fn into_buffer(self) -> Vec<T>
{
self.buf
}
/// Take the buffer now
pub fn take_now(&mut self) -> Into
{
std::mem::replace(&mut self.buf, Vec::with_capacity(self.cap)).into()
}
} }
impl<S, T, Into> Stream for ChunkingStream<S,T, Into> impl<S, T, Into> Stream for ChunkingStream<S,T, Into>
@ -246,6 +268,7 @@ where S: Stream<Item=T>,
_ => return Poll::Pending, _ => return Poll::Pending,
} }
} }
debug!("Sending buffer of {} (cap {})", self.buf.len(), self.cap);
// Buffer is full or we reach end of stream // Buffer is full or we reach end of stream
Poll::Ready(if self.buf.len() == 0 { Poll::Ready(if self.buf.len() == 0 {
None None

@ -28,9 +28,11 @@ pub struct Config
pub save_interval_secs: Option<NonZeroU64>, pub save_interval_secs: Option<NonZeroU64>,
pub trust_x_forwarded_for: bool, pub trust_x_forwarded_for: bool,
#[serde(default)] #[serde(default)]
pub feed_bounds: String,
#[serde(default)]
pub filter: FilterConfig, pub filter: FilterConfig,
#[serde(default)] #[serde(default)]
pub feed_bounds: String, pub writer: WriterConfig,
} }
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)]
@ -41,6 +43,49 @@ pub struct FilterConfig
outbound: String, outbound: String,
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)]
pub struct WriterConfig
{
pub backlog: usize,
pub internal_backlog: usize,
pub capacity: usize,
pub timeout_ms: Option<u64>,
pub throttle_ms: Option<u64>,
}
impl Default for WriterConfig
{
#[inline]
fn default() -> Self
{
Self {
backlog: 32,
internal_backlog: 8,
capacity: 4,
timeout_ms: None,
throttle_ms: None,
}
}
}
impl WriterConfig
{
fn create_settings(self, bounds: range::DynRange<usize>) -> handle::Settings
{
handle::Settings{
backlog: self.backlog,
internal_backlog: self.internal_backlog,
capacity: self.capacity,
timeout: self.timeout_ms.map(tokio::time::Duration::from_millis).unwrap_or(handle::DEFAULT_TIMEOUT),
throttle: self.throttle_ms.map(tokio::time::Duration::from_millis),
bounds,
}
}
}
impl FilterConfig impl FilterConfig
{ {
fn get_inbound_filter(&self) -> sanitise::filter::Filter fn get_inbound_filter(&self) -> sanitise::filter::Filter
@ -77,6 +122,7 @@ impl Default for Config
trust_x_forwarded_for: false, trust_x_forwarded_for: false,
filter: Default::default(), filter: Default::default(),
feed_bounds: "2..".to_owned(), feed_bounds: "2..".to_owned(),
writer: Default::default(),
} }
} }
} }
@ -96,12 +142,14 @@ impl Config
} }
use std::ops::RangeBounds; use std::ops::RangeBounds;
Ok(Cache { let feed_bounds = section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) {
feed_bounds: section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) {
Err(InvalidConfigError("feed_bounds", Box::new(opaque_error!("Bounds not allowed to contains 0 (they were `{}`)", bounds)))) Err(InvalidConfigError("feed_bounds", Box::new(opaque_error!("Bounds not allowed to contains 0 (they were `{}`)", bounds))))
} else { } else {
Ok(bounds) Ok(bounds)
})?, })?;
Ok(Cache {
handler_settings: self.writer.create_settings(feed_bounds.clone()),
feed_bounds,
inbound_filter: self.filter.get_inbound_filter(), inbound_filter: self.filter.get_inbound_filter(),
outbound_filter: self.filter.get_outbound_filter(), outbound_filter: self.filter.get_outbound_filter(),
}) })
@ -205,12 +253,13 @@ impl fmt::Display for InvalidConfigError
/// Caches some parsed config arguments /// Caches some parsed config arguments
#[derive(Clone, PartialEq, Eq)] #[derive(Clone, PartialEq)]
pub struct Cache pub struct Cache
{ {
pub feed_bounds: range::DynRange<usize>, pub feed_bounds: range::DynRange<usize>,
pub inbound_filter: sanitise::filter::Filter, pub inbound_filter: sanitise::filter::Filter,
pub outbound_filter: sanitise::filter::Filter, pub outbound_filter: sanitise::filter::Filter,
pub handler_settings: handle::Settings,
} }
impl fmt::Debug for Cache impl fmt::Debug for Cache
@ -221,6 +270,7 @@ impl fmt::Debug for Cache
.field("feed_bounds", &self.feed_bounds) .field("feed_bounds", &self.feed_bounds)
.field("inbound_filter", &self.inbound_filter.iter().collect::<String>()) .field("inbound_filter", &self.inbound_filter.iter().collect::<String>())
.field("outbound_filter", &self.outbound_filter.iter().collect::<String>()) .field("outbound_filter", &self.outbound_filter.iter().collect::<String>())
.field("handler_settings", &self.handler_settings)
.finish() .finish()
} }
} }

@ -2,7 +2,7 @@
use super::*; use super::*;
#[cfg(any(feature="feed-sentance", feature="split-sentance"))] #[cfg(any(feature="feed-sentance", feature="split-sentance"))]
use sanitise::Sentance; use sanitise::Sentance;
use std::iter; use futures::stream;
pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom<usize> = 2..; pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom<usize> = 2..;
@ -58,7 +58,7 @@ pub fn feed(chain: &mut Chain<String>, what: impl AsRef<str>, bounds: impl std::
} }
debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds"); debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds");
if bounds.contains(&map.len()) { if bounds.contains(&map.len()) {
debug!("Feeding chain {} items", map.len()); //debug!("Feeding chain {} items", map.len());
chain.feed(map); chain.feed(map);
} }
else { else {
@ -80,7 +80,7 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream<Item = R
($buffer:expr) => { ($buffer:expr) => {
{ {
let buffer = $buffer; let buffer = $buffer;
state.chain_write(buffer.map(ToOwned::to_owned)).await.map_err(|_| FillBodyError)?; state.chain_write(buffer).await.map_err(|_| FillBodyError)?;
} }
} }
} }
@ -105,38 +105,40 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream<Item = R
info!("{} -> {:?}", who, buffer); info!("{} -> {:?}", who, buffer);
cfg_if! { cfg_if! {
if #[cfg(feature="split-newlines")] { if #[cfg(feature="split-newlines")] {
feed!(buffer.split('\n').filter(|line| !line.trim().is_empty())) feed!(stream::iter(buffer.split('\n').filter(|line| !line.trim().is_empty())
.map(|x| x.to_owned())))
} else { } else {
feed!(iter::once(buffer)); feed!(stream::once(async move{buffer.into_owned()}));
} }
} }
} else { } else {
use tokio::prelude::*; use tokio::prelude::*;
let reader = chunking::StreamReader::new(body.filter_map(|x| x.map(|mut x| x.to_bytes()).ok())); let reader = chunking::StreamReader::new(body.filter_map(|x| x.map(|mut x| x.to_bytes()).ok()));
let mut lines = reader.lines(); let lines = reader.lines();
#[cfg(feature="hog-buffer")] feed!(lines.filter_map(|x| x.ok().and_then(|line| {
let mut chain = state.chain().write().await;
while let Some(line) = lines.next_line().await.map_err(|_| FillBodyError)? {
let line = state.inbound_filter().filter_cow(&line); let line = state.inbound_filter().filter_cow(&line);
let line = line.trim(); let line = line.trim();
if !line.is_empty() { if !line.is_empty() {
//#[cfg(not(feature="hog-buffer"))] //#[cfg(not(feature="hog-buffer"))]
//let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right? //let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right?
feed!(iter::once(line));
info!("{} -> {:?}", who, line); info!("{} -> {:?}", who, line);
}
written+=line.len(); written+=line.len();
Some(line.to_owned())
} else {
None
} }
})));
} }
} }
if_debug!{ if_debug! {
trace!("Write took {}ms", timer.elapsed().as_millis()); trace!("Write took {}ms", timer.elapsed().as_millis());
} }
state.notify_save();
Ok(written) Ok(written)
} }

@ -15,6 +15,8 @@ use tokio::{
self, self,
error::SendError, error::SendError,
}, },
watch,
Notify,
}, },
task::JoinHandle, task::JoinHandle,
time::{ time::{
@ -24,11 +26,14 @@ use tokio::{
}; };
use futures::StreamExt; use futures::StreamExt;
pub const DEFAULT_TIMEOUT: Duration= Duration::from_secs(5);
/// Settings for chain handler /// Settings for chain handler
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Settings pub struct Settings
{ {
pub backlog: usize, pub backlog: usize,
pub internal_backlog: usize,
pub capacity: usize, pub capacity: usize,
pub timeout: Duration, pub timeout: Duration,
pub throttle: Option<Duration>, pub throttle: Option<Duration>,
@ -38,7 +43,7 @@ pub struct Settings
impl Settings impl Settings
{ {
/// Should we keep this string. /// Should we keep this string.
#[inline] fn matches(&self, s: &str) -> bool #[inline] fn matches(&self, _s: &str) -> bool
{ {
true true
} }
@ -51,6 +56,7 @@ impl Default for Settings
{ {
Self { Self {
backlog: 32, backlog: 32,
internal_backlog: 8,
capacity: 4, capacity: 4,
timeout: Duration::from_secs(5), timeout: Duration::from_secs(5),
throttle: Some(Duration::from_millis(200)), throttle: Some(Duration::from_millis(200)),
@ -64,6 +70,7 @@ impl Default for Settings
struct HostInner<T> struct HostInner<T>
{ {
input: mpsc::Receiver<Vec<T>>, input: mpsc::Receiver<Vec<T>>,
shutdown: watch::Receiver<bool>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -72,6 +79,9 @@ struct Handle<T: Send+ chain::Chainable>
chain: RwLock<chain::Chain<T>>, chain: RwLock<chain::Chain<T>>,
input: mpsc::Sender<Vec<T>>, input: mpsc::Sender<Vec<T>>,
opt: Settings, opt: Settings,
notify_write: Arc<Notify>,
push_now: Arc<Notify>,
shutdown: watch::Sender<bool>,
/// Data used only for the worker task. /// Data used only for the worker task.
host: msg::Once<HostInner<T>>, host: msg::Once<HostInner<T>>,
@ -80,22 +90,23 @@ struct Handle<T: Send+ chain::Chainable>
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ChainHandle<T: Send + chain::Chainable>(Arc<Box<Handle<T>>>); pub struct ChainHandle<T: Send + chain::Chainable>(Arc<Box<Handle<T>>>);
impl<T: Send+ chain::Chainable> ChainHandle<T> impl<T: Send+ chain::Chainable + 'static> ChainHandle<T>
{ {
#[inline] pub fn new(chain: chain::Chain<T>) -> Self
{
Self::with_settings(chain, Default::default())
}
pub fn with_settings(chain: chain::Chain<T>, opt: Settings) -> Self pub fn with_settings(chain: chain::Chain<T>, opt: Settings) -> Self
{ {
let (shutdown_tx, shutdown) = watch::channel(false);
let (itx, irx) = mpsc::channel(opt.backlog); let (itx, irx) = mpsc::channel(opt.backlog);
Self(Arc::new(Box::new(Handle{ Self(Arc::new(Box::new(Handle{
chain: RwLock::new(chain), chain: RwLock::new(chain),
input: itx, input: itx,
opt, opt,
push_now: Arc::new(Notify::new()),
notify_write: Arc::new(Notify::new()),
shutdown: shutdown_tx,
host: msg::Once::new(HostInner{ host: msg::Once::new(HostInner{
input: irx, input: irx,
shutdown,
}) })
}))) })))
} }
@ -122,10 +133,51 @@ impl<T: Send+ chain::Chainable> ChainHandle<T>
} }
/// Send this buffer to the chain /// Send this buffer to the chain
pub async fn write(&self, buf: Vec<T>) -> Result<(), SendError<Vec<T>>> pub fn write(&self, buf: Vec<T>) -> impl futures::Future<Output = Result<(), SendError<Vec<T>>>> + 'static
{
let mut write = self.0.input.clone();
async move {
write.send(buf).await
}
}
/// Send this stream buffer to the chain
pub fn write_stream<'a, I: Stream<Item=T>>(&self, buf: I) -> impl futures::Future<Output = Result<(), SendError<Vec<T>>>> + 'a
where I: 'a
{
let mut write = self.0.input.clone();
async move {
write.send(buf.collect().await).await
}
}
/// Send this buffer to the chain
pub async fn write_in_place(&self, buf: Vec<T>) -> Result<(), SendError<Vec<T>>>
{ {
self.0.input.clone().send(buf).await self.0.input.clone().send(buf).await
} }
/// A referencer for the notifier
pub fn notify_when(&self) -> &Arc<Notify>
{
&self.0.notify_write
}
/// Force the pending buffers to be written to the chain now
pub fn push_now(&self)
{
self.0.push_now.notify();
}
/// Hang the worker thread, preventing it from taking any more inputs and also flushing it.
///
/// # Panics
/// If there was no worker thread.
pub fn hang(&self)
{
trace!("Communicating hang request");
self.0.shutdown.broadcast(true).expect("Failed to communicate hang");
}
} }
impl ChainHandle<String> impl ChainHandle<String>
@ -157,13 +209,13 @@ impl ChainHandle<String>
pub async fn host(from: ChainHandle<String>) pub async fn host(from: ChainHandle<String>)
{ {
let opt = from.0.opt.clone(); let opt = from.0.opt.clone();
let data = from.0.host.unwrap().await; let mut data = from.0.host.unwrap().await;
let (mut tx, child) = { let (mut tx, mut child) = {
// The `real` input channel. // The `real` input channel.
let from = from.clone(); let from = from.clone();
let opt = opt.clone(); let opt = opt.clone();
let (tx, rx) = mpsc::channel::<Vec<Vec<_>>>(opt.backlog); let (tx, rx) = mpsc::channel::<Vec<Vec<_>>>(opt.internal_backlog);
(tx, tokio::spawn(async move { (tx, tokio::spawn(async move {
let mut rx = if let Some(thr) = opt.throttle { let mut rx = if let Some(thr) = opt.throttle {
time::throttle(thr, rx).boxed() time::throttle(thr, rx).boxed()
@ -172,6 +224,8 @@ pub async fn host(from: ChainHandle<String>)
}; };
trace!("child: Begin waiting on parent"); trace!("child: Begin waiting on parent");
while let Some(item) = rx.next().await { while let Some(item) = rx.next().await {
if item.len() > 0 {
info!("Write lock acq");
let mut lock = from.0.chain.write().await; let mut lock = from.0.chain.write().await;
for item in item.into_iter() for item in item.into_iter()
{ {
@ -180,6 +234,9 @@ pub async fn host(from: ChainHandle<String>)
feed::feed(lock.deref_mut(), item, &from.0.opt.bounds); feed::feed(lock.deref_mut(), item, &from.0.opt.bounds);
} }
} }
trace!("Signalling write");
from.0.notify_write.notify();
}
} }
trace!("child: exiting"); trace!("child: exiting");
})) }))
@ -187,44 +244,94 @@ pub async fn host(from: ChainHandle<String>)
trace!("Begin polling on child"); trace!("Begin polling on child");
tokio::select!{ tokio::select!{
v = child => { v = &mut child => {
match v { match v {
#[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen. #[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen.
Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel. Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel.
} }
}, },
_ = async move { _ = async move {
let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh, oh well.
if !data.shutdown.recv().await.unwrap_or(true) { //first shutdown we get for free
while Arc::strong_count(&from.0) > 2 { while Arc::strong_count(&from.0) > 2 {
if *data.shutdown.borrow() {
break;
}
tokio::select!{ tokio::select!{
Some(true) = data.shutdown.recv() => {
debug!("Got shutdown (hang) request. Sending now then breaking");
let mut rest = {
let irx = rx.get_mut();
irx.close(); //accept no more inputs
let mut output = Vec::with_capacity(opt.capacity);
while let Ok(item) = irx.try_recv() {
output.push(item);
}
output
};
rest.extend(rx.take_now());
if rest.len() > 0 {
if let Err(err) = tx.send(rest).await {
error!("Failed to force send buffer, exiting now: {}", err);
}
}
break;
}
_ = time::delay_for(opt.timeout) => { _ = time::delay_for(opt.timeout) => {
trace!("Setting push now");
rx.push_now();
}
_ = from.0.push_now.notified() => {
debug!("Got force push signal");
let take =rx.take_now();
rx.push_now(); rx.push_now();
if take.len() > 0 {
if let Err(err) = tx.send(take).await {
error!("Failed to force send buffer: {}", err);
break;
}
}
} }
Some(buffer) = rx.next() => { Some(buffer) = rx.next() => {
debug!("Sending {} (cap {})", buffer.len(), buffer.capacity());
if let Err(err) = tx.send(buffer).await { if let Err(err) = tx.send(buffer).await {
// Receive closed? // Receive closed?
// //
// This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`. // This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`.
// In any case, it means we should abort. // In any case, it means we should abort.
error!("Failed to send buffer: {}", err); #[cold] error!("Failed to send buffer: {}", err);
break; break;
} }
} }
} }
} }
}
let last = rx.into_buffer();
if last.len() > 0 {
if let Err(err) = tx.send(last).await {
error!("Failed to force send last part of buffer: {}", err);
} else {
trace!("Sent rest of buffer");
}
}
} => { } => {
// Normal exit // Normal exit
trace!("Normal exit") trace!("Normal exit")
}, },
} }
trace!("Waiting on child");
// No more handles except child, no more possible inputs. // No more handles except child, no more possible inputs.
child.await.expect("Child panic");
trace!("Returning"); trace!("Returning");
} }
/// Spawn a new chain handler for this chain. /// Spawn a new chain handler for this chain.
pub fn spawn(from: chain::Chain<String>, opt: Settings) -> (JoinHandle<()>, ChainHandle<String>) pub fn spawn(from: chain::Chain<String>, opt: Settings) -> (JoinHandle<()>, ChainHandle<String>)
{ {
debug!("Spawning with opt: {:?}", opt);
let handle = ChainHandle::with_settings(from, opt); let handle = ChainHandle::with_settings(from, opt);
(tokio::spawn(host(handle.clone())), handle) (tokio::spawn(host(handle.clone())), handle)
} }

@ -145,16 +145,15 @@ async fn main() {
trace!("Error: {}", e); trace!("Error: {}", e);
Chain::new() Chain::new()
}, },
}, Default::default()/*TODO*/); }, ccache.handler_settings.clone());
{ {
let mut tasks = Vec::<BoxFuture<'static, ()>>::new(); let mut tasks = Vec::<BoxFuture<'static, ()>>::new();
tasks.push(chain_handle.map(|res| res.expect("Chain handle panicked")).boxed());
let (state, chain) = { let (state, chain) = {
let save_when = Arc::new(Notify::new());
let state = State::new(config, let state = State::new(config,
ccache, ccache,
chain, chain);
Arc::clone(&save_when));
let state2 = state.clone(); let state2 = state.clone();
let saver = tokio::spawn(save::host(Box::new(state.clone()))); let saver = tokio::spawn(save::host(Box::new(state.clone())));
let chain = warp::any().map(move || state.clone()); let chain = warp::any().map(move || state.clone());

@ -77,10 +77,11 @@ pub async fn host(mut state: Box<State>)
{ {
let to = state.config().file.to_owned(); let to = state.config().file.to_owned();
let interval = state.config().save_interval(); let interval = state.config().save_interval();
let when = Arc::clone(state.when_ref());
trace!("Setup oke. Waiting on init"); trace!("Setup oke. Waiting on init");
if state.on_init().await.is_ok() { if state.on_init().await.is_ok() {
debug!("Begin save handler"); debug!("Begin save handler");
while Arc::strong_count(state.when()) > 1 { while Arc::strong_count(&when) > 1 {
{ {
let chain = state.chain_ref().read().await; let chain = state.chain_ref().read().await;
use std::ops::Deref; use std::ops::Deref;
@ -97,7 +98,7 @@ pub async fn host(mut state: Box<State>)
break; break;
} }
} }
state.when().notified().await; when.notified().await;
if state.has_shutdown() { if state.has_shutdown() {
break; break;
} }

@ -14,6 +14,7 @@ pub async fn handle(mut state: State)
let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("Failed to hook SIGUSR1"); let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("Failed to hook SIGUSR1");
let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2"); let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2");
let mut quit = unix::signal(SignalKind::quit()).expect("Failed to hook SIGQUIT"); let mut quit = unix::signal(SignalKind::quit()).expect("Failed to hook SIGQUIT");
let mut io = unix::signal(SignalKind::io()).expect("Failed to hook IO");
trace!("Setup oke. Waiting on init"); trace!("Setup oke. Waiting on init");
if state.on_init().await.is_ok() { if state.on_init().await.is_ok() {
@ -24,15 +25,11 @@ pub async fn handle(mut state: State)
break; break;
} }
_ = usr1.recv() => { _ = usr1.recv() => {
info!("Got SIGUSR1. Saving chain immediately."); info!("Got SIGUSR1. Causing chain write.");
if let Err(e) = save::save_now(&state).await { state.push_now();
error!("Failed to save chain: {}", e);
} else{
trace!("Saved chain okay");
}
}, },
_ = usr2.recv() => { _ = usr2.recv() => {
info!("Got SIGUSR1. Loading chain immediately."); info!("Got SIGUSR2. Loading chain immediately.");
match save::load(&state.config().file).await { match save::load(&state.config().file).await {
Ok(new) => { Ok(new) => {
{ {
@ -46,6 +43,15 @@ pub async fn handle(mut state: State)
}, },
} }
}, },
_ = io.recv() => {
info!("Got SIGIO. Saving chain immediately.");
if let Err(e) = save::save_now(&state).await {
error!("Failed to save chain: {}", e);
} else{
trace!("Saved chain okay");
}
},
_ = quit.recv() => { _ = quit.recv() => {
warn!("Got SIGQUIT. Saving chain then aborting."); warn!("Got SIGQUIT. Saving chain then aborting.");
if let Err(e) = save::save_now(&state).await { if let Err(e) = save::save_now(&state).await {

@ -27,7 +27,7 @@ pub struct State
{ {
config: Arc<Box<(Config, config::Cache)>>, //to avoid cloning config config: Arc<Box<(Config, config::Cache)>>, //to avoid cloning config
chain: handle::ChainHandle<String>, chain: handle::ChainHandle<String>,
save: Arc<Notify>, //save: Arc<Notify>,
begin: Initialiser, begin: Initialiser,
shutdown: Arc<watch::Sender<bool>>, shutdown: Arc<watch::Sender<bool>>,
@ -79,13 +79,12 @@ impl State
&self.config_cache().outbound_filter &self.config_cache().outbound_filter
} }
pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle<String>, save: Arc<Notify>) -> Self pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle<String>) -> Self
{ {
let (shutdown, shutdown_recv) = watch::channel(false); let (shutdown, shutdown_recv) = watch::channel(false);
Self { Self {
config: Arc::new(Box::new((config, cache))), config: Arc::new(Box::new((config, cache))),
chain, chain,
save,
begin: Initialiser::new(), begin: Initialiser::new(),
shutdown: Arc::new(shutdown), shutdown: Arc::new(shutdown),
shutdown_recv, shutdown_recv,
@ -102,10 +101,10 @@ impl State
&self.config.as_ref().1 &self.config.as_ref().1
} }
pub fn notify_save(&self) /*pub fn notify_save(&self)
{ {
self.save.notify(); self.save.notify();
} }*/
/*pub fn chain(&self) -> &RwLock<Chain<String>> /*pub fn chain(&self) -> &RwLock<Chain<String>>
{ {
@ -121,20 +120,29 @@ impl State
self.chain.read() self.chain.read()
} }
pub async fn chain_write(&self, buffer: impl IntoIterator<Item = String>) -> Result<(), SendError<Vec<String>>> /// Write to this chain
pub async fn chain_write<'a, T: Stream<Item = String>>(&'a self, buffer: T) -> Result<(), SendError<Vec<String>>>
{ {
self.chain.write(buffer.into_iter().collect()).await self.chain.write_stream(buffer).await
} }
pub fn when(&self) -> &Arc<Notify>
pub fn when_ref(&self) -> &Arc<Notify>
{ {
&self.save &self.chain.notify_when()
}
/// Force the chain to push through now
pub fn push_now(&self)
{
self.chain.push_now()
} }
pub fn shutdown(self) pub fn shutdown(self)
{ {
self.shutdown.broadcast(true).expect("Failed to communicate shutdown"); self.shutdown.broadcast(true).expect("Failed to communicate shutdown");
self.save.notify(); self.chain.hang();
self.when_ref().notify();
} }
pub fn has_shutdown(&self) -> bool pub fn has_shutdown(&self) -> bool

Loading…
Cancel
Save