You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
yuurei/src/web/mod.rs

270 lines
5.8 KiB

//! Handle web serving and managing state of web clients
use super::*;
use std::{
sync::{
Arc,
Weak,
},
marker::{
Send, Sync,
},
iter,
};
use hyper::{
service::{
make_service_fn,
service_fn,
},
server::{
Server,
conn::AddrStream,
},
Request,
Response,
Body,
};
use futures::{
TryStreamExt as _,
};
use cidr::{
Cidr,
};
use tokio::{
sync::{
RwLock,
mpsc,
},
};
pub mod error;
pub mod route;
/// A unique ID generated each time a request is sent through router.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Nonce(uuid::Uuid);
#[derive(Debug, Clone)]
pub struct Handle
{
state: Arc<State>,
nonce: Nonce,
req: Arc<Request<Body>>,
/// We can let multiple router hooks mutate body if they desire. Such as adding headers, etc.
resp: Arc<RwLock<Response<Body>>>,
}
impl Handle
{
/// Attempt to upgrade the response handle into a potentially mutateable `Response`.
///
/// Function fails if the reference count to the response has expired (i.e. the response has been sent or timed out already)
pub fn access_response(&self) -> Result<Arc<RwLock<Response<Body>>>, error::HandleError>
{
Ok(self.resp.clone())
//self.resp.upgrade().ok_or(error::HandleError)
}
/// Replace the response with a new one if possible.
///
/// Fails if `access_response()` fails.
pub async fn set_response(&self, rsp: Response<Body>) -> Result<Response<Body>, error::HandleError>
{
use std::ops::DerefMut;
match self.access_response() {
Ok(resp) => Ok(std::mem::replace(resp.write().await.deref_mut(), rsp)),
Err(err) => Err(err),
}
}
}
/// Contains all web-server state
#[derive(Debug)]
pub struct State
{
config: config::Config,
router: RwLock<route::Router<Handle>>,
}
impl State
{
/// Create a new state with this specific config instance.
///
/// # Notes
/// You'll almost always want to use the *global* config instance, in which case use `default()` to create this.
pub fn new(config: config::Config) -> Self
{
Self{
config,
router: RwLock::new(route::Router::new()),
}
}
}
impl Default for State
{
#[inline]
fn default() -> Self
{
Self::new(config::get().clone())
}
}
fn mask_contains(mask: &[cidr::IpCidr], value: &std::net::IpAddr) -> bool
{
for mask in mask.iter()
{
if mask.contains(value) {
return true;
}
}
false
}
fn handle_test(state: Arc<State>) -> tokio::task::JoinHandle<()>
{
tokio::task::spawn(async move {
let (hook, mut recv) = {
let mut router = state.router.write().await;
router.hook(None, route::PrefixRouter::new("/hello"))
};
while let Some((uri, handle)) = recv.recv().await
{
match handle.set_response(Response::builder()
.status(200)
.body(format!("Hello world! You are at {}", uri).into())
.unwrap()).await {
Ok(_) => (),
Err(e) => {
error!("{}", e);
break;
},
}
}
{
let mut router = state.router.write().await;
router.unhook(iter::once(hook));
}
})
}
async fn handle_conn(state: Arc<State>, req: Request<Body>) -> Result<Response<Body>, error::Error>
{
let response = Arc::new(RwLock::new(Response::new(Body::empty())));
let nonce = Nonce(uuid::Uuid::new_v4());
let req = Arc::new(req);
let resp_num = {
let resp = Arc::clone(&response);
async {
let mut route = state.router.write().await;
let handle = Handle {
state: state.clone(),
nonce,
req: Arc::clone(&req),
resp,
};
match route.dispatch(req.method(), req.uri().path(), handle, state.config.req_timeout_local).await {
Ok(num) => {
num
},
Err((num, _)) => {
num
},
}
}
};
tokio::pin!(resp_num);
match match state.config.req_timeout_global {
Some(timeout) => tokio::time::timeout(timeout, resp_num).await,
None => Ok(resp_num.await),
} {
Ok(0) => {
// No handlers matched this
trace!(" x {}", req.uri().path());
Ok(Response::builder()
.status(404)
.body("404 not found".into())
.unwrap())
},
Ok(_) => {
let resp = {
let mut resp = response;
loop {
match Arc::try_unwrap(resp) {
Err(e) => {
resp = e;
tokio::task::yield_now().await;
},
Ok(n) => break n,
}
}
};
Ok(resp.into_inner())
},
Err(_) => {
// Timeout reached
Err(error::Error::TimeoutReached.info())
},
}
}
pub async fn serve(state: State) -> Result<(), eyre::Report>
{
cfg_debug!(if {
if &state.config != config::get() {
panic!("Our config is not the same as global? This is unsound.");
}
} else {
if &state.config != config::get() {
warn!("Our config is not the same as global? This is unsound.");
}
});
let h = {
let state = Arc::new(state);
let h = handle_test(state.clone());
let service = make_service_fn(|conn: &AddrStream| {
let state = Arc::clone(&state);
let remote_addr = conn.remote_addr();
let remote_ip = remote_addr.ip();
let denied = mask_contains(&state.config.deny_mask[..], &remote_ip);
let allowed = mask_contains(&state.config.accept_mask[..], &remote_ip);
async move {
if denied {
Err(error::Error::Denied(remote_addr, true).warn())
} else if allowed || state.config.accept_default {
trace!("Accepted conn: {}", remote_addr);
Ok(service_fn(move |req: Request<Body>| {
handle_conn(Arc::clone(&state), req)
}))
} else {
Err(error::Error::Denied(remote_addr,false).info())
}
}
});
let server = Server::bind(&state.config.listen).serve(service)
.with_graceful_shutdown(async {
tokio::signal::ctrl_c().await.expect("Failed to catch SIGINT");
info!("Going down for shutdown now!");
});
server.await?;
// remove all handles now
let mut wr= state.router.write().await;
wr.clear();
h
};
trace!("server down");
h.await?;
Ok(())
}