You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
genmarkov/src/handle.rs

286 lines
6.4 KiB

//! Chain handler.
use super::*;
use std::{
marker::Send,
sync::Weak,
num::NonZeroUsize,
task::{Poll, Context,},
pin::Pin,
};
use tokio::{
sync::{
RwLock,
RwLockReadGuard,
mpsc::{
self,
error::SendError,
},
},
task::JoinHandle,
time::{
self,
Duration,
},
};
use futures::StreamExt;
/// Settings for chain handler
#[derive(Debug, Clone, PartialEq)]
pub struct Settings
{
pub backlog: usize,
pub capacity: usize,
pub timeout: Duration,
pub throttle: Option<Duration>,
pub bounds: range::DynRange<usize>,
}
impl Settings
{
/// Should we keep this string.
#[inline] fn matches(&self, s: &str) -> bool
{
true
}
}
impl Default for Settings
{
#[inline]
fn default() -> Self
{
Self {
backlog: 32,
capacity: 4,
timeout: Duration::from_secs(5),
throttle: Some(Duration::from_millis(200)),
bounds: feed::DEFAULT_FEED_BOUNDS.into(),
}
}
}
#[derive(Debug)]
struct HostInner<T>
{
input: mpsc::Receiver<Vec<T>>,
}
#[derive(Debug)]
struct Handle<T: Send+ chain::Chainable>
{
chain: RwLock<chain::Chain<T>>,
input: mpsc::Sender<Vec<T>>,
opt: Settings,
/// Data used only for the worker task.
host: msg::Once<HostInner<T>>,
}
#[derive(Clone, Debug)]
pub struct ChainHandle<T: Send + chain::Chainable>(Arc<Box<Handle<T>>>);
impl<T: Send+ chain::Chainable> ChainHandle<T>
{
#[inline] pub fn new(chain: chain::Chain<T>) -> Self
{
Self::with_settings(chain, Default::default())
}
pub fn with_settings(chain: chain::Chain<T>, opt: Settings) -> Self
{
let (itx, irx) = mpsc::channel(opt.backlog);
Self(Arc::new(Box::new(Handle{
chain: RwLock::new(chain),
input: itx,
opt,
host: msg::Once::new(HostInner{
input: irx,
})
})))
}
/// Acquire the chain read lock
async fn chain(&self) -> RwLockReadGuard<'_, chain::Chain<T>>
{
self.0.chain.read().await
}
/// A reference to the chain
pub fn chain_ref(&self) -> &RwLock<chain::Chain<T>>
{
&self.0.chain
}
/// Create a stream that reads generated values forever.
pub fn read(&self) -> ChainStream<T>
{
ChainStream{
chain: Arc::downgrade(&self.0),
buffer: Vec::with_capacity(self.0.opt.backlog),
}
}
/// Send this buffer to the chain
pub async fn write(&self, buf: Vec<T>) -> Result<(), SendError<Vec<T>>>
{
self.0.input.clone().send(buf).await
}
}
impl ChainHandle<String>
{
#[deprecated = "use read() pls"]
pub async fn generate_body(&self, state: &state::State, num: Option<usize>, mut output: mpsc::Sender<String>) -> Result<(), SendError<String>>
{
let chain = self.chain().await;
if !chain.is_empty() {
let filter = state.outbound_filter();
match num {
Some(num) if num < state.config().max_gen_size => {
//This could DoS writes, potentially.
for string in chain.str_iter_for(num) {
output.send(filter.filter_owned(string)).await?;
}
},
_ => output.send(filter.filter_owned(chain.generate_str())).await?,
}
}
Ok(())
}
}
/// Host this handle on the current task.
///
/// # Panics
/// If `from` has already been hosted.
pub async fn host(from: ChainHandle<String>)
{
let opt = from.0.opt.clone();
let data = from.0.host.unwrap().await;
let (mut tx, child) = {
// The `real` input channel.
let from = from.clone();
let opt = opt.clone();
let (tx, rx) = mpsc::channel::<Vec<Vec<_>>>(opt.backlog);
(tx, tokio::spawn(async move {
let mut rx = if let Some(thr) = opt.throttle {
time::throttle(thr, rx).boxed()
} else {
rx.boxed()
};
trace!("child: Begin waiting on parent");
while let Some(item) = rx.next().await {
let mut lock = from.0.chain.write().await;
for item in item.into_iter()
{
use std::ops::DerefMut;
for item in item.into_iter() {
feed::feed(lock.deref_mut(), item, &from.0.opt.bounds);
}
}
}
trace!("child: exiting");
}))
};
trace!("Begin polling on child");
tokio::select!{
v = child => {
match v {
#[cold] Ok(_) => {warn!("Child exited before we have? This should probably never happen.")},//Should never happen.
Err(e) => {error!("Child exited abnormally. Aborting: {}", e)}, //Child panic or cancel.
}
},
_ = async move {
let mut rx = data.input.chunk(opt.capacity); //we don't even need this tbh
while Arc::strong_count(&from.0) > 2 {
tokio::select!{
_ = time::delay_for(opt.timeout) => {
rx.push_now();
}
Some(buffer) = rx.next() => {
if let Err(err) = tx.send(buffer).await {
// Receive closed?
//
// This probably shouldn't happen, as we `select!` for it up there and child never calls `close()` on `rx`.
// In any case, it means we should abort.
error!("Failed to send buffer: {}", err);
break;
}
}
}
}
} => {
// Normal exit
trace!("Normal exit")
},
}
// No more handles except child, no more possible inputs.
trace!("Returning");
}
/// Spawn a new chain handler for this chain.
pub fn spawn(from: chain::Chain<String>, opt: Settings) -> (JoinHandle<()>, ChainHandle<String>)
{
let handle = ChainHandle::with_settings(from, opt);
(tokio::spawn(host(handle.clone())), handle)
}
#[derive(Debug)]
pub struct ChainStream<T: Send + chain::Chainable>
{
chain: Weak<Box<Handle<T>>>,
buffer: Vec<T>,
}
impl ChainStream<String>
{
async fn try_pull(&mut self, n: usize) -> Option<NonZeroUsize>
{
if n == 0 {
return None;
}
if let Some(read) = self.chain.upgrade() {
let chain = read.chain.read().await;
if chain.is_empty() {
return None;
}
let n = if n == 1 {
self.buffer.push(chain.generate_str());
1
} else {
self.buffer.extend(chain.str_iter_for(n));
n //for now
};
Some(unsafe{NonZeroUsize::new_unchecked(n)})
} else {
None
}
}
}
impl Stream for ChainStream<String>
{
type Item = String;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use futures::Future;
let this = self.get_mut();
if this.buffer.len() == 0 {
let pull = this.try_pull(this.buffer.capacity());
tokio::pin!(pull);
match pull.poll(cx) {
Poll::Ready(Some(_)) => {},
Poll::Pending => return Poll::Pending,
_ => return Poll::Ready(None),
};
}
debug_assert!(this.buffer.len()>0);
Poll::Ready(Some(this.buffer.remove(0)))
}
}