handle okay

feed
Avril 4 years ago
parent 6453392758
commit 5dc10547d5
Signed by: flanchan
GPG Key ID: 284488987C31F630

@ -6,6 +6,7 @@ use std::{
Context,
},
pin::Pin,
marker::PhantomData,
};
use tokio::{
io::{
@ -173,3 +174,86 @@ mod tests
assert_eq!(&output[..], "Hello world\nHow are you");
}
}
/// A stream that chunks its input.
#[pin_project]
pub struct ChunkingStream<S, T, Into=Vec<T>>
{
#[pin] stream: Fuse<S>,
buf: Vec<T>,
cap: usize,
_output: PhantomData<Into>,
push_now: bool,
}
impl<S, T, Into> ChunkingStream<S,T, Into>
where S: Stream<Item=T>,
Into: From<Vec<T>>
{
pub fn new(stream: S, sz: usize) -> Self
{
Self {
stream: stream.fuse(),
buf: Vec::with_capacity(sz),
cap: sz,
_output: PhantomData,
push_now: false,
}
}
pub fn into_inner(self) -> S
{
self.stream.into_inner()
}
pub fn cap(&self) -> usize
{
self.cap
}
pub fn buffer(&self) -> &[T]
{
&self.buf[..]
}
/// Force the next read to send the buffer even if it's not full.
///
/// # Note
/// The buffer still won't send if it's empty.
pub fn push_now(&mut self)
{
self.push_now= true;
}
}
impl<S, T, Into> Stream for ChunkingStream<S,T, Into>
where S: Stream<Item=T>,
Into: From<Vec<T>>
{
type Item = Into;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
while !(self.push_now && !self.buf.is_empty()) && self.buf.len() < self.cap {
// Buffer isn't full, keep filling
let this = self.as_mut().project();
match this.stream.poll_next(cx) {
Poll::Ready(None) => {
// Stream is over
break;
},
Poll::Ready(Some(item)) => {
this.buf.push(item);
},
_ => return Poll::Pending,
}
}
// Buffer is full or we reach end of stream
Poll::Ready(if self.buf.len() == 0 {
None
} else {
let this = self.project();
*this.push_now = false;
let output = std::mem::replace(this.buf, Vec::with_capacity(*this.cap));
Some(output.into())
})
}
}

@ -1,4 +1,5 @@
//! Extensions
use super::*;
use std::{
iter,
ops::{
@ -162,3 +163,21 @@ impl<T> DerefMut for AssertNotSend<T>
&mut self.0
}
}
pub trait ChunkStreamExt<T>: Sized
{
fn chunk_into<I: From<Vec<T>>>(self, sz: usize) -> chunking::ChunkingStream<Self,T,I>;
fn chunk(self, sz: usize) -> chunking::ChunkingStream<Self, T>
{
self.chunk_into(sz)
}
}
impl<S, T> ChunkStreamExt<T> for S
where S: Stream<Item=T>
{
fn chunk_into<I: From<Vec<T>>>(self, sz: usize) -> chunking::ChunkingStream<Self,T,I>
{
chunking::ChunkingStream::new(self, sz)
}
}

@ -2,8 +2,10 @@
use super::*;
#[cfg(any(feature="feed-sentance", feature="split-sentance"))]
use sanitise::Sentance;
use std::iter;
pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom<usize> = 2..; //TODO: Add to config somehow
pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom<usize> = 2..;
/// Feed `what` into `chain`, at least `bounds` tokens.
///
@ -73,12 +75,12 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream<Item = R
if_debug! {
let timer = std::time::Instant::now();
}
let bounds = &state.config_cache().feed_bounds;
//let bounds = &state.config_cache().feed_bounds;
macro_rules! feed {
($chain:expr, $buffer:ident) => {
($buffer:expr) => {
{
let buffer = $buffer;
feed($chain, &buffer, bounds)
state.chain_write(buffer.map(ToOwned::to_owned)).await.map_err(|_| FillBodyError)?;
}
}
}
@ -101,15 +103,11 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream<Item = R
let buffer = std::str::from_utf8(&buffer[..]).map_err(|_| FillBodyError)?;
let buffer = state.inbound_filter().filter_cow(buffer);
info!("{} -> {:?}", who, buffer);
let mut chain = state.chain().write().await;
cfg_if! {
if #[cfg(feature="split-newlines")] {
for buffer in buffer.split('\n').filter(|line| !line.trim().is_empty()) {
feed!(&mut chain, buffer);
}
feed!(buffer.split('\n').filter(|line| !line.trim().is_empty()))
} else {
feed!(&mut chain, buffer);
feed!(iter::once(buffer));
}
}
} else {
@ -124,10 +122,10 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream<Item = R
let line = state.inbound_filter().filter_cow(&line);
let line = line.trim();
if !line.is_empty() {
#[cfg(not(feature="hog-buffer"))]
let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right?
//#[cfg(not(feature="hog-buffer"))]
//let mut chain = state.chain().write().await; // Acquire mutex once per line? Is this right?
feed!(&mut chain, line);
feed!(iter::once(line));
info!("{} -> {:?}", who, line);
}
written+=line.len();

@ -1,34 +1,46 @@
//! Generating the strings
use super::*;
use tokio::sync::mpsc::error::SendError;
use futures::StreamExt;
#[derive(Debug)]
pub struct GenBodyError(pub String);
#[derive(Debug, Default)]
pub struct GenBodyError(Option<String>);
impl error::Error for GenBodyError{}
impl fmt::Display for GenBodyError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "failed to write {:?} to body", self.0)
if let Some(z) = &self.0 {
write!(f, "failed to write read string {:?} to body", z)
} else {
write!(f, "failed to read string from chain. it might be empty.")
}
}
}
pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError>
{
let chain = state.chain().read().await;
if !chain.is_empty() {
let mut chain = state.chain_read();
let filter = state.outbound_filter();
match num {
Some(num) if num < state.config().max_gen_size => {
//This could DoS `full_body` and writes, potentially.
for string in chain.str_iter_for(num) {
output.send(filter.filter_owned(string)).await.map_err(|e| GenBodyError(e.0))?;
let mut chain = chain.take(num);
while let Some(string) = chain.next().await {
output.send(filter.filter_owned(string)).await?;
}
},
_ => output.send(filter.filter_owned(chain.generate_str())).await.map_err(|e| GenBodyError(e.0))?,
}
_ => output.send(filter.filter_owned(chain.next().await.ok_or_else(GenBodyError::default)?)).await?,
}
Ok(())
}
impl From<SendError<String>> for GenBodyError
{
#[inline] fn from(from: SendError<String>) -> Self
{
Self(Some(from.0))
}
}

@ -0,0 +1,285 @@
//! Chain handler.
use super::*;
use std::{
marker::Send,
sync::Weak,
num::NonZeroUsize,
task::{Poll, Context,},
pin::Pin,
};
use tokio::{
sync::{
RwLock,
RwLockReadGuard,
mpsc::{
self,
error::SendError,
},
},
task::JoinHandle,
time::{
self,
Duration,
},
};
use futures::StreamExt;
/// Settings for chain handler
#[derive(Debug, Clone, PartialEq)]
pub struct Settings
{
pub backlog: usize,
pub capacity: usize,
pub timeout: Duration,
pub throttle: Option<Duration>,
pub bounds: range::DynRange<usize>,
}
impl Settings
{
/// Should we keep this string.
#[inline] fn matches(&self, s: &str) -> bool
{
true
}
}
impl Default for Settings
{
#[inline]
fn default() -> Self
{
Self {
backlog: 32,
capacity: 4,
timeout: Duration::from_secs(5),
throttle: Some(Duration::from_millis(200)),
bounds: feed::DEFAULT_FEED_BOUNDS.into(),
}
}
}
#[derive(Debug)]
struct HostInner<T>
{
input: mpsc::Receiver<Vec<T>>,
}
#[derive(Debug)]
struct Handle<T: Send+ chain::Chainable>
{
chain: RwLock<chain::Chain<T>>,
input: mpsc::Sender<Vec<T>>,
opt: Settings,
/// Data used only for the worker task.
host: msg::Once<HostInner<T>>,
}
#[derive(Clone, Debug)]
pub struct ChainHandle<T: Send + chain::Chainable>(Arc<Box<Handle<T>>>);
impl<T: Send+ chain::Chainable> ChainHandle<T>
{
#[inline] pub fn new(chain: chain::Chain<T>) -> Self
{
Self::with_settings(chain, Default::default())
}
pub fn with_settings(chain: chain::Chain<T>, opt: Settings) -> Self
{
let (itx, irx) = mpsc::channel(opt.backlog);
Self(Arc::new(Box::new(Handle{
chain: RwLock::new(chain),
input: itx,
opt,
host: msg::Once::new(HostInner{
input: irx,
})
})))
}
/// Acquire the chain read lock
async fn chain(&self) -> RwLockReadGuard<'_, chain::Chain<T>>
{
self.0.chain.read().await
}
/// A reference to the chain
pub fn chain_ref(&self) -> &RwLock<chain::Chain<T>>
{
&self.0.chain
}
/// Create a stream that reads generated values forever.
pub fn read(&self) -> ChainStream<T>
{
ChainStream{
chain: Arc::downgrade(&self.0),
buffer: Vec::with_capacity(self.0.opt.backlog),
}
}
/// Send this buffer to the chain
pub async fn write(&self, buf: Vec<T>) -> Result<(), SendError<Vec<T>>>
{
self.0.input.clone().send(buf).await
}
}
impl ChainHandle<String>
{
#[deprecated = "use read() pls"]
pub async fn generate_body(&self, state: &state::State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), SendError<String>>
{
let chain = self.chain().await;
if !chain.is_empty() {
let filter = state.outbound_filter();
match num {
Some(num) if num < state.config().max_gen_size => {
//This could DoS writes, potentially.
for string in chain.str_iter_for(num) {
output.send(filter.filter_owned(string)).await?;
}
},
_ => output.send(filter.filter_owned(chain.generate_str())).await?,
}
}
Ok(())
}
}
/// Host this handle on the current task.
///
/// # Panics
/// If `from` has already been hosted.
pub async fn host(from: ChainHandle<String>)
{
let opt = from.0.opt.clone();
let data = from.0.host.unwrap().await;
let (mut tx, child) = {
// The `real` input channel.
let from = from.clone();
let opt = opt.clone();
let (tx, rx) = mpsc::channel::<Vec<Vec<_>>>(opt.backlog);
(tx, tokio::spawn(async move {
let mut rx = if let Some(thr) = opt.throttle {
time::throttle(thr, rx).boxed()
} else {
rx.boxed()
};
trace!("child: Begin waiting on parent");
while let Some(item) = rx.next().await {
let mut lock = from.0.chain.write().await;
for item in item.into_iter()
{
use std::ops::DerefMut;
for item in item.into_iter() {
feed::feed(lock.deref_mut(), item, &from.0.opt.bounds);
}
}
}
trace!("child: exiting");
}))
};
trace!("Begin polling on child");
tokio::select!{
v = child => {
match v {
#[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen.
Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel.
}
},
_ = async move {
let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh
while Arc::strong_count(&from.0) > 2 {
tokio::select!{
_ = time::delay_for(opt.timeout) => {
rx.push_now();
}
Some(buffer) = rx.next() => {
if let Err(err) = tx.send(buffer).await {
// Receive closed?
//
// This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`.
// In any case, it means we should abort.
error!("Failed to send buffer: {}", err);
break;
}
}
}
}
} => {
// Normal exit
trace!("Normal exit")
},
}
// No more handles except child, no more possible inputs.
trace!("Returning");
}
/// Spawn a new chain handler for this chain.
pub fn spawn(from: chain::Chain<String>, opt: Settings) -> (JoinHandle<()>, ChainHandle<String>)
{
let handle = ChainHandle::with_settings(from, opt);
(tokio::spawn(host(handle.clone())), handle)
}
#[derive(Debug)]
pub struct ChainStream<T: Send + chain::Chainable>
{
chain: Weak<Box<Handle<T>>>,
buffer: Vec<T>,
}
impl ChainStream<String>
{
async fn try_pull(&mut self, n: usize) -> Option<NonZeroUsize>
{
if n == 0 {
return None;
}
if let Some(read) = self.chain.upgrade() {
let chain = read.chain.read().await;
if chain.is_empty() {
return None;
}
let n = if n == 1 {
self.buffer.push(chain.generate_str());
1
} else {
self.buffer.extend(chain.str_iter_for(n));
n //for now
};
Some(unsafe{NonZeroUsize::new_unchecked(n)})
} else {
None
}
}
}
impl Stream for ChainStream<String>
{
type Item = String;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use futures::Future;
let this = self.get_mut();
if this.buffer.len() == 0 {
let pull = this.try_pull(this.buffer.capacity());
tokio::pin!(pull);
match pull.poll(cx) {
Poll::Ready(Some(_)) => {},
Poll::Pending => return Poll::Pending,
_ => return Poll::Ready(None),
};
}
debug_assert!(this.buffer.len()>0);
Poll::Ready(Some(this.buffer.remove(0)))
}
}

@ -78,6 +78,7 @@ use state::State;
mod save;
mod forwarded_list;
use forwarded_list::XForwardedFor;
mod handle;
mod feed;
mod gen;
@ -134,7 +135,7 @@ async fn main() {
debug!("Using config {:?}", config);
trace!("With config cached: {:?}", ccache);
let chain = Arc::new(RwLock::new(match save::load(&config.file).await {
let (chain_handle, chain) = handle::spawn(match save::load(&config.file).await {
Ok(chain) => {
info!("Loaded chain from {:?}", config.file);
chain
@ -144,7 +145,7 @@ async fn main() {
trace!("Error: {}", e);
Chain::new()
},
}));
}, Default::default()/*TODO*/);
{
let mut tasks = Vec::<BoxFuture<'static, ()>>::new();
let (state, chain) = {
@ -152,7 +153,7 @@ async fn main() {
let state = State::new(config,
ccache,
Arc::clone(&chain),
chain,
Arc::clone(&save_when));
let state2 = state.clone();
let saver = tokio::spawn(save::host(Box::new(state.clone())));

@ -3,6 +3,7 @@ use super::*;
use tokio::{
sync::{
watch,
Mutex,
},
};
use std::{
@ -12,7 +13,9 @@ use std::{
error,
};
use futures::{
future::Future,
future::{
Future,
},
};
#[derive(Debug)]
@ -160,3 +163,48 @@ impl Future for Initialiser
uhh.poll(cx)
}
}
/// A value that can be consumed once.
#[derive(Debug)]
pub struct Once<T>(Mutex<Option<T>>);
impl<T> Once<T>
{
/// Create a new instance
pub fn new(from: T) -> Self
{
Self(Mutex::new(Some(from)))
}
/// Consume into the instance from behind a potentially shared reference.
pub async fn consume_shared(self: Arc<Self>) -> Option<T>
{
match Arc::try_unwrap(self) {
Ok(x) => x.0.into_inner(),
Err(x) => x.0.lock().await.take(),
}
}
/// Consume from a shared reference and panic if the value has already been consumed.
pub async fn unwrap_shared(self: Arc<Self>) -> T
{
self.consume_shared().await.unwrap()
}
/// Consume into the instance.
pub async fn consume(&self) -> Option<T>
{
self.0.lock().await.take()
}
/// Consume and panic if the value has already been consumed.
pub async fn unwrap(&self) -> T
{
self.consume().await.unwrap()
}
/// Consume into the inner value
pub fn into_inner(self) -> Option<T>
{
self.0.into_inner()
}
}

@ -272,6 +272,6 @@ mod tests
let string = "abcdef ghi jk1\nhian";
assert_eq!(filter.filter_str(&string).to_string(), filter.filter_cow(&string).to_string());
assert_eq!(filter.filter_cow(&string).to_string(), filter.filter(string.chars()).collect::<String>());
assert_eq!(filter.filter_cow(&string).to_string(), filter.filter_iter(string.chars()).collect::<String>());
}
}

@ -25,7 +25,7 @@ macro_rules! new {
};
}
const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?'];
const DEFAULT_BOUNDARIES: &[char] = &['\n', '.', ':', '!', '?', '~'];
lazy_static! {
static ref BOUNDARIES: smallmap::Map<char, ()> = {

@ -25,7 +25,7 @@ macro_rules! new {
};
}
const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ','];
const DEFAULT_BOUNDARIES: &[char] = &['!', '.', ',', '*'];
lazy_static! {
static ref BOUNDARIES: smallmap::Map<char, ()> = {

@ -43,7 +43,7 @@ type Decompressor<T> = BzDecoder<T>;
pub async fn save_now(state: &State) -> io::Result<()>
{
let chain = state.chain().read().await;
let chain = state.chain_ref().read().await;
use std::ops::Deref;
let to = &state.config().file;
save_now_to(chain.deref(),to).await
@ -82,7 +82,7 @@ pub async fn host(mut state: Box<State>)
debug!("Begin save handler");
while Arc::strong_count(state.when()) > 1 {
{
let chain = state.chain().read().await;
let chain = state.chain_ref().read().await;
use std::ops::Deref;
if let Err(e) = save_now_to(chain.deref(), &to).await {
error!("Failed to save chain: {}", e);

@ -1,17 +1,19 @@
//! /sentance/
use super::*;
use futures::StreamExt;
pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), gen::GenBodyError>
{
let string = {
let chain = state.chain().read().await;
if chain.is_empty() {
return Ok(());
}
let mut chain = state.chain_read();
match num {
None => chain.generate_str(),
Some(num) => (0..num).map(|_| chain.generate_str()).join("\n"),
None => chain.next().await.ok_or_else(gen::GenBodyError::default)?,
Some(num) if num < state.config().max_gen_size => {//(0..num).map(|_| chain.generate_str()).join("\n"),
let chain = chain.take(num);
chain.collect::<Vec<_>>().await.join("\n")//TODO: Stream version of JoinStrExt
},
_ => return Err(Default::default()),
}
};
@ -20,14 +22,14 @@ pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<Str
if let Some(num) = num {
for sen in sanitise::Sentance::new_iter(&string).take(num)
{
output.send(filter.filter_owned(sen.to_owned())).await.map_err(|e| gen::GenBodyError(e.0))?;
output.send(filter.filter_owned(sen.to_owned())).await?;
}
} else {
output.send(filter.filter_owned(match sanitise::Sentance::new_iter(&string)
.max_by_key(|x| x.len()) {
Some(x) => x,
#[cold] None => return Ok(()),
}.to_owned())).await.map_err(|e| gen::GenBodyError(e.0))?;
}.to_owned())).await?;
}
Ok(())
}

@ -36,7 +36,7 @@ pub async fn handle(mut state: State)
match save::load(&state.config().file).await {
Ok(new) => {
{
let mut chain = state.chain().write().await;
let mut chain = state.chain_ref().write().await;
*chain = new;
}
trace!("Replaced with read chain");

@ -3,6 +3,7 @@ use super::*;
use tokio::{
sync::{
watch,
mpsc::error::SendError,
},
};
use config::Config;
@ -25,7 +26,7 @@ impl fmt::Display for ShutdownError
pub struct State
{
config: Arc<Box<(Config, config::Cache)>>, //to avoid cloning config
chain: Arc<RwLock<Chain<String>>>,
chain: handle::ChainHandle<String>,
save: Arc<Notify>,
begin: Initialiser,
@ -78,7 +79,7 @@ impl State
&self.config_cache().outbound_filter
}
pub fn new(config: Config, cache: config::Cache, chain: Arc<RwLock<Chain<String>>>, save: Arc<Notify>) -> Self
pub fn new(config: Config, cache: config::Cache, chain: handle::ChainHandle<String>, save: Arc<Notify>) -> Self
{
let (shutdown, shutdown_recv) = watch::channel(false);
Self {
@ -106,9 +107,23 @@ impl State
self.save.notify();
}
pub fn chain(&self) -> &RwLock<Chain<String>>
/*pub fn chain(&self) -> &RwLock<Chain<String>>
{
&self.chain.as_ref()
}*/
pub fn chain_ref(&self) -> &RwLock<Chain<String>>
{
&self.chain.chain_ref()
}
pub fn chain_read(&self) -> handle::ChainStream<String>
{
self.chain.read()
}
pub async fn chain_write(&self, buffer: impl IntoIterator<Item = String>) -> Result<(), SendError<Vec<String>>>
{
self.chain.write(buffer.into_iter().collect()).await
}
pub fn when(&self) -> &Arc<Notify>

Loading…
Cancel
Save