diff --git a/Cargo.toml b/Cargo.toml index e22f386..2884f2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.7.4" +version = "0.8.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" diff --git a/Makefile b/Makefile index bc333f4..ea25b54 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -FEATURES:="api,always-aggregate" +FEATURES:="api,always-aggregate,split-sentance" VERSION:=`cargo read-manifest | rematch - 'version":"([0-9\.]+)"' 1` markov: diff --git a/markov.toml b/markov.toml index f85b87f..6b24665 100644 --- a/markov.toml +++ b/markov.toml @@ -4,7 +4,8 @@ max_content_length = 4194304 max_gen_size = 256 save_interval_secs = 2 trust_x_forwarded_for = false +feed_bounds = '2..' [filter] -inbound = '' +inbound = '<>/\\' outbound = '' diff --git a/src/config.rs b/src/config.rs index b69d0bf..12706ca 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,6 +6,8 @@ use std::{ io, borrow::Cow, num::NonZeroU64, + error, + fmt, }; use tokio::{ fs::OpenOptions, @@ -27,6 +29,8 @@ pub struct Config pub trust_x_forwarded_for: bool, #[serde(default)] pub filter: FilterConfig, + #[serde(default)] + pub feed_bounds: String, } #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] @@ -39,7 +43,7 @@ pub struct FilterConfig impl FilterConfig { - pub fn get_inbound_filter(&self) -> sanitise::filter::Filter + fn get_inbound_filter(&self) -> sanitise::filter::Filter { let filt: sanitise::filter::Filter = self.inbound.parse().unwrap(); if !filt.is_empty() @@ -48,7 +52,7 @@ impl FilterConfig } filt } - pub fn get_outbound_filter(&self) -> sanitise::filter::Filter + fn get_outbound_filter(&self) -> sanitise::filter::Filter { let filt: sanitise::filter::Filter = self.outbound.parse().unwrap(); if !filt.is_empty() @@ -72,12 +76,45 @@ impl Default for Config save_interval_secs: Some(unsafe{NonZeroU64::new_unchecked(2)}), trust_x_forwarded_for: false, filter: Default::default(), + feed_bounds: "2..".to_owned(), } } } impl Config { + /// Try to generate a config cache for this instance. + pub fn try_gen_cache(&self) -> Result + { + macro_rules! section { + ($name:literal, $expr:expr) => { + match $expr { + Ok(v) => Ok(v), + Err(e) => Err(InvalidConfigError($name, Box::new(e))), + } + } + } + use std::ops::RangeBounds; + + Ok(Cache { + feed_bounds: section!("feed_bounds", self.parse_feed_bounds()).and_then(|bounds| if bounds.contains(&0) { + Err(InvalidConfigError("feed_bounds", Box::new(opaque_error!("Bounds not allowed to contains 0 (they were `{}`)", bounds)))) + } else { + Ok(bounds) + })?, + inbound_filter: self.filter.get_inbound_filter(), + outbound_filter: self.filter.get_outbound_filter(), + }) + } + /// Try to parse the `feed_bounds` + fn parse_feed_bounds(&self) -> Result, range::ParseError> + { + if self.feed_bounds.len() == 0 { + Ok(feed::DEFAULT_FEED_BOUNDS.into()) + } else { + self.feed_bounds.parse() + } + } pub fn save_interval(&self) -> Option { self.save_interval_secs.map(|x| Duration::from_secs(x.into())) @@ -139,3 +176,52 @@ async fn load_args>(mut from: I) -> Option }, } } + +#[derive(Debug)] +pub struct InvalidConfigError(&'static str, Box); + +impl InvalidConfigError +{ + pub fn field(&self) -> &str + { + &self.0[..] + } +} + +impl error::Error for InvalidConfigError +{ + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + Some(self.1.as_ref()) + } +} + +impl fmt::Display for InvalidConfigError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f,"failed to parse field `{}`: {}", self.0, self.1) + } +} + + +/// Caches some parsed config arguments +#[derive(Clone, PartialEq, Eq)] +pub struct Cache +{ + pub feed_bounds: range::DynRange, + pub inbound_filter: sanitise::filter::Filter, + pub outbound_filter: sanitise::filter::Filter, +} + +impl fmt::Debug for Cache +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + f.debug_struct("Cache") + .field("feed_bounds", &self.feed_bounds) + .field("inbound_filter", &self.inbound_filter.iter().collect::()) + .field("outbound_filter", &self.outbound_filter.iter().collect::()) + .finish() + } +} + diff --git a/src/feed.rs b/src/feed.rs index aaa2dec..ef91bc5 100644 --- a/src/feed.rs +++ b/src/feed.rs @@ -3,7 +3,7 @@ use super::*; #[cfg(any(feature="feed-sentance", feature="split-sentance"))] use sanitise::Sentance; -const FEED_BOUNDS: std::ops::RangeFrom = 2..; //TODO: Add to config somehow +pub const DEFAULT_FEED_BOUNDS: std::ops::RangeFrom = 2..; //TODO: Add to config somehow /// Feed `what` into `chain`, at least `bounds` tokens. /// @@ -35,10 +35,11 @@ pub fn feed(chain: &mut Chain, what: impl AsRef, bounds: impl std:: debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds"); for map in map {// feed each sentance seperately if bounds.contains(&map.len()) { + debug!("Feeding chain {} items", map.len()); chain.feed(map); } else { - debug!("Ignoring feed of invalid length {}", map.len()); + debug!("Ignoring feed of invalid length {}: {:?}", map.len(), map); } } } else { @@ -49,16 +50,17 @@ pub fn feed(chain: &mut Chain, what: impl AsRef, bounds: impl std:: .flatten() // add all into one buffer .map(|s| s.to_owned()).collect::>(); } else { - let map: Vec<_> = sanitise::Word::new_iter(what.as_ref()).map(ToOwned::to_owned) + let map: Vec<_> = sanitise::words(what.as_ref()).map(ToOwned::to_owned) .collect(); } } debug_assert!(!bounds.contains(&0), "Cannot allow 0 size feeds"); if bounds.contains(&map.len()) { + debug!("Feeding chain {} items", map.len()); chain.feed(map); } else { - debug!("Ignoring feed of invalid length {}", map.len()); + debug!("Ignoring feed of invalid length {}: {:?}", map.len(), map); } } @@ -71,11 +73,12 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream { + ($chain:expr, $buffer:ident) => { { let buffer = $buffer; - feed($chain, &buffer, $bounds) + feed($chain, &buffer, bounds) } } } @@ -102,10 +105,10 @@ pub async fn full(who: &IpAddr, state: State, body: impl Unpin + Stream {:?}", who, line); } written+=line.len(); diff --git a/src/main.rs b/src/main.rs index 82758b1..a0b6f30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -63,6 +63,7 @@ macro_rules! status { mod ext; use ext::*; mod util; +mod range; mod sanitise; mod bytes; mod chunking; @@ -105,8 +106,19 @@ fn init_log() async fn main() { init_log(); - let config = match config::load().await { - Some(v) => v, + let (config, ccache) = match config::load().await { + Some(v) => { + let cache = match v.try_gen_cache() { + Ok(c) => c, + Err(e) => { + error!("Invalid config, cannot continue"); + error!("{}", e); + debug!("{:?}", e); + return; + }, + }; + (v, cache) + }, _ => { let cfg = config::Config::default(); #[cfg(debug_assertions)] @@ -115,10 +127,12 @@ async fn main() { error!("Failed to create default config file: {}", err); } } - cfg + let cache= cfg.try_gen_cache().unwrap(); + (cfg, cache) }, }; - trace!("Using config {:?}", config); + debug!("Using config {:?}", config); + trace!("With config cached: {:?}", ccache); let chain = Arc::new(RwLock::new(match save::load(&config.file).await { Ok(chain) => { @@ -137,6 +151,7 @@ async fn main() { let save_when = Arc::new(Notify::new()); let state = State::new(config, + ccache, Arc::clone(&chain), Arc::clone(&save_when)); let state2 = state.clone(); diff --git a/src/range.rs b/src/range.rs new file mode 100644 index 0000000..a1e86e4 --- /dev/null +++ b/src/range.rs @@ -0,0 +1,287 @@ +//! Workarounds for ridiculously janky `std::ops::Range*` polymorphism +use super::*; +use std::{ + ops::{ + Range, + RangeFrom, + RangeInclusive, + RangeTo, + RangeToInclusive, + RangeFull, + + Bound, + RangeBounds, + }, + str::{ + FromStr, + }, + fmt, + error, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum DynRange +{ + Range(Range), + From(RangeFrom), + Inclusive(RangeInclusive), + To(RangeTo), + ToInclusive(RangeToInclusive), + Full(RangeFull), +} + +#[macro_export] macro_rules! impl_from { + (Full, RangeFull) => { + impl From for DynRange + { + #[inline] fn from(from: RangeFull) -> Self + { + Self::Full(from) + } + } + }; + ($name:ident, $range:tt) => { + + impl From<$range > for DynRange + { + #[inline] fn from(from: $range) -> Self + { + Self::$name(from) + } + } + }; +} + +impl_from!(Range, Range); +impl_from!(From, RangeFrom); +impl_from!(Inclusive, RangeInclusive); +impl_from!(To, RangeTo); +impl_from!(ToInclusive, RangeToInclusive); +impl_from!(Full, RangeFull); + +macro_rules! bounds { + ($self:ident, $bound:ident) => { + match $self { + DynRange::Range(from) => from.$bound(), + DynRange::From(from) => from.$bound(), + DynRange::Inclusive(i) => i.$bound(), + DynRange::To(i) => i.$bound(), + DynRange::ToInclusive(i) => i.$bound(), + DynRange::Full(_) => (..).$bound(), + } + }; +} + +impl RangeBounds for DynRange +{ + fn start_bound(&self) -> Bound<&T> { + bounds!(self, start_bound) + } + fn end_bound(&self) -> Bound<&T> { + bounds!(self, end_bound) + } +} + +impl<'a, T> RangeBounds for &'a DynRange +{ + fn start_bound(&self) -> Bound<&T> { + bounds!(self, start_bound) + } + fn end_bound(&self) -> Bound<&T> { + bounds!(self, end_bound) + } +} + +impl fmt::Display for DynRange +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + match self { + Self::Range(from) => write!(f, "{}..{}", from.start, from.end), + Self::From(from) => write!(f, "{}..", from.start), + Self::Inclusive(from) => write!(f, "{}..={}", from.start(), from.end()), + Self::To(from) => write!(f, "..{}", from.end), + Self::ToInclusive(from) => write!(f, "..={}", from.end), + Self::Full(_) => write!(f, ".."), + } + } +} + +use std::any::{ + Any, +}; + +impl DynRange +{ + + fn into_inner(self) -> Box + { + match self { + Self::Range(from) => Box::new(from), + Self::From(from) => Box::new(from), + Self::Inclusive(from) => Box::new(from), + Self::To(from) => Box::new(from), + Self::ToInclusive(from) => Box::new(from), + Self::Full(_) => Box::new(..), + } + } + fn inner_mut(&mut self) -> &mut dyn Any + { + match self { + Self::Range(from) => from, + Self::From(from) => from, + Self::Inclusive(from) => from, + Self::To(from) => from, + Self::ToInclusive(from) => from, + Self::Full(f) => f, + } + } + fn inner_ref(&self) -> &dyn Any + { + match self { + Self::Range(from) => from, + Self::From(from) => from, + Self::Inclusive(from) => from, + Self::To(from) => from, + Self::ToInclusive(from) => from, + Self::Full(_) => &(..), + } + } + pub fn downcast_ref + 'static>(&self) -> Option<&R> + { + self.inner_ref().downcast_ref() + } + pub fn downcast_mut + 'static>(&mut self) -> Option<&mut R> + { + self.inner_mut().downcast_mut() + } + pub fn downcast + 'static>(self) -> Result + { + Box::new(self).downcast() + } +} + +#[derive(Debug)] +pub struct ParseError(DynRange<()>, Option>); + +impl ParseError +{ + fn new>>(which: R, err: impl error::Error + 'static) -> Self + { + Self(which.into(), Some(Box::new(err))) + } + fn none(which: impl Into>) -> Self + { + Self(which.into(), None) + } + fn map>>(self, to: T) -> Self + { + Self (to.into(), self.1) + } +} + +impl error::Error for ParseError +{ + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + if let Some(this) = self.1.as_ref() { + Some(this.as_ref()) + } else { + None + } + } +} + +impl fmt::Display for ParseError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "failed to parse range in format `{:?}`", self.0)?; + if let Some(this) = self.1.as_ref() { + write!(f, ": {}", this)?; + } + Ok(()) + } +} + + +impl FromStr for DynRange +where T::Err: error::Error + 'static +{ + type Err = ParseError; + fn from_str(s: &str) -> Result { + if s== ".." { + Ok(Self::Full(..)) + } else if s.starts_with("..=") { + Ok(Self::ToInclusive(..=T::from_str(&s[3..]).map_err(|x| ParseError::new(..=(), x))?)) + } else if s.starts_with("..") { + Ok(Self::To(..(T::from_str(&s[2..])).map_err(|x| ParseError::new(..(), x))?)) + } else if s.ends_with("..") { + Ok(Self::From(T::from_str(&s[..s.len()-2]).map_err(|x| ParseError::new(().., x))?..)) + } else { + fn try_next_incl<'a, T: FromStr>(m: &mut impl Iterator) -> Result, ParseError> + where T::Err: error::Error + 'static + { + let (first, second) = if let Some(first) = m.next() { + if let Some(seocond) = m.next() { + (first,seocond) + } else { + return Err(ParseError::none(()..=())); + } + } else { + return Err(ParseError::none(()..=())); + }; + + let first: T = first.parse().map_err(|x| ParseError::new(()..=(), x))?; + let second: T = second.parse().map_err(|x| ParseError::new(()..=(), x))?; + + Ok(first..=second) + } + + fn try_next<'a, T: FromStr>(m: &mut impl Iterator) -> Result, ParseError> + where T::Err: error::Error + 'static + { + let (first, second) = if let Some(first) = m.next() { + if let Some(seocond) = m.next() { + (first,seocond) + } else { + return Err(ParseError::none(()..())); + } + } else { + return Err(ParseError::none(()..())); + }; + + let first: T = first.parse().map_err(|x| ParseError::new(()..(), x))?; + let second: T = second.parse().map_err(|x| ParseError::new(()..(), x))?; + + Ok(first..second) + } + + + let mut split = s.split("..=").fuse(); + + let mut last_err = ParseError::none(()..()); + match loop { + match try_next_incl(&mut split) { + Err(ParseError(_, None)) => break Err(last_err), // iter empty + Err(other) => last_err = other, + Ok(value) => break Ok(Self::Inclusive(value)), + } + } { + Ok(v) => return Ok(v), + Err(e) => last_err = e, + }; + + let mut split = s.split("..").fuse(); + match loop { + match try_next(&mut split) { + Err(ParseError(_, None)) => break Err(last_err), // iter empty + Err(other) => last_err = other, + Ok(value) => break Ok(Self::Range(value)), + } + } { + Ok(v) => Ok(v), + Err(e) => Err(e), + } + } + } +} diff --git a/src/sanitise/word.rs b/src/sanitise/word.rs index 320fe88..c50a5fc 100644 --- a/src/sanitise/word.rs +++ b/src/sanitise/word.rs @@ -139,3 +139,11 @@ impl AsRef for Word self } } + +pub fn words(input: &str) -> impl Iterator +{ + input.split_inclusive(is_word_boundary) + .map(|x| x.trim()) + .filter(|x| !x.is_empty()) + .map(|x| new!(x)) +} diff --git a/src/state.rs b/src/state.rs index 971353e..f5cd137 100644 --- a/src/state.rs +++ b/src/state.rs @@ -24,8 +24,7 @@ impl fmt::Display for ShutdownError #[derive(Debug, Clone)] pub struct State { - config: Arc, //to avoid cloning config - exclude: Arc<(sanitise::filter::Filter, sanitise::filter::Filter)>, + config: Arc>, //to avoid cloning config chain: Arc>>, save: Arc, begin: Initialiser, @@ -72,20 +71,18 @@ impl State pub fn inbound_filter(&self) -> &sanitise::filter::Filter { - &self.exclude.0 + &self.config_cache().inbound_filter } pub fn outbound_filter(&self) -> &sanitise::filter::Filter { - &self.exclude.1 + &self.config_cache().outbound_filter } - pub fn new(config: Config, chain: Arc>>, save: Arc) -> Self + pub fn new(config: Config, cache: config::Cache, chain: Arc>>, save: Arc) -> Self { let (shutdown, shutdown_recv) = watch::channel(false); Self { - exclude: Arc::new((config.filter.get_inbound_filter(), - config.filter.get_outbound_filter())), - config: Arc::new(config), + config: Arc::new(Box::new((config, cache))), chain, save, begin: Initialiser::new(), @@ -96,7 +93,12 @@ impl State pub fn config(&self) -> &Config { - self.config.as_ref() + &self.config.as_ref().0 + } + + pub fn config_cache(&self) -> &config::Cache + { + &self.config.as_ref().1 } pub fn notify_save(&self) diff --git a/src/util.rs b/src/util.rs index 2a5d5a3..ec1f3e7 100644 --- a/src/util.rs +++ b/src/util.rs @@ -39,3 +39,54 @@ pub fn hint_cap(iter: &I) -> T (_, Some(x)) | (x, _) => T::with_capacity(x) } } + +#[macro_export] macro_rules! opaque_error { + ($msg:literal) => { + { + #[derive(Debug)] + struct OpaqueError; + + impl ::std::error::Error for OpaqueError{} + impl ::std::fmt::Display for OpaqueError + { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result + { + write!(f, $msg) + } + } + OpaqueError + } + }; + ($msg:literal $($tt:tt)*) => { + { + #[derive(Debug)] + struct OpaqueError(String); + + impl ::std::error::Error for OpaqueError{} + impl ::std::fmt::Display for OpaqueError + { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result + { + write!(f, "{}", self.0) + } + } + OpaqueError(format!($msg $($tt)*)) + } + }; + (yield $msg:literal $($tt:tt)*) => { + { + #[derive(Debug)] + struct OpaqueError<'a>(fmt::Arguments<'a>); + + impl ::std::error::Error for OpaqueError{} + impl ::std::fmt::Display for OpaqueError + { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result + { + write!(f, "{}", self.0) + } + } + OpaqueError(format_args!($msg $($tt)*)) + } + }; +}