Merge branch 'feed' into master

dedup
Avril 4 years ago
commit acf2ac605e
Signed by: flanchan
GPG Key ID: 284488987C31F630

2
Cargo.lock generated

@ -616,7 +616,7 @@ dependencies = [
[[package]] [[package]]
name = "markov" name = "markov"
version = "0.8.2" version = "0.9.0"
dependencies = [ dependencies = [
"async-compression", "async-compression",
"bzip2-sys", "bzip2-sys",

@ -1,6 +1,6 @@
[package] [package]
name = "markov" name = "markov"
version = "0.8.2" version = "0.9.0"
description = "Generate string of text from Markov chain fed by stdin" description = "Generate string of text from Markov chain fed by stdin"
authors = ["Avril <flanchan@cumallover.me>"] authors = ["Avril <flanchan@cumallover.me>"]
edition = "2018" edition = "2018"
@ -36,14 +36,7 @@ split-sentance = []
# NOTE: This does nothing if `split-newlines` is not enabled # NOTE: This does nothing if `split-newlines` is not enabled
always-aggregate = [] always-aggregate = []
# Feeds will hog the buffer lock until the whole body has been fed, instead of acquiring lock every time # Does nothing, legacy thing.
# This will make feeds of many lines faster but can potentially cause DoS
#
# With: ~169ms
# Without: ~195ms
#
# NOTE:
# This does nothing if `always-aggregate` is enabled and/or `split-newlines` is not enabled
hog-buffer = [] hog-buffer = []
# Enable the /api/ route # Enable the /api/ route

@ -7,5 +7,12 @@ trust_x_forwarded_for = false
feed_bounds = '2..' feed_bounds = '2..'
[filter] [filter]
inbound = '<>/\\' inbound = ''
outbound = '' outbound = ''
[writer]
backlog = 32
internal_backlog = 8
capacity = 4
timeout_ms = 5000
throttle_ms = 50

@ -6,6 +6,7 @@ use std::{
Context, Context,
}, },
pin::Pin, pin::Pin,
marker::PhantomData,
}; };
use tokio::{ use tokio::{
io::{ io::{
@ -173,3 +174,109 @@ mod tests
assert_eq!(&output[..], "Hello world\nHow are you"); assert_eq!(&output[..], "Hello world\nHow are you");
} }
} }
/// A stream that chunks its input.
#[pin_project]
pub struct ChunkingStream<S, T, Into=Vec<T>>
{
#[pin] stream: Fuse<S>,
buf: Vec<T>,
cap: usize,
_output: PhantomData<Into>,
push_now: bool,
}
impl<S, T, Into> ChunkingStream<S,T, Into>
where S: Stream<Item=T>,
Into: From<Vec<T>>
{
pub fn new(stream: S, sz: usize) -> Self
{
Self {
stream: stream.fuse(),
buf: Vec::with_capacity(sz),
cap: sz,
_output: PhantomData,
push_now: false,
}
}
pub fn into_inner(self) -> S
{
self.stream.into_inner()
}
pub fn cap(&self) -> usize
{
self.cap
}
pub fn buffer(&self) -> &[T]
{
&self.buf[..]
}
pub fn get_ref(&self) -> &S
{
self.stream.get_ref()
}
pub fn get_mut(&mut self)-> &mut S
{
self.stream.get_mut()
}
/// Force the next read to send the buffer even if it's not full.
///
/// # Note
/// The buffer still won't send if it's empty.
pub fn push_now(&mut self)
{
self.push_now= true;
}
/// Consume into the current held buffer
pub fn into_buffer(self) -> Vec<T>
{
self.buf
}
/// Take the buffer now
pub fn take_now(&mut self) -> Into
{
std::mem::replace(&mut self.buf, Vec::with_capacity(self.cap)).into()
}
}
impl<S, T, Into> Stream for ChunkingStream<S,T, Into>
where S: Stream<Item=T>,
Into: From<Vec<T>>
{
type Item = Into;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
while !(self.push_now && !self.buf.is_empty()) && self.buf.len() < self.cap {
// Buffer isn't full, keep filling
let this = self.as_mut().project();
match this.stream.poll_next(cx) {
Poll::Ready(None) => {
// Stream is over
break;
},
Poll::Ready(Some(item)) => {
this.buf.push(item);
},
_ => return Poll::Pending,
}
}
debug!("Sending buffer of {} (cap {})", self.buf.len(), self.cap);
// Buffer is full or we reach end of stream
Poll::Ready(if self.buf.len() == 0 {
None
} else {
let this = self.project();
*this.push_now = false;
let output = std::mem::replace(this.buf, Vec::with_capacity(*this.cap));
Some(output.into())
})
}
}

@ -28,9 +28,11 @@ pub struct Config
pub save_interval_secs: Option<NonZeroU64>, pub save_interval_secs: Option<NonZeroU64>,
pub trust_x_forwarded_for: bool, pub trust_x_forwarded_for: bool,
#[serde(default)] #[serde(default)]
pub feed_bounds: String,
#[serde(default)]
pub filter: FilterConfig, pub filter: FilterConfig,
#[serde(default)] #[serde(default)]
pub feed_bounds: String, pub writer: WriterConfig,
} }
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)]
@ -41,6 +43,49 @@ pub struct FilterConfig
outbound: String, outbound: String,
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)]
pub struct WriterConfig
{
pub backlog: usize,
pub internal_backlog: usize,
pub capacity: usize,
pub timeout_ms: Option<u64>,
pub throttle_ms: Option<u64>,
}
impl Default for WriterConfig
{
#[inline]
fn default() -> Self
{
Self {
backlog: 32,
internal_backlog: 8,
capacity: 4,
timeout_ms: None,
throttle_ms: None,
}
}
}
impl WriterConfig
{
fn create_settings(self, bounds: range::DynRange<usize>) -> handle::Settings
{
handle::Settings{
backlog: self.backlog,
internal_backlog: self.internal_backlog,
capacity: self.capacity,
timeout: self.timeout_ms.map(tokio::time::Duration::from_millis).unwrap_or(handle::DEFAULT_TIMEOUT),
throttle: self.throttle_ms.map(tokio::time::Duration::from_millis),
bounds,
}
}
}
impl FilterConfig impl FilterConfig
{ {
fn get_inbound_filter(&self) -> sanitise::filter::Filter fn get_inbound_filter(&self) -> sanitise::filter::Filter
@ -77,6 +122,7 @@ impl Default for Config
trust_x_forwarded_for: false, trust_x_forwarded_for: false,
filter: Default::default(), filter: Default::default(),
feed_bounds: "2..".to_owned(), feed_bounds: "2..".to_owned(),
writer: Default::default(),
} }
} }
} }
@ -96,12 +142,14 @@ impl Config
} }
use std::ops::RangeBounds; use std::ops::RangeBounds;
Ok(Cache { let feed_bounds = section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) {
feed_bounds: section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) {
Err(InvalidConfigError("feed_bounds", Box::new(opaque_error!("Bounds not allowed to contains 0 (they were `{}`)", bounds)))) Err(InvalidConfigError("feed_bounds", Box::new(opaque_error!("Bounds not allowed to contains 0 (they were `{}`)", bounds))))
} else { } else {
Ok(bounds) Ok(bounds)
})?, })?;
Ok(Cache {
handler_settings: self.writer.create_settings(feed_bounds.clone()),
feed_bounds,
inbound_filter: self.filter.get_inbound_filter(), inbound_filter: self.filter.get_inbound_filter(),
outbound_filter: self.filter.get_outbound_filter(), outbound_filter: self.filter.get_outbound_filter(),
}) })
@ -205,12 +253,13 @@ impl fmt::Display for InvalidConfigError
/// Caches some parsed config arguments /// Caches some parsed config arguments
#[derive(Clone, PartialEq, Eq)] #[derive(Clone, PartialEq)]
pub struct Cache pub struct Cache
{ {
pub feed_bounds: range::DynRange<usize>, pub feed_bounds: range::DynRange<usize>,
pub inbound_filter: sanitise::filter::Filter, pub inbound_filter: sanitise::filter::Filter,
pub outbound_filter: sanitise::filter::Filter, pub outbound_filter: sanitise::filter::Filter,
pub handler_settings: handle::Settings,
} }
impl fmt::Debug for Cache impl fmt::Debug for Cache
@ -221,6 +270,7 @@ impl fmt::Debug for Cache
.field("feed_bounds", &self.feed_bounds) .field("feed_bounds", &self.feed_bounds)
.field("inbound_filter", &self.inbound_filter.iter().collect::<String>()) .field("inbound_filter", &self.inbound_filter.iter().collect::<String>())
.field("outbound_filter", &self.outbound_filter.iter().collect::<String>()) .field("outbound_filter", &self.outbound_filter.iter().collect::<String>())
.field("handler_settings", &self.handler_settings)
.finish() .finish()
} }
} }

@ -1,4 +1,5 @@
//! Extensions //! Extensions
use super::*;
use std::{ use std::{
iter, iter,
ops::{ ops::{
@ -162,3 +163,21 @@ impl<T> DerefMut for AssertNotSend<T>
&mut self.0 &mut self.0
} }
} }
pub trait ChunkStreamExt<T>: Sized
{
fn chunk_into<I: From<Vec<T>>>(self, sz: usize) -> chunking::ChunkingStream<Self,T,I>;
fn chunk(self, sz: usize) -> chunking::ChunkingStream<Self, T>
{
self.chunk_into(sz)
}
}
impl<S, T> ChunkStreamExt<T> for S
where S: Stream<Item=T>
{
fn chunk_into<I: From<Vec<T>>>(self, sz: usize) -> chunking::ChunkingStream<Self,T,I>
{
chunking::ChunkingStream::new(self, sz)
}
}

@ -2,8 +2,10 @@
use super::*; use super::*;
#[cfg(any(feature="feed-sentance", feature="split-sentance"))] #[cfg(any(feature="feed-sentance", feature="split-sentance"))]
use sanitise::Sentance; use sanitise::Sentance;
use futures::stream;
pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom<usize> = 2..; //TODO: Add to config somehow
pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom<usize> = 2..;
/// Feed `what` into `chain`, at least `bounds` tokens. /// Feed `what` into `chain`, at least `bounds` tokens.
/// ///
@ -56,7 +58,7 @@ pub fn feed(chain: &mut Chain<String>, what: impl AsRef<str>, bounds: impl std::
} }
debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds"); debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds");
if bounds.contains(&map.len()) { if bounds.contains(&map.len()) {
debug!("Feeding chain {} items", map.len()); //debug!("Feeding chain {} items", map.len());
chain.feed(map); chain.feed(map);
} }
else { else {
@ -73,12 +75,12 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream<Item = R
if_debug! { if_debug! {
let timer = std::time::Instant::now(); let timer = std::time::Instant::now();
} }
let bounds = &state.config_cache().feed_bounds; //let bounds = &state.config_cache().feed_bounds;
macro_rules! feed { macro_rules! feed {
($chain:expr, $buffer:ident) => { ($buffer:expr) => {
{ {
let buffer = $buffer; let buffer = $buffer;
feed($chain, &buffer, bounds) state.chain_write(buffer).await.map_err(|_| FillBodyError)?;
} }
} }
} }
@ -101,44 +103,42 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream<Item = R
let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?; let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?;
let buffer = state.inbound_filter().filter_cow(buffer); let buffer = state.inbound_filter().filter_cow(buffer);
info!("{} -> {:?}", who, buffer); info!("{} -> {:?}", who, buffer);
let mut chain = state.chain().write().await;
cfg_if! { cfg_if! {
if #[cfg(feature="split-newlines")] { if #[cfg(feature="split-newlines")] {
for buffer in buffer.split('\n').filter(|line| !line.trim().is_empty()) { feed!(stream::iter(buffer.split('\n').filter(|line| !line.trim().is_empty())
feed!(&mut chain, buffer); .map(|x| x.to_owned())))
}
} else { } else {
feed!(&mut chain, buffer); feed!(stream::once(async move{buffer.into_owned()}));
} }
} }
} else { } else {
use tokio::prelude::*; use tokio::prelude::*;
let reader = chunking::StreamReader::new(body.filter_map(|x| x.map(|mut x| x.to_bytes()).ok())); let reader = chunking::StreamReader::new(body.filter_map(|x| x.map(|mut x| x.to_bytes()).ok()));
let mut lines = reader.lines(); let lines = reader.lines();
#[cfg(feature="hog-buffer")] feed!(lines.filter_map(|x| x.ok().and_then(|line| {
let mut chain = state.chain().write().await;
while let Some(line) = lines.next_line().await.map_err(|_| FillBodyError)? {
let line = state.inbound_filter().filter_cow(&line); let line = state.inbound_filter().filter_cow(&line);
let line = line.trim(); let line = line.trim();
if !line.is_empty() { if !line.is_empty() {
#[cfg(not(feature="hog-buffer"))] //#[cfg(not(feature="hog-buffer"))]
let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right? //let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right?
feed!(&mut chain, line);
info!("{} -> {:?}", who, line); info!("{} -> {:?}", who, line);
}
written+=line.len(); written+=line.len();
Some(line.to_owned())
} else {
None
} }
})));
} }
} }
if_debug! { if_debug! {
trace!("Write took {}ms", timer.elapsed().as_millis()); trace!("Write took {}ms", timer.elapsed().as_millis());
} }
state.notify_save();
Ok(written) Ok(written)
} }

@ -1,34 +1,46 @@
//! Generating the strings //! Generating the strings
use super::*; use super::*;
use tokio::sync::mpsc::error::SendError;
use futures::StreamExt;
#[derive(Debug)] #[derive(Debug, Default)]
pub struct GenBodyError(pub String); pub struct GenBodyError(Option<String>);
impl error::Error for GenBodyError{} impl error::Error for GenBodyError{}
impl fmt::Display for GenBodyError impl fmt::Display for GenBodyError
{ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{ {
write!(f, "failed to write {:?} to body", self.0) if let Some(z) = &self.0 {
write!(f, "failed to write read string {:?} to body", z)
} else {
write!(f, "failed to read string from chain. it might be empty.")
}
} }
} }
pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError> pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError>
{ {
let chain = state.chain().read().await; let mut chain = state.chain_read();
if !chain.is_empty() {
let filter = state.outbound_filter(); let filter = state.outbound_filter();
match num { match num {
Some(num) if num < state.config().max_gen_size => { Some(num) if num < state.config().max_gen_size => {
//This could DoS `full_body` and writes, potentially. let mut chain = chain.take(num);
for string in chain.str_iter_for(num) { while let Some(string) = chain.next().await {
output.send(filter.filter_owned(string)).await.map_err(|e| GenBodyError(e.0))?; output.send(filter.filter_owned(string)).await?;
} }
}, },
_ => output.send(filter.filter_owned(chain.generate_str())).await.map_err(|e| GenBodyError(e.0))?, _ => output.send(filter.filter_owned(chain.next().await.ok_or_else(GenBodyError::default)?)).await?,
}
} }
Ok(()) Ok(())
} }
impl From<SendError<String>> for GenBodyError
{
#[inline] fn from(from: SendError<String>) -> Self
{
Self(Some(from.0))
}
}

@ -0,0 +1,392 @@
//! Chain handler.
use super::*;
use std::{
marker::Send,
sync::Weak,
num::NonZeroUsize,
task::{Poll, Context,},
pin::Pin,
};
use tokio::{
sync::{
RwLock,
RwLockReadGuard,
mpsc::{
self,
error::SendError,
},
watch,
Notify,
},
task::JoinHandle,
time::{
self,
Duration,
},
};
use futures::StreamExt;
pub const DEFAULT_TIMEOUT: Duration= Duration::from_secs(5);
/// Settings for chain handler
#[derive(Debug, Clone, PartialEq)]
pub struct Settings
{
pub backlog: usize,
pub internal_backlog: usize,
pub capacity: usize,
pub timeout: Duration,
pub throttle: Option<Duration>,
pub bounds: range::DynRange<usize>,
}
impl Settings
{
/// Should we keep this string.
#[inline] fn matches(&self, _s: &str) -> bool
{
true
}
}
impl Default for Settings
{
#[inline]
fn default() -> Self
{
Self {
backlog: 32,
internal_backlog: 8,
capacity: 4,
timeout: Duration::from_secs(5),
throttle: Some(Duration::from_millis(200)),
bounds: feed::DEFAULT_FEED_BOUNDS.into(),
}
}
}
#[derive(Debug)]
struct HostInner<T>
{
input: mpsc::Receiver<Vec<T>>,
shutdown: watch::Receiver<bool>,
}
#[derive(Debug)]
struct Handle<T: Send+ chain::Chainable>
{
chain: RwLock<chain::Chain<T>>,
input: mpsc::Sender<Vec<T>>,
opt: Settings,
notify_write: Arc<Notify>,
push_now: Arc<Notify>,
shutdown: watch::Sender<bool>,
/// Data used only for the worker task.
host: msg::Once<HostInner<T>>,
}
#[derive(Clone, Debug)]
pub struct ChainHandle<T: Send + chain::Chainable>(Arc<Box<Handle<T>>>);
impl<T: Send+ chain::Chainable + 'static> ChainHandle<T>
{
pub fn with_settings(chain: chain::Chain<T>, opt: Settings) -> Self
{
let (shutdown_tx, shutdown) = watch::channel(false);
let (itx, irx) = mpsc::channel(opt.backlog);
Self(Arc::new(Box::new(Handle{
chain: RwLock::new(chain),
input: itx,
opt,
push_now: Arc::new(Notify::new()),
notify_write: Arc::new(Notify::new()),
shutdown: shutdown_tx,
host: msg::Once::new(HostInner{
input: irx,
shutdown,
})
})))
}
/// Acquire the chain read lock
async fn chain(&self) -> RwLockReadGuard<'_, chain::Chain<T>>
{
self.0.chain.read().await
}
/// A reference to the chain
pub fn chain_ref(&self) -> &RwLock<chain::Chain<T>>
{
&self.0.chain
}
/// Create a stream that reads generated values forever.
pub fn read(&self) -> ChainStream<T>
{
ChainStream{
chain: Arc::downgrade(&self.0),
buffer: Vec::with_capacity(self.0.opt.backlog),
}
}
/// Send this buffer to the chain
pub fn write(&self, buf: Vec<T>) -> impl futures::Future<Output = Result<(), SendError<Vec<T>>>> + 'static
{
let mut write = self.0.input.clone();
async move {
write.send(buf).await
}
}
/// Send this stream buffer to the chain
pub fn write_stream<'a, I: Stream<Item=T>>(&self, buf: I) -> impl futures::Future<Output = Result<(), SendError<Vec<T>>>> + 'a
where I: 'a
{
let mut write = self.0.input.clone();
async move {
write.send(buf.collect().await).await
}
}
/// Send this buffer to the chain
pub async fn write_in_place(&self, buf: Vec<T>) -> Result<(), SendError<Vec<T>>>
{
self.0.input.clone().send(buf).await
}
/// A referencer for the notifier
pub fn notify_when(&self) -> &Arc<Notify>
{
&self.0.notify_write
}
/// Force the pending buffers to be written to the chain now
pub fn push_now(&self)
{
self.0.push_now.notify();
}
/// Hang the worker thread, preventing it from taking any more inputs and also flushing it.
///
/// # Panics
/// If there was no worker thread.
pub fn hang(&self)
{
trace!("Communicating hang request");
self.0.shutdown.broadcast(true).expect("Failed to communicate hang");
}
}
impl ChainHandle<String>
{
#[deprecated = "use read() pls"]
pub async fn generate_body(&self, state: &state::State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), SendError<String>>
{
let chain = self.chain().await;
if !chain.is_empty() {
let filter = state.outbound_filter();
match num {
Some(num) if num < state.config().max_gen_size => {
//This could DoS writes, potentially.
for string in chain.str_iter_for(num) {
output.send(filter.filter_owned(string)).await?;
}
},
_ => output.send(filter.filter_owned(chain.generate_str())).await?,
}
}
Ok(())
}
}
/// Host this handle on the current task.
///
/// # Panics
/// If `from` has already been hosted.
pub async fn host(from: ChainHandle<String>)
{
let opt = from.0.opt.clone();
let mut data = from.0.host.unwrap().await;
let (mut tx, mut child) = {
// The `real` input channel.
let from = from.clone();
let opt = opt.clone();
let (tx, rx) = mpsc::channel::<Vec<Vec<_>>>(opt.internal_backlog);
(tx, tokio::spawn(async move {
let mut rx = if let Some(thr) = opt.throttle {
time::throttle(thr, rx).boxed()
} else {
rx.boxed()
};
trace!("child: Begin waiting on parent");
while let Some(item) = rx.next().await {
if item.len() > 0 {
info!("Write lock acq");
let mut lock = from.0.chain.write().await;
for item in item.into_iter()
{
use std::ops::DerefMut;
for item in item.into_iter() {
feed::feed(lock.deref_mut(), item, &from.0.opt.bounds);
}
}
trace!("Signalling write");
from.0.notify_write.notify();
}
}
trace!("child: exiting");
}))
};
trace!("Begin polling on child");
tokio::select!{
v = &mut child => {
match v {
#[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen.
Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel.
}
},
_ = async move {
let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh, oh well.
if !data.shutdown.recv().await.unwrap_or(true) { //first shutdown we get for free
while Arc::strong_count(&from.0) > 2 {
if *data.shutdown.borrow() {
break;
}
tokio::select!{
Some(true) = data.shutdown.recv() => {
debug!("Got shutdown (hang) request. Sending now then breaking");
let mut rest = {
let irx = rx.get_mut();
irx.close(); //accept no more inputs
let mut output = Vec::with_capacity(opt.capacity);
while let Ok(item) = irx.try_recv() {
output.push(item);
}
output
};
rest.extend(rx.take_now());
if rest.len() > 0 {
if let Err(err) = tx.send(rest).await {
error!("Failed to force send buffer, exiting now: {}", err);
}
}
break;
}
_ = time::delay_for(opt.timeout) => {
trace!("Setting push now");
rx.push_now();
}
_ = from.0.push_now.notified() => {
debug!("Got force push signal");
let take =rx.take_now();
rx.push_now();
if take.len() > 0 {
if let Err(err) = tx.send(take).await {
error!("Failed to force send buffer: {}", err);
break;
}
}
}
Some(buffer) = rx.next() => {
debug!("Sending {} (cap {})", buffer.len(), buffer.capacity());
if let Err(err) = tx.send(buffer).await {
// Receive closed?
//
// This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`.
// In any case, it means we should abort.
#[cold] error!("Failed to send buffer: {}", err);
break;
}
}
}
}
}
let last = rx.into_buffer();
if last.len() > 0 {
if let Err(err) = tx.send(last).await {
error!("Failed to force send last part of buffer: {}", err);
} else {
trace!("Sent rest of buffer");
}
}
} => {
// Normal exit
trace!("Normal exit")
},
}
trace!("Waiting on child");
// No more handles except child, no more possible inputs.
child.await.expect("Child panic");
trace!("Returning");
}
/// Spawn a new chain handler for this chain.
pub fn spawn(from: chain::Chain<String>, opt: Settings) -> (JoinHandle<()>, ChainHandle<String>)
{
debug!("Spawning with opt: {:?}", opt);
let handle = ChainHandle::with_settings(from, opt);
(tokio::spawn(host(handle.clone())), handle)
}
#[derive(Debug)]
pub struct ChainStream<T: Send + chain::Chainable>
{
chain: Weak<Box<Handle<T>>>,
buffer: Vec<T>,
}
impl ChainStream<String>
{
async fn try_pull(&mut self, n: usize) -> Option<NonZeroUsize>
{
if n == 0 {
return None;
}
if let Some(read) = self.chain.upgrade() {
let chain = read.chain.read().await;
if chain.is_empty() {
return None;
}
let n = if n == 1 {
self.buffer.push(chain.generate_str());
1
} else {
self.buffer.extend(chain.str_iter_for(n));
n //for now
};
Some(unsafe{NonZeroUsize::new_unchecked(n)})
} else {
None
}
}
}
impl Stream for ChainStream<String>
{
type Item = String;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use futures::Future;
let this = self.get_mut();
if this.buffer.len() == 0 {
let pull = this.try_pull(this.buffer.capacity());
tokio::pin!(pull);
match pull.poll(cx) {
Poll::Ready(Some(_)) => {},
Poll::Pending => return Poll::Pending,
_ => return Poll::Ready(None),
};
}
debug_assert!(this.buffer.len()>0);
Poll::Ready(Some(this.buffer.remove(0)))
}
}

@ -78,6 +78,7 @@ use state::State;
mod save; mod save;
mod forwarded_list; mod forwarded_list;
use forwarded_list::XForwardedFor; use forwarded_list::XForwardedFor;
mod handle;
mod feed; mod feed;
mod gen; mod gen;
@ -134,7 +135,7 @@ async fn main() {
debug!("Using config {:?}", config); debug!("Using config {:?}", config);
trace!("With config cached: {:?}", ccache); trace!("With config cached: {:?}", ccache);
let chain = Arc::new(RwLock::new(match save::load(&config.file).await { let (chain_handle, chain) = handle::spawn(match save::load(&config.file).await {
Ok(chain) => { Ok(chain) => {
info!("Loaded chain from {:?}", config.file); info!("Loaded chain from {:?}", config.file);
chain chain
@ -144,16 +145,15 @@ async fn main() {
trace!("Error: {}", e); trace!("Error: {}", e);
Chain::new() Chain::new()
}, },
})); }, ccache.handler_settings.clone());
{ {
let mut tasks = Vec::<BoxFuture<'static, ()>>::new(); let mut tasks = Vec::<BoxFuture<'static, ()>>::new();
tasks.push(chain_handle.map(|res| res.expect("Chain handle panicked")).boxed());
let (state, chain) = { let (state, chain) = {
let save_when = Arc::new(Notify::new());
let state = State::new(config, let state = State::new(config,
ccache, ccache,
Arc::clone(&chain), chain);
Arc::clone(&save_when));
let state2 = state.clone(); let state2 = state.clone();
let saver = tokio::spawn(save::host(Box::new(state.clone()))); let saver = tokio::spawn(save::host(Box::new(state.clone())));
let chain = warp::any().map(move || state.clone()); let chain = warp::any().map(move || state.clone());

@ -3,6 +3,7 @@ use super::*;
use tokio::{ use tokio::{
sync::{ sync::{
watch, watch,
Mutex,
}, },
}; };
use std::{ use std::{
@ -12,7 +13,9 @@ use std::{
error, error,
}; };
use futures::{ use futures::{
future::Future, future::{
Future,
},
}; };
#[derive(Debug)] #[derive(Debug)]
@ -160,3 +163,48 @@ impl Future for Initialiser
uhh.poll(cx) uhh.poll(cx)
} }
} }
/// A value that can be consumed once.
#[derive(Debug)]
pub struct Once<T>(Mutex<Option<T>>);
impl<T> Once<T>
{
/// Create a new instance
pub fn new(from: T) -> Self
{
Self(Mutex::new(Some(from)))
}
/// Consume into the instance from behind a potentially shared reference.
pub async fn consume_shared(self: Arc<Self>) -> Option<T>
{
match Arc::try_unwrap(self) {
Ok(x) => x.0.into_inner(),
Err(x) => x.0.lock().await.take(),
}
}
/// Consume from a shared reference and panic if the value has already been consumed.
pub async fn unwrap_shared(self: Arc<Self>) -> T
{
self.consume_shared().await.unwrap()
}
/// Consume into the instance.
pub async fn consume(&self) -> Option<T>
{
self.0.lock().await.take()
}
/// Consume and panic if the value has already been consumed.
pub async fn unwrap(&self) -> T
{
self.consume().await.unwrap()
}
/// Consume into the inner value
pub fn into_inner(self) -> Option<T>
{
self.0.into_inner()
}
}

@ -272,6 +272,6 @@ mod tests
let string = "abcdef ghi jk1\nhian"; let string = "abcdef ghi jk1\nhian";
assert_eq!(filter.filter_str(&string).to_string(), filter.filter_cow(&string).to_string()); assert_eq!(filter.filter_str(&string).to_string(), filter.filter_cow(&string).to_string());
assert_eq!(filter.filter_cow(&string).to_string(), filter.filter(string.chars()).collect::<String>()); assert_eq!(filter.filter_cow(&string).to_string(), filter.filter_iter(string.chars()).collect::<String>());
} }
} }

@ -25,7 +25,7 @@ macro_rules! new {
}; };
} }
const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?']; const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?', '~'];
lazy_static! { lazy_static! {
static ref BOUNDARIES: smallmap::Map<char, ()> = { static ref BOUNDARIES: smallmap::Map<char, ()> = {

@ -25,7 +25,7 @@ macro_rules! new {
}; };
} }
const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ',']; const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ',', '*'];
lazy_static! { lazy_static! {
static ref BOUNDARIES: smallmap::Map<char, ()> = { static ref BOUNDARIES: smallmap::Map<char, ()> = {

@ -43,7 +43,7 @@ type Decompressor<T> = BzDecoder<T>;
pub async fn save_now(state: &State) -> io::Result<()> pub async fn save_now(state: &State) -> io::Result<()>
{ {
let chain = state.chain().read().await; let chain = state.chain_ref().read().await;
use std::ops::Deref; use std::ops::Deref;
let to = &state.config().file; let to = &state.config().file;
save_now_to(chain.deref(),to).await save_now_to(chain.deref(),to).await
@ -77,12 +77,13 @@ pub async fn host(mut state: Box<State>)
{ {
let to = state.config().file.to_owned(); let to = state.config().file.to_owned();
let interval = state.config().save_interval(); let interval = state.config().save_interval();
let when = Arc::clone(state.when_ref());
trace!("Setup oke. Waiting on init"); trace!("Setup oke. Waiting on init");
if state.on_init().await.is_ok() { if state.on_init().await.is_ok() {
debug!("Begin save handler"); debug!("Begin save handler");
while Arc::strong_count(state.when()) > 1 { while Arc::strong_count(&when) > 1 {
{ {
let chain = state.chain().read().await; let chain = state.chain_ref().read().await;
use std::ops::Deref; use std::ops::Deref;
if let Err(e) = save_now_to(chain.deref(), &to).await { if let Err(e) = save_now_to(chain.deref(), &to).await {
error!("Failed to save chain: {}", e); error!("Failed to save chain: {}", e);
@ -97,7 +98,7 @@ pub async fn host(mut state: Box<State>)
break; break;
} }
} }
state.when().notified().await; when.notified().await;
if state.has_shutdown() { if state.has_shutdown() {
break; break;
} }

@ -1,17 +1,19 @@
//! /sentance/ //! /sentance/
use super::*; use super::*;
use futures::StreamExt;
pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), gen::GenBodyError> pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), gen::GenBodyError>
{ {
let string = { let string = {
let chain = state.chain().read().await; let mut chain = state.chain_read();
if chain.is_empty() {
return Ok(());
}
match num { match num {
None => chain.generate_str(), None => chain.next().await.ok_or_else(gen::GenBodyError::default)?,
Some(num) => (0..num).map(|_| chain.generate_str()).join("\n"), Some(num) if num < state.config().max_gen_size => {//(0..num).map(|_| chain.generate_str()).join("\n"),
let chain = chain.take(num);
chain.collect::<Vec<_>>().await.join("\n")//TODO: Stream version of JoinStrExt
},
_ => return Err(Default::default()),
} }
}; };
@ -20,14 +22,14 @@ pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<Str
if let Some(num) = num { if let Some(num) = num {
for sen in sanitise::Sentance::new_iter(&string).take(num) for sen in sanitise::Sentance::new_iter(&string).take(num)
{ {
output.send(filter.filter_owned(sen.to_owned())).await.map_err(|e| gen::GenBodyError(e.0))?; output.send(filter.filter_owned(sen.to_owned())).await?;
} }
} else { } else {
output.send(filter.filter_owned(match sanitise::Sentance::new_iter(&string) output.send(filter.filter_owned(match sanitise::Sentance::new_iter(&string)
.max_by_key(|x| x.len()) { .max_by_key(|x| x.len()) {
Some(x) => x, Some(x) => x,
#[cold] None => return Ok(()), #[cold] None => return Ok(()),
}.to_owned())).await.map_err(|e| gen::GenBodyError(e.0))?; }.to_owned())).await?;
} }
Ok(()) Ok(())
} }

@ -14,6 +14,7 @@ pub async fn handle(mut state: State)
let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("Failed to hook SIGUSR1"); let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("Failed to hook SIGUSR1");
let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2"); let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("Failed to hook SIGUSR2");
let mut quit = unix::signal(SignalKind::quit()).expect("Failed to hook SIGQUIT"); let mut quit = unix::signal(SignalKind::quit()).expect("Failed to hook SIGQUIT");
let mut io = unix::signal(SignalKind::io()).expect("Failed to hook IO");
trace!("Setup oke. Waiting on init"); trace!("Setup oke. Waiting on init");
if state.on_init().await.is_ok() { if state.on_init().await.is_ok() {
@ -24,19 +25,15 @@ pub async fn handle(mut state: State)
break; break;
} }
_ = usr1.recv() => { _ = usr1.recv() => {
info!("Got SIGUSR1. Saving chain immediately."); info!("Got SIGUSR1. Causing chain write.");
if let Err(e) = save::save_now(&state).await { state.push_now();
error!("Failed to save chain: {}", e);
} else{
trace!("Saved chain okay");
}
}, },
_ = usr2.recv() => { _ = usr2.recv() => {
info!("Got SIGUSR1. Loading chain immediately."); info!("Got SIGUSR2. Loading chain immediately.");
match save::load(&state.config().file).await { match save::load(&state.config().file).await {
Ok(new) => { Ok(new) => {
{ {
let mut chain = state.chain().write().await; let mut chain = state.chain_ref().write().await;
*chain = new; *chain = new;
} }
trace!("Replaced with read chain"); trace!("Replaced with read chain");
@ -46,6 +43,15 @@ pub async fn handle(mut state: State)
}, },
} }
}, },
_ = io.recv() => {
info!("Got SIGIO. Saving chain immediately.");
if let Err(e) = save::save_now(&state).await {
error!("Failed to save chain: {}", e);
} else{
trace!("Saved chain okay");
}
},
_ = quit.recv() => { _ = quit.recv() => {
warn!("Got SIGQUIT. Saving chain then aborting."); warn!("Got SIGQUIT. Saving chain then aborting.");
if let Err(e) = save::save_now(&state).await { if let Err(e) = save::save_now(&state).await {

@ -3,6 +3,7 @@ use super::*;
use tokio::{ use tokio::{
sync::{ sync::{
watch, watch,
mpsc::error::SendError,
}, },
}; };
use config::Config; use config::Config;
@ -25,8 +26,8 @@ impl fmt::Display for ShutdownError
pub struct State pub struct State
{ {
config: Arc<Box<(Config, config::Cache)>>, //to avoid cloning config config: Arc<Box<(Config, config::Cache)>>, //to avoid cloning config
chain: Arc<RwLock<Chain<String>>>, chain: handle::ChainHandle<String>,
save: Arc<Notify>, //save: Arc<Notify>,
begin: Initialiser, begin: Initialiser,
shutdown: Arc<watch::Sender<bool>>, shutdown: Arc<watch::Sender<bool>>,
@ -78,13 +79,12 @@ impl State
&self.config_cache().outbound_filter &self.config_cache().outbound_filter
} }
pub fn new(config: Config, cache: config::Cache, chain: Arc<RwLock<Chain<String>>>, save: Arc<Notify>) -> Self pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle<String>) -> Self
{ {
let (shutdown, shutdown_recv) = watch::channel(false); let (shutdown, shutdown_recv) = watch::channel(false);
Self { Self {
config: Arc::new(Box::new((config, cache))), config: Arc::new(Box::new((config, cache))),
chain, chain,
save,
begin: Initialiser::new(), begin: Initialiser::new(),
shutdown: Arc::new(shutdown), shutdown: Arc::new(shutdown),
shutdown_recv, shutdown_recv,
@ -101,25 +101,48 @@ impl State
&self.config.as_ref().1 &self.config.as_ref().1
} }
pub fn notify_save(&self) /*pub fn notify_save(&self)
{ {
self.save.notify(); self.save.notify();
} }*/
pub fn chain(&self) -> &RwLock<Chain<String>> /*pub fn chain(&self) -> &RwLock<Chain<String>>
{ {
&self.chain.as_ref() &self.chain.as_ref()
}*/
pub fn chain_ref(&self) -> &RwLock<Chain<String>>
{
&self.chain.chain_ref()
} }
pub fn when(&self) -> &Arc<Notify> pub fn chain_read(&self) -> handle::ChainStream<String>
{ {
&self.save self.chain.read()
}
/// Write to this chain
pub async fn chain_write<'a, T: Stream<Item = String>>(&'a self, buffer: T) -> Result<(), SendError<Vec<String>>>
{
self.chain.write_stream(buffer).await
}
pub fn when_ref(&self) -> &Arc<Notify>
{
&self.chain.notify_when()
}
/// Force the chain to push through now
pub fn push_now(&self)
{
self.chain.push_now()
} }
pub fn shutdown(self) pub fn shutdown(self)
{ {
self.shutdown.broadcast(true).expect("Failed to communicate shutdown"); self.shutdown.broadcast(true).expect("Failed to communicate shutdown");
self.save.notify(); self.chain.hang();
self.when_ref().notify();
} }
pub fn has_shutdown(&self) -> bool pub fn has_shutdown(&self) -> bool

Loading…
Cancel
Save