//! Chain handler. use super::*; use std::{ marker::Send, sync::Weak, num::NonZeroUsize, task::{Poll, Context,}, pin::Pin, }; use tokio::{ sync::{ RwLock, RwLockReadGuard, mpsc::{ self, error::SendError, }, watch, Notify, }, task::JoinHandle, time::{ self, Duration, }, }; use futures::StreamExt; pub const DEFAULT_TIMEOUT: Duration= Duration::from_secs(5); /// Settings for chain handler #[derive(Debug, Clone, PartialEq)] pub struct Settings { pub backlog: usize, pub internal_backlog: usize, pub capacity: usize, pub timeout: Duration, pub throttle: Option, pub bounds: range::DynRange, } impl Settings { /// Should we keep this string. #[inline] fn matches(&self, _s: &str) -> bool { true } } impl Default for Settings { #[inline] fn default() -> Self { Self { backlog: 32, internal_backlog: 8, capacity: 4, timeout: Duration::from_secs(5), throttle: Some(Duration::from_millis(200)), bounds: feed::DEFAULT_FEED_BOUNDS.into(), } } } #[derive(Debug)] struct HostInner { input: mpsc::Receiver>, shutdown: watch::Receiver, } #[derive(Debug)] struct Handle { chain: RwLock>, input: mpsc::Sender>, opt: Settings, notify_write: Arc, push_now: Arc, shutdown: watch::Sender, /// Data used only for the worker task. host: msg::Once>, } #[derive(Clone, Debug)] pub struct ChainHandle(Arc>>); impl ChainHandle { pub fn with_settings(chain: chain::Chain, opt: Settings) -> Self { let (shutdown_tx, shutdown) = watch::channel(false); let (itx, irx) = mpsc::channel(opt.backlog); Self(Arc::new(Box::new(Handle{ chain: RwLock::new(chain), input: itx, opt, push_now: Arc::new(Notify::new()), notify_write: Arc::new(Notify::new()), shutdown: shutdown_tx, host: msg::Once::new(HostInner{ input: irx, shutdown, }) }))) } /// Acquire the chain read lock async fn chain(&self) -> RwLockReadGuard<'_, chain::Chain> { self.0.chain.read().await } /// A reference to the chain pub fn chain_ref(&self) -> &RwLock> { &self.0.chain } /// Create a stream that reads generated values forever. pub fn read(&self) -> ChainStream { ChainStream{ chain: Arc::downgrade(&self.0), buffer: Vec::with_capacity(self.0.opt.backlog), } } /// Send this buffer to the chain pub fn write(&self, buf: Vec) -> impl futures::Future>>> + 'static { let mut write = self.0.input.clone(); async move { write.send(buf).await } } /// Send this stream buffer to the chain pub fn write_stream<'a, I: Stream>(&self, buf: I) -> impl futures::Future>>> + 'a where I: 'a { let mut write = self.0.input.clone(); async move { write.send(buf.collect().await).await } } /// Send this buffer to the chain pub async fn write_in_place(&self, buf: Vec) -> Result<(), SendError>> { self.0.input.clone().send(buf).await } /// A referencer for the notifier pub fn notify_when(&self) -> &Arc { &self.0.notify_write } /// Force the pending buffers to be written to the chain now pub fn push_now(&self) { self.0.push_now.notify(); } /// Hang the worker thread, preventing it from taking any more inputs and also flushing it. /// /// # Panics /// If there was no worker thread. pub fn hang(&self) { trace!("Communicating hang request"); self.0.shutdown.broadcast(true).expect("Failed to communicate hang"); } } impl ChainHandle { #[deprecated = "use read() pls"] pub async fn generate_body(&self, state: &state::State, num: Option, mut output: mpsc::Sender) -> Result<(), SendError> { let chain = self.chain().await; if !chain.is_empty() { let filter = state.outbound_filter(); match num { Some(num) if num < state.config().max_gen_size => { //This could DoS writes, potentially. for string in chain.str_iter_for(num) { output.send(filter.filter_owned(string)).await?; } }, _ => output.send(filter.filter_owned(chain.generate_str())).await?, } } Ok(()) } } /// Host this handle on the current task. /// /// # Panics /// If `from` has already been hosted. pub async fn host(from: ChainHandle) { let opt = from.0.opt.clone(); let mut data = from.0.host.unwrap().await; let (mut tx, mut child) = { // The `real` input channel. let from = from.clone(); let opt = opt.clone(); let (tx, rx) = mpsc::channel::>>(opt.internal_backlog); (tx, tokio::spawn(async move { let mut rx = if let Some(thr) = opt.throttle { time::throttle(thr, rx).boxed() } else { rx.boxed() }; trace!("child: Begin waiting on parent"); while let Some(item) = rx.next().await { if item.len() > 0 { info!("Write lock acq"); let mut lock = from.0.chain.write().await; for item in item.into_iter() { use std::ops::DerefMut; for item in item.into_iter() { feed::feed(lock.deref_mut(), item, &from.0.opt.bounds); } } trace!("Signalling write"); from.0.notify_write.notify(); } } trace!("child: exiting"); })) }; trace!("Begin polling on child"); tokio::select!{ v = &mut child => { match v { /*#[cold]*/ Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen. Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel. } }, _ = async move { let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh, oh well. if !data.shutdown.recv().await.unwrap_or(true) { //first shutdown we get for free while Arc::strong_count(&from.0) > 2 { if *data.shutdown.borrow() { break; } tokio::select!{ Some(true) = data.shutdown.recv() => { debug!("Got shutdown (hang) request. Sending now then breaking"); let mut rest = { let irx = rx.get_mut(); irx.close(); //accept no more inputs let mut output = Vec::with_capacity(opt.capacity); while let Ok(item) = irx.try_recv() { output.push(item); } output }; rest.extend(rx.take_now()); if rest.len() > 0 { if let Err(err) = tx.send(rest).await { error!("Failed to force send buffer, exiting now: {}", err); } } break; } _ = time::delay_for(opt.timeout) => { trace!("Setting push now"); rx.push_now(); } _ = from.0.push_now.notified() => { debug!("Got force push signal"); let take =rx.take_now(); rx.push_now(); if take.len() > 0 { if let Err(err) = tx.send(take).await { error!("Failed to force send buffer: {}", err); break; } } } Some(buffer) = rx.next() => { debug!("Sending {} (cap {})", buffer.len(), buffer.capacity()); if let Err(err) = tx.send(buffer).await { // Receive closed? // // This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`. // In any case, it means we should abort. /*#[cold]*/ error!("Failed to send buffer: {}", err); break; } } } } } let last = rx.into_buffer(); if last.len() > 0 { if let Err(err) = tx.send(last).await { error!("Failed to force send last part of buffer: {}", err); } else { trace!("Sent rest of buffer"); } } } => { // Normal exit trace!("Normal exit") }, } trace!("Waiting on child"); // No more handles except child, no more possible inputs. child.await.expect("Child panic"); trace!("Returning"); } /// Spawn a new chain handler for this chain. pub fn spawn(from: chain::Chain, opt: Settings) -> (JoinHandle<()>, ChainHandle) { debug!("Spawning with opt: {:?}", opt); let handle = ChainHandle::with_settings(from, opt); (tokio::spawn(host(handle.clone())), handle) } #[derive(Debug)] pub struct ChainStream { chain: Weak>>, buffer: Vec, } impl ChainStream { async fn try_pull(&mut self, n: usize) -> Option { if n == 0 { return None; } if let Some(read) = self.chain.upgrade() { let chain = read.chain.read().await; if chain.is_empty() { return None; } let n = if n == 1 { self.buffer.push(chain.generate_str()); 1 } else { self.buffer.extend(chain.str_iter_for(n)); n //for now }; Some(unsafe{NonZeroUsize::new_unchecked(n)}) } else { None } } } impl Stream for ChainStream { type Item = String; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use futures::Future; let this = self.get_mut(); if this.buffer.len() == 0 { let pull = this.try_pull(this.buffer.capacity()); tokio::pin!(pull); match pull.poll(cx) { Poll::Ready(Some(_)) => {}, Poll::Pending => return Poll::Pending, _ => return Poll::Ready(None), }; } debug_assert!(this.buffer.len()>0); Poll::Ready(Some(this.buffer.remove(0))) } }