From ef5dc3cbf1bdc83c0437e546c1efcbcfbf2c450a Mon Sep 17 00:00:00 2001 From: Avril Date: Mon, 12 Oct 2020 04:15:06 +0100 Subject: [PATCH] accepts AF_UNIX --- Cargo.lock | 2 +- Cargo.toml | 5 +- TODO | 3 - markov.toml | 6 +- src/bind.rs | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/config.rs | 4 +- src/ext.rs | 58 ++++++++++++++++- src/main.rs | 56 ++++++++++++++--- 8 files changed, 283 insertions(+), 22 deletions(-) delete mode 100644 TODO create mode 100644 src/bind.rs diff --git a/Cargo.lock b/Cargo.lock index 2306f39..25d5512 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -639,7 +639,7 @@ dependencies = [ [[package]] name = "markov" -version = "0.6.3" +version = "0.7.0" dependencies = [ "async-compression", "cfg-if 1.0.0", diff --git a/Cargo.toml b/Cargo.toml index 92748c0..b7431b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,10 @@ [package] name = "markov" -version = "0.6.4" +version = "0.7.0" description = "Generate string of text from Markov chain fed by stdin" authors = ["Avril "] edition = "2018" +license = "gpl-3.0-or-later" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -16,7 +17,7 @@ compress-chain = ["async-compression"] # Treat each new line as a new set to feed instead of feeding the whole data at once split-newlines = [] -# Feed each sentance seperately, instead of just each line / whole body +# Feed each sentance seperately with default /get api, instead of just each line / whole body # Maybe better without `split-newlines`? # Kinda experimental split-sentance = [] diff --git a/TODO b/TODO deleted file mode 100644 index eeabf8c..0000000 --- a/TODO +++ /dev/null @@ -1,3 +0,0 @@ -Maybe see if `split-sentance` is stable enough to be enabled in prod now? -Allow Unix domain socket for bind - diff --git a/markov.toml b/markov.toml index df5fde0..f85b87f 100644 --- a/markov.toml +++ b/markov.toml @@ -2,9 +2,9 @@ bindpoint = '127.0.0.1:8001' file = 'chain.dat' max_content_length = 4194304 max_gen_size = 256 -#save_interval_secs = 2 +save_interval_secs = 2 trust_x_forwarded_for = false [filter] -inbound = "<>)([]/" -outbound = "*" +inbound = '' +outbound = '' diff --git a/src/bind.rs b/src/bind.rs new file mode 100644 index 0000000..810ec44 --- /dev/null +++ b/src/bind.rs @@ -0,0 +1,171 @@ +//! For binding to sockets +use super::*; +use futures::{ + prelude::*, +}; +use std::{ + marker::{ + Send, + Unpin, + }, + fmt, + error, + path::{ + Path, + PathBuf, + }, +}; +use tokio::{ + io::{ + self, + AsyncRead, + AsyncWrite, + }, +}; + +#[derive(Debug)] +pub enum BindError +{ + IO(io::Error), + Warp(warp::Error), + Other(E), +} + +impl error::Error for BindError +{ + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + Some(match &self { + Self::IO(io) => io, + Self::Other(o) => o, + Self::Warp(w) => w, + }) + } +} +impl fmt::Display for BindError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + match self { + Self::IO(io) => write!(f, "io error: {}", io), + Self::Other(other) => write!(f, "{}", other), + Self::Warp(warp) => write!(f, "server error: {}", warp), + } + } +} + + +#[derive(Debug)] +pub struct BindpointParseError; + +impl error::Error for BindpointParseError{} +impl fmt::Display for BindpointParseError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "Failed to parse bindpoint as IP or unix domain socket") + } +} + + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd)] +pub enum Bindpoint +{ + Unix(PathBuf), + TCP(SocketAddr), +} + +impl fmt::Display for Bindpoint +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + match self { + Self::Unix(unix) => write!(f, "unix:/{}", unix.to_string_lossy()), + Self::TCP(tcp) => write!(f, "{}", tcp), + } + } +} + +impl std::str::FromStr for Bindpoint +{ + type Err = BindpointParseError; + fn from_str(s: &str) -> Result { + Ok(if let Ok(ip) = s.parse::() { + Self::TCP(ip) + } else if s.starts_with("unix:/") { + Self::Unix(PathBuf::from(&s[6..].to_owned())) + } else { + return Err(BindpointParseError); + }) + } +} + +fn bind_unix(to: impl AsRef) -> io::Result>>> +{ + debug!("Binding to AF_UNIX: {:?}", to.as_ref()); + let listener = tokio::net::UnixListener::bind(to)?; + Ok(listener) +} + +pub fn serve(server: warp::Server, bind: Bindpoint, signal: impl Future + Send + 'static) -> Result<(Bindpoint, BoxFuture<'static, ()>), BindError> +where F: Filter + Clone + Send + Sync + 'static, +::Ok: warp::Reply, +{ + Ok(match bind { + Bindpoint::TCP(sock) => server.try_bind_with_graceful_shutdown(sock, signal).map(|(sock, fut)| (Bindpoint::TCP(sock), fut.boxed())).map_err(BindError::Warp)?, + Bindpoint::Unix(unix) => { + (Bindpoint::Unix(unix.clone()), + server.serve_incoming_with_graceful_shutdown(bind_unix(unix).map_err(BindError::IO)?, signal).boxed()) + }, + }) +} + +impl From for Bindpoint +{ + fn from(from: SocketAddr) -> Self + { + Self::TCP(from) + } +} + +pub fn try_serve(server: warp::Server, bind: impl TryBindpoint, signal: impl Future + Send + 'static) -> Result<(Bindpoint, BoxFuture<'static, ()>), BindError> +where F: Filter + Clone + Send + Sync + 'static, +::Ok: warp::Reply, +{ + serve(server, bind.try_parse().map_err(BindError::Other)?, signal).map_err(BindError::coerce) +} + +pub trait TryBindpoint: Sized +{ + type Err: error::Error + 'static; + fn try_parse(self) -> Result; +} + +impl TryBindpoint for Bindpoint +{ + type Err = std::convert::Infallible; + fn try_parse(self) -> Result + { + Ok(self) + } +} + +impl> TryBindpoint for T +{ + type Err = BindpointParseError; + fn try_parse(self) -> Result + { + self.as_ref().parse() + } +} + +impl BindError +{ + pub fn coerce(self) -> BindError + { + match self { + Self::Warp(w) => BindError::Warp(w), + Self::IO(w) => BindError::IO(w), + #[cold] _ => unreachable!(), + } + } +} diff --git a/src/config.rs b/src/config.rs index a820eb8..b69d0bf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,7 +19,7 @@ pub const DEFAULT_FILE_LOCATION: &'static str = "markov.toml"; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] pub struct Config { - pub bindpoint: SocketAddr, + pub bindpoint: String, pub file: String, pub max_content_length: u64, pub max_gen_size: usize, @@ -65,7 +65,7 @@ impl Default for Config fn default() -> Self { Self { - bindpoint: ([127,0,0,1], 8001).into(), + bindpoint: SocketAddr::from(([127,0,0,1], 8001)).to_string(), file: "chain.dat".to_owned(), max_content_length: 1024 * 1024 * 4, max_gen_size: 256, diff --git a/src/ext.rs b/src/ext.rs index bc44889..ad80d79 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -1,7 +1,14 @@ //! Extensions use std::{ iter, - ops::Range, + ops::{ + Range, + Deref,DerefMut, + }, + marker::{ + PhantomData, + Send, + }, }; pub trait StringJoinExt: Sized @@ -94,3 +101,52 @@ impl TrimInPlace for String self } } + +pub trait MapTuple2 +{ + fn map (V,W)>(self, fun: F) -> (V,W); +} + +impl MapTuple2 for (T,U) +{ + #[inline] fn map (V,W)>(self, fun: F) -> (V,W) + { + fun(self) + } +} + +/// To make sure we don't keep this data across an `await` boundary. +#[repr(transparent)] +pub struct AssertNotSend(pub T, PhantomData<*const T>); + +impl AssertNotSend +{ + pub const fn new(from :T) -> Self + { + Self(from, PhantomData) + } + pub fn into_inner(self) -> T + { + self.0 + } +} + +/// Require a future is Send +#[inline(always)] pub fn require_send(t: T) -> T +{ + t +} + +impl Deref for AssertNotSend +{ + type Target = T; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for AssertNotSend +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/src/main.rs b/src/main.rs index 837069c..80af06f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -81,10 +81,29 @@ mod feed; mod gen; mod sentance; -#[tokio::main] -async fn main() { +const DEFAULT_LOG_LEVEL: &str = "warn"; + +fn init_log() +{ + let level = match std::env::var_os("RUST_LOG") { + None => { + std::env::set_var("RUST_LOG", DEFAULT_LOG_LEVEL); + std::borrow::Cow::Borrowed(std::ffi::OsStr::new(DEFAULT_LOG_LEVEL)) + }, + Some(w) => std::borrow::Cow::Owned(w), + }; pretty_env_logger::init(); + trace!("Initialising `genmarkov` ({}) v{} with log level {:?}.\n\tMade by {} with <3.\n\tLicensed with GPL v3 or later", + std::env::args().next().unwrap(), + env!("CARGO_PKG_VERSION"), + level, + env!("CARGO_PKG_AUTHORS")); +} +#[tokio::main] +async fn main() { + init_log(); + let config = match config::load().await { Some(v) => v, _ => { @@ -237,14 +256,29 @@ async fn main() { #[cfg(target_family="unix")] tasks.push(tokio::spawn(signals::handle(state.clone())).map(|res| res.expect("Signal handler panicked")).boxed()); - let (addr, server) = warp::serve(push - .or(read)) - .bind_with_graceful_shutdown(state.config().bindpoint, async move { - tokio::signal::ctrl_c().await.unwrap(); - state.shutdown(); - }); - info!("Server bound on {:?}", addr); - server.await; + require_send(async { + let server = { + let s2 = AssertNotSend::new(state.clone()); //temp clone the Arcs here for shutdown if server fails to bind, assert they cannot remain cloned across an await boundary. + match bind::try_serve(warp::serve(push + .or(read)), + state.config().bindpoint.clone(), + async move { + tokio::signal::ctrl_c().await.unwrap(); + state.shutdown(); + }) { + Ok((addr, server)) => { + info!("Server bound on {:?}", addr); + server + }, + Err(err) => { + error!("Failed to bind server: {}", err); + s2.into_inner().shutdown(); + return; + }, + } + }; + server.await; + }).await; // Cleanup async move { @@ -255,3 +289,5 @@ async fn main() { }.await; info!("Shut down gracefully") } + +mod bind;