commit
acf2ac605e
@ -1,34 +1,46 @@
|
||||
//! Generating the strings
|
||||
use super::*;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use futures::StreamExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GenBodyError(pub String);
|
||||
#[derive(Debug, Default)]
|
||||
pub struct GenBodyError(Option<String>);
|
||||
|
||||
impl error::Error for GenBodyError{}
|
||||
impl fmt::Display for GenBodyError
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
|
||||
{
|
||||
write!(f, "failed to write {:?} to body", self.0)
|
||||
if let Some(z) = &self.0 {
|
||||
write!(f, "failed to write read string {:?} to body", z)
|
||||
} else {
|
||||
write!(f, "failed to read string from chain. it might be empty.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub async fn body(state: State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), GenBodyError>
|
||||
{
|
||||
let chain = state.chain().read().await;
|
||||
if !chain.is_empty() {
|
||||
let mut chain = state.chain_read();
|
||||
let filter = state.outbound_filter();
|
||||
match num {
|
||||
Some(num) if num < state.config().max_gen_size => {
|
||||
//This could DoS `full_body` and writes, potentially.
|
||||
for string in chain.str_iter_for(num) {
|
||||
output.send(filter.filter_owned(string)).await.map_err(|e| GenBodyError(e.0))?;
|
||||
let mut chain = chain.take(num);
|
||||
while let Some(string) = chain.next().await {
|
||||
output.send(filter.filter_owned(string)).await?;
|
||||
}
|
||||
},
|
||||
_ => output.send(filter.filter_owned(chain.generate_str())).await.map_err(|e| GenBodyError(e.0))?,
|
||||
}
|
||||
_ => output.send(filter.filter_owned(chain.next().await.ok_or_else(GenBodyError::default)?)).await?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
impl From<SendError<String>> for GenBodyError
|
||||
{
|
||||
#[inline] fn from(from: SendError<String>) -> Self
|
||||
{
|
||||
Self(Some(from.0))
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,392 @@
|
||||
//! Chain handler.
|
||||
use super::*;
|
||||
use std::{
|
||||
marker::Send,
|
||||
sync::Weak,
|
||||
num::NonZeroUsize,
|
||||
task::{Poll, Context,},
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::{
|
||||
sync::{
|
||||
RwLock,
|
||||
RwLockReadGuard,
|
||||
mpsc::{
|
||||
self,
|
||||
error::SendError,
|
||||
},
|
||||
watch,
|
||||
Notify,
|
||||
},
|
||||
task::JoinHandle,
|
||||
time::{
|
||||
self,
|
||||
Duration,
|
||||
},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
|
||||
pub const DEFAULT_TIMEOUT: Duration= Duration::from_secs(5);
|
||||
|
||||
/// Settings for chain handler
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Settings
|
||||
{
|
||||
pub backlog: usize,
|
||||
pub internal_backlog: usize,
|
||||
pub capacity: usize,
|
||||
pub timeout: Duration,
|
||||
pub throttle: Option<Duration>,
|
||||
pub bounds: range::DynRange<usize>,
|
||||
}
|
||||
|
||||
impl Settings
|
||||
{
|
||||
/// Should we keep this string.
|
||||
#[inline] fn matches(&self, _s: &str) -> bool
|
||||
{
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Settings
|
||||
{
|
||||
#[inline]
|
||||
fn default() -> Self
|
||||
{
|
||||
Self {
|
||||
backlog: 32,
|
||||
internal_backlog: 8,
|
||||
capacity: 4,
|
||||
timeout: Duration::from_secs(5),
|
||||
throttle: Some(Duration::from_millis(200)),
|
||||
bounds: feed::DEFAULT_FEED_BOUNDS.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
struct HostInner<T>
|
||||
{
|
||||
input: mpsc::Receiver<Vec<T>>,
|
||||
shutdown: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Handle<T: Send+ chain::Chainable>
|
||||
{
|
||||
chain: RwLock<chain::Chain<T>>,
|
||||
input: mpsc::Sender<Vec<T>>,
|
||||
opt: Settings,
|
||||
notify_write: Arc<Notify>,
|
||||
push_now: Arc<Notify>,
|
||||
shutdown: watch::Sender<bool>,
|
||||
|
||||
/// Data used only for the worker task.
|
||||
host: msg::Once<HostInner<T>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChainHandle<T: Send + chain::Chainable>(Arc<Box<Handle<T>>>);
|
||||
|
||||
impl<T: Send+ chain::Chainable + 'static> ChainHandle<T>
|
||||
{
|
||||
pub fn with_settings(chain: chain::Chain<T>, opt: Settings) -> Self
|
||||
{
|
||||
let (shutdown_tx, shutdown) = watch::channel(false);
|
||||
let (itx, irx) = mpsc::channel(opt.backlog);
|
||||
Self(Arc::new(Box::new(Handle{
|
||||
chain: RwLock::new(chain),
|
||||
input: itx,
|
||||
opt,
|
||||
push_now: Arc::new(Notify::new()),
|
||||
notify_write: Arc::new(Notify::new()),
|
||||
shutdown: shutdown_tx,
|
||||
|
||||
host: msg::Once::new(HostInner{
|
||||
input: irx,
|
||||
shutdown,
|
||||
})
|
||||
})))
|
||||
}
|
||||
|
||||
/// Acquire the chain read lock
|
||||
async fn chain(&self) -> RwLockReadGuard<'_, chain::Chain<T>>
|
||||
{
|
||||
self.0.chain.read().await
|
||||
}
|
||||
|
||||
/// A reference to the chain
|
||||
pub fn chain_ref(&self) -> &RwLock<chain::Chain<T>>
|
||||
{
|
||||
&self.0.chain
|
||||
}
|
||||
|
||||
/// Create a stream that reads generated values forever.
|
||||
pub fn read(&self) -> ChainStream<T>
|
||||
{
|
||||
ChainStream{
|
||||
chain: Arc::downgrade(&self.0),
|
||||
buffer: Vec::with_capacity(self.0.opt.backlog),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send this buffer to the chain
|
||||
pub fn write(&self, buf: Vec<T>) -> impl futures::Future<Output = Result<(), SendError<Vec<T>>>> + 'static
|
||||
{
|
||||
let mut write = self.0.input.clone();
|
||||
async move {
|
||||
write.send(buf).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Send this stream buffer to the chain
|
||||
pub fn write_stream<'a, I: Stream<Item=T>>(&self, buf: I) -> impl futures::Future<Output = Result<(), SendError<Vec<T>>>> + 'a
|
||||
where I: 'a
|
||||
{
|
||||
let mut write = self.0.input.clone();
|
||||
async move {
|
||||
write.send(buf.collect().await).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Send this buffer to the chain
|
||||
pub async fn write_in_place(&self, buf: Vec<T>) -> Result<(), SendError<Vec<T>>>
|
||||
{
|
||||
self.0.input.clone().send(buf).await
|
||||
}
|
||||
|
||||
/// A referencer for the notifier
|
||||
pub fn notify_when(&self) -> &Arc<Notify>
|
||||
{
|
||||
&self.0.notify_write
|
||||
}
|
||||
|
||||
/// Force the pending buffers to be written to the chain now
|
||||
pub fn push_now(&self)
|
||||
{
|
||||
self.0.push_now.notify();
|
||||
}
|
||||
|
||||
/// Hang the worker thread, preventing it from taking any more inputs and also flushing it.
|
||||
///
|
||||
/// # Panics
|
||||
/// If there was no worker thread.
|
||||
pub fn hang(&self)
|
||||
{
|
||||
trace!("Communicating hang request");
|
||||
self.0.shutdown.broadcast(true).expect("Failed to communicate hang");
|
||||
}
|
||||
}
|
||||
|
||||
impl ChainHandle<String>
|
||||
{
|
||||
#[deprecated = "use read() pls"]
|
||||
pub async fn generate_body(&self, state: &state::State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), SendError<String>>
|
||||
{
|
||||
let chain = self.chain().await;
|
||||
if !chain.is_empty() {
|
||||
let filter = state.outbound_filter();
|
||||
match num {
|
||||
Some(num) if num < state.config().max_gen_size => {
|
||||
//This could DoS writes, potentially.
|
||||
for string in chain.str_iter_for(num) {
|
||||
output.send(filter.filter_owned(string)).await?;
|
||||
}
|
||||
},
|
||||
_ => output.send(filter.filter_owned(chain.generate_str())).await?,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Host this handle on the current task.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `from` has already been hosted.
|
||||
pub async fn host(from: ChainHandle<String>)
|
||||
{
|
||||
let opt = from.0.opt.clone();
|
||||
let mut data = from.0.host.unwrap().await;
|
||||
|
||||
let (mut tx, mut child) = {
|
||||
// The `real` input channel.
|
||||
let from = from.clone();
|
||||
let opt = opt.clone();
|
||||
let (tx, rx) = mpsc::channel::<Vec<Vec<_>>>(opt.internal_backlog);
|
||||
(tx, tokio::spawn(async move {
|
||||
let mut rx = if let Some(thr) = opt.throttle {
|
||||
time::throttle(thr, rx).boxed()
|
||||
} else {
|
||||
rx.boxed()
|
||||
};
|
||||
trace!("child: Begin waiting on parent");
|
||||
while let Some(item) = rx.next().await {
|
||||
if item.len() > 0 {
|
||||
info!("Write lock acq");
|
||||
let mut lock = from.0.chain.write().await;
|
||||
for item in item.into_iter()
|
||||
{
|
||||
use std::ops::DerefMut;
|
||||
for item in item.into_iter() {
|
||||
feed::feed(lock.deref_mut(), item, &from.0.opt.bounds);
|
||||
}
|
||||
}
|
||||
trace!("Signalling write");
|
||||
from.0.notify_write.notify();
|
||||
}
|
||||
}
|
||||
trace!("child: exiting");
|
||||
}))
|
||||
};
|
||||
|
||||
trace!("Begin polling on child");
|
||||
tokio::select!{
|
||||
v = &mut child => {
|
||||
match v {
|
||||
#[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen.
|
||||
Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel.
|
||||
}
|
||||
},
|
||||
_ = async move {
|
||||
let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh, oh well.
|
||||
|
||||
if !data.shutdown.recv().await.unwrap_or(true) { //first shutdown we get for free
|
||||
while Arc::strong_count(&from.0) > 2 {
|
||||
if *data.shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::select!{
|
||||
Some(true) = data.shutdown.recv() => {
|
||||
debug!("Got shutdown (hang) request. Sending now then breaking");
|
||||
|
||||
let mut rest = {
|
||||
let irx = rx.get_mut();
|
||||
irx.close(); //accept no more inputs
|
||||
let mut output = Vec::with_capacity(opt.capacity);
|
||||
while let Ok(item) = irx.try_recv() {
|
||||
output.push(item);
|
||||
}
|
||||
output
|
||||
};
|
||||
rest.extend(rx.take_now());
|
||||
if rest.len() > 0 {
|
||||
if let Err(err) = tx.send(rest).await {
|
||||
error!("Failed to force send buffer, exiting now: {}", err);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ = time::delay_for(opt.timeout) => {
|
||||
trace!("Setting push now");
|
||||
rx.push_now();
|
||||
}
|
||||
_ = from.0.push_now.notified() => {
|
||||
debug!("Got force push signal");
|
||||
let take =rx.take_now();
|
||||
rx.push_now();
|
||||
if take.len() > 0 {
|
||||
if let Err(err) = tx.send(take).await {
|
||||
error!("Failed to force send buffer: {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(buffer) = rx.next() => {
|
||||
debug!("Sending {} (cap {})", buffer.len(), buffer.capacity());
|
||||
if let Err(err) = tx.send(buffer).await {
|
||||
// Receive closed?
|
||||
//
|
||||
// This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`.
|
||||
// In any case, it means we should abort.
|
||||
#[cold] error!("Failed to send buffer: {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let last = rx.into_buffer();
|
||||
if last.len() > 0 {
|
||||
if let Err(err) = tx.send(last).await {
|
||||
error!("Failed to force send last part of buffer: {}", err);
|
||||
} else {
|
||||
trace!("Sent rest of buffer");
|
||||
}
|
||||
}
|
||||
} => {
|
||||
// Normal exit
|
||||
trace!("Normal exit")
|
||||
},
|
||||
}
|
||||
trace!("Waiting on child");
|
||||
// No more handles except child, no more possible inputs.
|
||||
child.await.expect("Child panic");
|
||||
trace!("Returning");
|
||||
}
|
||||
|
||||
/// Spawn a new chain handler for this chain.
|
||||
pub fn spawn(from: chain::Chain<String>, opt: Settings) -> (JoinHandle<()>, ChainHandle<String>)
|
||||
{
|
||||
debug!("Spawning with opt: {:?}", opt);
|
||||
let handle = ChainHandle::with_settings(from, opt);
|
||||
(tokio::spawn(host(handle.clone())), handle)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ChainStream<T: Send + chain::Chainable>
|
||||
{
|
||||
chain: Weak<Box<Handle<T>>>,
|
||||
buffer: Vec<T>,
|
||||
}
|
||||
|
||||
impl ChainStream<String>
|
||||
{
|
||||
async fn try_pull(&mut self, n: usize) -> Option<NonZeroUsize>
|
||||
{
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
if let Some(read) = self.chain.upgrade() {
|
||||
let chain = read.chain.read().await;
|
||||
if chain.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let n = if n == 1 {
|
||||
self.buffer.push(chain.generate_str());
|
||||
1
|
||||
} else {
|
||||
self.buffer.extend(chain.str_iter_for(n));
|
||||
n //for now
|
||||
};
|
||||
Some(unsafe{NonZeroUsize::new_unchecked(n)})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for ChainStream<String>
|
||||
{
|
||||
type Item = String;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
use futures::Future;
|
||||
let this = self.get_mut();
|
||||
|
||||
if this.buffer.len() == 0 {
|
||||
let pull = this.try_pull(this.buffer.capacity());
|
||||
tokio::pin!(pull);
|
||||
match pull.poll(cx) {
|
||||
Poll::Ready(Some(_)) => {},
|
||||
Poll::Pending => return Poll::Pending,
|
||||
_ => return Poll::Ready(None),
|
||||
};
|
||||
}
|
||||
debug_assert!(this.buffer.len()>0);
|
||||
Poll::Ready(Some(this.buffer.remove(0)))
|
||||
}
|
||||
}
|
Loading…
Reference in new issue