diff --git a/Cargo.toml b/Cargo.toml index ea14a10..850bb9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ panic = "unwind" tokio = {version = "0.2", features=["full"]} async-trait = "0.1" chrono = {version = "0.4.15", features=["serde"]} -uuid = {version = "0.8", features=["v4", "serde"]} +uuid = { version = "0.8.1", features = ["v4", "serde"] } once_cell = "1.4" crypto = {version = "1.1.2", package= "cryptohelpers", features= ["serialise", "async", "sha256"]} libc = "0.2.76" diff --git a/src/config.rs b/src/config.rs index 7722b8f..bff43a6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,6 +6,9 @@ use std::{ Ipv4Addr, }, }; +use tokio::{ + time, +}; use cidr::Cidr; //TODO: Use tokio Watcher instead, to allow hotreloading? @@ -53,9 +56,13 @@ pub struct Config pub deny_mask: Vec, /// Accept by default pub accept_default: bool, - /// The number of connections allowed to be processed at once on one route pub dos_max: usize, + + /// The timeout for any routing dispatch + pub req_timeout_local: Option, + /// The timeout for *all* routing dispatchs + pub req_timeout_global: Option, } impl Default for Config @@ -70,6 +77,9 @@ impl Default for Config deny_mask: Vec::new(), accept_default: false, dos_max: 16, + + req_timeout_local: Some(time::Duration::from_millis(500)), + req_timeout_global: Some(time::Duration::from_secs(1)), } } } diff --git a/src/main.rs b/src/main.rs index a228cd0..3f29fbc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,10 @@ use color_eyre::{ Help, SectionExt, }; +use futures::{ + FutureExt as _, + prelude::*, +}; mod ext; use ext::*; diff --git a/src/web/error.rs b/src/web/error.rs index 5833135..21462b2 100644 --- a/src/web/error.rs +++ b/src/web/error.rs @@ -10,9 +10,14 @@ use std::{ #[non_exhaustive] pub enum Error { Denied(SocketAddr, bool), + TimeoutReached, + NoResponse, Unknown, } +#[derive(Debug)] +pub struct HandleError; + impl Error { /// Print this error as a warning @@ -31,6 +36,7 @@ impl Error } impl error::Error for Error{} +impl error::Error for HandleError{} impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result @@ -38,7 +44,18 @@ impl fmt::Display for Error match self { Self::Denied(sock, true) => write!(f, "denied connection (explicit): {}", sock), Self::Denied(sock, _) => write!(f, "denied connection (implicit): {}", sock), + Self::TimeoutReached => write!(f, "timeout reached"), + Self::NoResponse => write!(f, "no handler for this request"), _ => write!(f, "unknown error"), } } } + +impl fmt::Display for HandleError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "handle response had already been sent or timed out by the time we tried to access it") + } +} + diff --git a/src/web/mod.rs b/src/web/mod.rs index bf892aa..d6c318b 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,7 +1,14 @@ //! Handle web serving and managing state of web clients use super::*; use std::{ - sync::Arc, + sync::{ + Arc, + Weak, + }, + marker::{ + Send, Sync, + }, + iter, }; use hyper::{ service::{ @@ -22,14 +29,59 @@ use futures::{ use cidr::{ Cidr, }; +use tokio::{ + sync::{ + RwLock, + mpsc, + }, +}; pub mod error; pub mod route; +/// A unique ID generated each time a request is sent through router. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Nonce(uuid::Uuid); + +#[derive(Debug, Clone)] +pub struct Handle +{ + state: Arc, + nonce: Nonce, + req: Arc>, + /// We can let multiple router hooks mutate body if they desire. Such as adding headers, etc. + resp: Arc>>, +} + +impl Handle +{ + /// Attempt to upgrade the response handle into a potentially mutateable `Response`. + /// + /// Function fails if the reference count to the response has expired (i.e. the response has been sent or timed out already) + pub fn access_response(&self) -> Result>>, error::HandleError> + { + Ok(self.resp.clone()) + //self.resp.upgrade().ok_or(error::HandleError) + } + + /// Replace the response with a new one if possible. + /// + /// Fails if `access_response()` fails. + pub async fn set_response(&self, rsp: Response) -> Result, error::HandleError> + { + use std::ops::DerefMut; + match self.access_response() { + Ok(resp) => Ok(std::mem::replace(resp.write().await.deref_mut(), rsp)), + Err(err) => Err(err), + } + } +} + #[derive(Debug)] pub struct State { config: config::Config, + router: RwLock>, } impl State @@ -37,7 +89,8 @@ impl State pub fn new(config: config::Config) -> Self { Self{ - config + config, + router: RwLock::new(route::Router::new()), } } } @@ -47,7 +100,7 @@ impl Default for State #[inline] fn default() -> Self { - Self{config: config::get().clone()} + Self::new(config::get().clone()) } } @@ -62,45 +115,140 @@ fn mask_contains(mask: &[cidr::IpCidr], value: &std::net::IpAddr) -> bool false } +fn handle_test(state: Arc) -> tokio::task::JoinHandle<()> +{ + tokio::task::spawn(async move { + let (hook, mut recv) = { + let mut router = state.router.write().await; + router.hook(None, route::PrefixRouter::new("/hello")) + }; + while let Some((uri, handle)) = recv.recv().await + { + match handle.set_response(Response::builder() + .status(200) + .body(format!("Hello world! You are at {}", uri).into()) + .unwrap()).await { + Ok(_) => (), + Err(e) => { + error!("{}", e); + break; + }, + } + } + { + let mut router = state.router.write().await; + router.unhook(iter::once(hook)); + } + }) +} + async fn handle_conn(state: Arc, req: Request) -> Result, error::Error> { - //TODO: Create client, route, and such - Ok(Response::new("Hi".into())) + let response = Arc::new(RwLock::new(Response::new(Body::empty()))); + + let nonce = Nonce(uuid::Uuid::new_v4()); + let req = Arc::new(req); + let resp_num = { + let resp = Arc::clone(&response); + async { + let mut route = state.router.write().await; + let handle = Handle { + state: state.clone(), + nonce, + req: Arc::clone(&req), + resp, + }; + match route.dispatch(req.method(), req.uri().path(), handle, state.config.req_timeout_local).await { + Ok(num) => { + num + }, + Err((num, _)) => { + num + }, + } + } + }; + tokio::pin!(resp_num); + + match match state.config.req_timeout_global { + Some(timeout) => tokio::time::timeout(timeout, resp_num).await, + None => Ok(resp_num.await), + } { + Ok(0) => { + // No handlers matched this + trace!(" x {}", req.uri().path()); + Ok(Response::builder() + .status(404) + .body("404 not found".into()) + .unwrap()) + }, + Ok(_) => { + let resp = { + let mut resp = response; + + loop { + match Arc::try_unwrap(resp) { + Err(e) => { + resp = e; + tokio::task::yield_now().await; + }, + Ok(n) => break n, + } + } + }; + Ok(resp.into_inner()) + }, + Err(_) => { + // Timeout reached + Err(error::Error::TimeoutReached.info()) + }, + } } pub async fn serve(state: State) -> Result<(), eyre::Report> { - let state = Arc::new(state); - - let service = make_service_fn(|conn: &AddrStream| { - let state = Arc::clone(&state); - - let remote_addr = conn.remote_addr(); - let remote_ip = remote_addr.ip(); - let denied = mask_contains(&state.config.deny_mask[..], &remote_ip); - let allowed = mask_contains(&state.config.accept_mask[..], &remote_ip); - async move { - if denied { - Err(error::Error::Denied(remote_addr, true).warn()) - } else if allowed || state.config.accept_default { - trace!("Accepted conn: {}", remote_addr); - - Ok(service_fn(move |req: Request| { - handle_conn(Arc::clone(&state), req) - })) - } else { - Err(error::Error::Denied(remote_addr,false).info()) - } - } - }); + let h = { + let state = Arc::new(state); - let server = Server::bind(&state.config.listen).serve(service) - .with_graceful_shutdown(async { - tokio::signal::ctrl_c().await.expect("Failed to catch SIGINT"); - info!("Going down for shutdown now!"); + let h = handle_test(state.clone()); + + let service = make_service_fn(|conn: &AddrStream| { + let state = Arc::clone(&state); + + let remote_addr = conn.remote_addr(); + let remote_ip = remote_addr.ip(); + let denied = mask_contains(&state.config.deny_mask[..], &remote_ip); + let allowed = mask_contains(&state.config.accept_mask[..], &remote_ip); + async move { + if denied { + Err(error::Error::Denied(remote_addr, true).warn()) + } else if allowed || state.config.accept_default { + trace!("Accepted conn: {}", remote_addr); + + Ok(service_fn(move |req: Request| { + handle_conn(Arc::clone(&state), req) + })) + } else { + Err(error::Error::Denied(remote_addr,false).info()) + } + } }); - server.await?; + let server = Server::bind(&state.config.listen).serve(service) + .with_graceful_shutdown(async { + tokio::signal::ctrl_c().await.expect("Failed to catch SIGINT"); + info!("Going down for shutdown now!"); + }); + + server.await?; + // remove all handles now + let mut wr= state.router.write().await; + wr.clear(); + + h + }; + trace!("server down"); + h.await?; Ok(()) } diff --git a/src/web/route.rs b/src/web/route.rs index 638cbf3..fc0d67b 100644 --- a/src/web/route.rs +++ b/src/web/route.rs @@ -5,7 +5,10 @@ use hyper::{ }; use std::{ fmt, - marker::Send, + marker::{ + Send, + Sync, + }, iter, }; use tokio::{ @@ -35,6 +38,16 @@ pub trait UriRoute { "" } + + #[inline] fn type_name(&self) -> &str + { + std::any::type_name::() + } + + #[inline] fn mutate_uri(&self, uri: String) -> String + { + uri + } } impl UriRoute for str @@ -67,19 +80,68 @@ impl UriRoute for regex::Regex } } +/// A router for all under a prefix +#[derive(Debug, Clone, PartialEq, Hash)] +pub struct PrefixRouter(String); + +impl PrefixRouter +{ + /// Create a new instance with this string + pub fn new(string: impl Into) -> Self + { + Self(string.into()) + } +} + +impl UriRoute for PrefixRouter +{ + #[inline] fn is_match(&self, uri: &str) -> bool { + uri.starts_with(self.0.as_str()) + } + #[inline] fn as_string(&self) -> &str { + self.0.as_str() + } + + fn mutate_uri(&self, mut uri: String) -> String { + uri.replace_range(..self.0.len(), ""); + uri + } +} + +impl fmt::Display for PrefixRouter +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "{}*", self.0) + } +} + + /// Contains a routing table #[derive(Debug)] -pub struct Router +pub struct Router { - routes: Arena<(Option, OpaqueDebug>, mpsc::Sender)>, + routes: Arena<(Option, OpaqueDebug>, mpsc::Sender<(String, T)>)>, } -impl Router +impl fmt::Display for Router +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "Router {{ routes: ")?; + for (i, (method, route, _)) in self.routes.iter() { + writeln!(f, "\t ({:?} => ({:?}, {} ({:?}))),", i, method, route.type_name(), route.as_string())?; + } + write!(f, "}}") + } +} + +impl Router { /// Create an empty routing table pub fn new() -> Self { - Self{ + Self { routes: Arena::new(), } } @@ -88,17 +150,27 @@ impl Router /// /// # Returns /// The hook's new index, and the receiver that `dispatch()` sends to. - pub fn hook(&mut self, method: Option, uri: Uri) -> (Index, mpsc::Receiver) + pub fn hook(&mut self, method: Option, uri: Uri) -> (Index, mpsc::Receiver<(String, T)>) { let (tx, rx) = mpsc::channel(config::get_or_default().dos_max); (self.routes.insert((method, OpaqueDebug::new(Box::new(uri)), tx)), rx) } + /// Remove all hooks + pub fn clear(&mut self) + { + self.routes.clear(); + } + /// Dispatch the URI location across this router, sending to all that match it. /// + /// # Timeout + /// The timeout is waited on the *individual* dispatches. If you want a global timeout, please timeout the future returned by this function instead. + /// Timed-out dispatches are counted the same as sending errors. + /// /// # Returns /// When one or more dispatchers match but faile, `Err` is returned. Inside the `Err` tuple is the amount of successful dispatches, and also a vector containing the indecies of the failed hook sends. - pub async fn dispatch(&mut self, method: &Method, uri: impl AsRef, timeout: Option) -> Result)> + pub async fn dispatch(&mut self, method: &Method, uri: impl AsRef, nonce: T, timeout: Option) -> Result)> { let string = uri.as_ref(); let mut success=0usize; @@ -109,25 +181,29 @@ impl Router Some(x) if x != method => None, _ => { if route.is_match(string) { - trace!("{:?} @{}: -> {}",i, route.as_string(), string); + trace!("{:?} `{}`: -> {}",i, route.as_string(), string); let timeout = timeout.clone(); + let nonce= nonce.clone(); macro_rules! send { () => { - match timeout { - None => sender.send(string.to_owned()).await - .map_err(|e| SendTimeoutError::Closed(e.0)), - Some(time) => sender.send_timeout(string.to_owned(), time).await + { + let string = route.mutate_uri(string.to_owned()); + match timeout { + None => sender.send((string, nonce)).await + .map_err(|e| SendTimeoutError::Closed(e.0)), + Some(time) => sender.send_timeout((string, nonce), time).await + } } } }; Some(async move { match send!() { Err(SendTimeoutError::Closed(er)) => { - error!("{:?}: Dispatch failed on hooked route for {}", i, er); + error!("{:?}: Dispatch failed on hooked route for `{}`", i, er.0); Err(i) }, Err(SendTimeoutError::Timeout(er)) => { - warn!("{:?}: Dispatch timed out on hooked route for {}", i, er); + warn!("{:?}: Dispatch timed out on hooked route for `{}`", i, er.0); Err(i) }, _ => Ok(()), @@ -152,6 +228,21 @@ impl Router } } + /// Forcefully dispatch `uri` on hook `which`, regardless of method or URI matching. + /// + /// # Returns + /// If `which` is not contained within the table, immediately returns `None`, otherwise returns a future that completes when the dispatch is complete. + /// Note: This future must be `await`ed for the dispatch to happen. + pub fn dispatch_force(&mut self, which: Index, uri: String, nonce: T, timeout: Option) -> Option>> + '_> + { + self.routes.get_mut(which).map(move |(_,_,send)| { + match timeout { + Some(timeout) => send.send_timeout((uri, nonce), timeout).boxed(), + None => send.send((uri, nonce)).map(|res| res.map_err(|e| SendTimeoutError::Closed(e.0))).boxed(), + } + }) + } + /// Attempt to unhook these hooks. If one or more of the provided indecies does not exist in the routing table, it is ignored. pub fn unhook(&mut self, items: I) where I: IntoIterator