accepts AF_UNIX

serve
Avril 4 years ago
parent ca16c97629
commit ef5dc3cbf1
Signed by: flanchan
GPG Key ID: 284488987C31F630

2
Cargo.lock generated

@ -639,7 +639,7 @@ dependencies = [
[[package]]
name = "markov"
version = "0.6.3"
version = "0.7.0"
dependencies = [
"async-compression",
"cfg-if 1.0.0",

@ -1,9 +1,10 @@
[package]
name = "markov"
version = "0.6.4"
version = "0.7.0"
description = "Generate string of text from Markov chain fed by stdin"
authors = ["Avril <flanchan@cumallover.me>"]
edition = "2018"
license = "gpl-3.0-or-later"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -16,7 +17,7 @@ compress-chain = ["async-compression"]
# Treat each new line as a new set to feed instead of feeding the whole data at once
split-newlines = []
# Feed each sentance seperately, instead of just each line / whole body
# Feed each sentance seperately with default /get api, instead of just each line / whole body
# Maybe better without `split-newlines`?
# Kinda experimental
split-sentance = []

@ -1,3 +0,0 @@
Maybe see if `split-sentance` is stable enough to be enabled in prod now?
Allow Unix domain socket for bind

@ -2,9 +2,9 @@ bindpoint = '127.0.0.1:8001'
file = 'chain.dat'
max_content_length = 4194304
max_gen_size = 256
#save_interval_secs = 2
save_interval_secs = 2
trust_x_forwarded_for = false
[filter]
inbound = "<>)([]/"
outbound = "*"
inbound = ''
outbound = ''

@ -0,0 +1,171 @@
//! For binding to sockets
use super::*;
use futures::{
prelude::*,
};
use std::{
marker::{
Send,
Unpin,
},
fmt,
error,
path::{
Path,
PathBuf,
},
};
use tokio::{
io::{
self,
AsyncRead,
AsyncWrite,
},
};
#[derive(Debug)]
pub enum BindError<E>
{
IO(io::Error),
Warp(warp::Error),
Other(E),
}
impl<E: error::Error + 'static> error::Error for BindError<E>
{
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
Some(match &self {
Self::IO(io) => io,
Self::Other(o) => o,
Self::Warp(w) => w,
})
}
}
impl<E: fmt::Display> fmt::Display for BindError<E>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
match self {
Self::IO(io) => write!(f, "io error: {}", io),
Self::Other(other) => write!(f, "{}", other),
Self::Warp(warp) => write!(f, "server error: {}", warp),
}
}
}
#[derive(Debug)]
pub struct BindpointParseError;
impl error::Error for BindpointParseError{}
impl fmt::Display for BindpointParseError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "Failed to parse bindpoint as IP or unix domain socket")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd)]
pub enum Bindpoint
{
Unix(PathBuf),
TCP(SocketAddr),
}
impl fmt::Display for Bindpoint
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
match self {
Self::Unix(unix) => write!(f, "unix:/{}", unix.to_string_lossy()),
Self::TCP(tcp) => write!(f, "{}", tcp),
}
}
}
impl std::str::FromStr for Bindpoint
{
type Err = BindpointParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if let Ok(ip) = s.parse::<SocketAddr>() {
Self::TCP(ip)
} else if s.starts_with("unix:/") {
Self::Unix(PathBuf::from(&s[6..].to_owned()))
} else {
return Err(BindpointParseError);
})
}
}
fn bind_unix(to: impl AsRef<Path>) -> io::Result<impl TryStream<Ok= impl AsyncRead + AsyncWrite + Send + Unpin + 'static + Send, Error = impl Into<Box<dyn std::error::Error + Send + Sync>>>>
{
debug!("Binding to AF_UNIX: {:?}", to.as_ref());
let listener = tokio::net::UnixListener::bind(to)?;
Ok(listener)
}
pub fn serve<F>(server: warp::Server<F>, bind: Bindpoint, signal: impl Future<Output=()> + Send + 'static) -> Result<(Bindpoint, BoxFuture<'static, ()>), BindError<std::convert::Infallible>>
where F: Filter + Clone + Send + Sync + 'static,
<F::Future as TryFuture>::Ok: warp::Reply,
{
Ok(match bind {
Bindpoint::TCP(sock) => server.try_bind_with_graceful_shutdown(sock, signal).map(|(sock, fut)| (Bindpoint::TCP(sock), fut.boxed())).map_err(BindError::Warp)?,
Bindpoint::Unix(unix) => {
(Bindpoint::Unix(unix.clone()),
server.serve_incoming_with_graceful_shutdown(bind_unix(unix).map_err(BindError::IO)?, signal).boxed())
},
})
}
impl From<SocketAddr> for Bindpoint
{
fn from(from: SocketAddr) -> Self
{
Self::TCP(from)
}
}
pub fn try_serve<F>(server: warp::Server<F>, bind: impl TryBindpoint, signal: impl Future<Output=()> + Send + 'static) -> Result<(Bindpoint, BoxFuture<'static, ()>), BindError<impl error::Error + 'static>>
where F: Filter + Clone + Send + Sync + 'static,
<F::Future as TryFuture>::Ok: warp::Reply,
{
serve(server, bind.try_parse().map_err(BindError::Other)?, signal).map_err(BindError::coerce)
}
pub trait TryBindpoint: Sized
{
type Err: error::Error + 'static;
fn try_parse(self) -> Result<Bindpoint, Self::Err>;
}
impl TryBindpoint for Bindpoint
{
type Err = std::convert::Infallible;
fn try_parse(self) -> Result<Bindpoint, Self::Err>
{
Ok(self)
}
}
impl<T: AsRef<str>> TryBindpoint for T
{
type Err = BindpointParseError;
fn try_parse(self) -> Result<Bindpoint, Self::Err>
{
self.as_ref().parse()
}
}
impl BindError<std::convert::Infallible>
{
pub fn coerce<T>(self) -> BindError<T>
{
match self {
Self::Warp(w) => BindError::Warp(w),
Self::IO(w) => BindError::IO(w),
#[cold] _ => unreachable!(),
}
}
}

@ -19,7 +19,7 @@ pub const DEFAULT_FILE_LOCATION: &'static str = "markov.toml";
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)]
pub struct Config
{
pub bindpoint: SocketAddr,
pub bindpoint: String,
pub file: String,
pub max_content_length: u64,
pub max_gen_size: usize,
@ -65,7 +65,7 @@ impl Default for Config
fn default() -> Self
{
Self {
bindpoint: ([127,0,0,1], 8001).into(),
bindpoint: SocketAddr::from(([127,0,0,1], 8001)).to_string(),
file: "chain.dat".to_owned(),
max_content_length: 1024 * 1024 * 4,
max_gen_size: 256,

@ -1,7 +1,14 @@
//! Extensions
use std::{
iter,
ops::Range,
ops::{
Range,
Deref,DerefMut,
},
marker::{
PhantomData,
Send,
},
};
pub trait StringJoinExt: Sized
@ -94,3 +101,52 @@ impl TrimInPlace for String
self
}
}
pub trait MapTuple2<T,U>
{
fn map<V,W, F: FnOnce((T,U)) -> (V,W)>(self, fun: F) -> (V,W);
}
impl<T,U> MapTuple2<T,U> for (T,U)
{
#[inline] fn map<V,W, F: FnOnce((T,U)) -> (V,W)>(self, fun: F) -> (V,W)
{
fun(self)
}
}
/// To make sure we don't keep this data across an `await` boundary.
#[repr(transparent)]
pub struct AssertNotSend<T>(pub T, PhantomData<*const T>);
impl<T> AssertNotSend<T>
{
pub const fn new(from :T) -> Self
{
Self(from, PhantomData)
}
pub fn into_inner(self) -> T
{
self.0
}
}
/// Require a future is Send
#[inline(always)] pub fn require_send<T: Send>(t: T) -> T
{
t
}
impl<T> Deref for AssertNotSend<T>
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for AssertNotSend<T>
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

@ -81,10 +81,29 @@ mod feed;
mod gen;
mod sentance;
#[tokio::main]
async fn main() {
const DEFAULT_LOG_LEVEL: &str = "warn";
fn init_log()
{
let level = match std::env::var_os("RUST_LOG") {
None => {
std::env::set_var("RUST_LOG", DEFAULT_LOG_LEVEL);
std::borrow::Cow::Borrowed(std::ffi::OsStr::new(DEFAULT_LOG_LEVEL))
},
Some(w) => std::borrow::Cow::Owned(w),
};
pretty_env_logger::init();
trace!("Initialising `genmarkov` ({}) v{} with log level {:?}.\n\tMade by {} with <3.\n\tLicensed with GPL v3 or later",
std::env::args().next().unwrap(),
env!("CARGO_PKG_VERSION"),
level,
env!("CARGO_PKG_AUTHORS"));
}
#[tokio::main]
async fn main() {
init_log();
let config = match config::load().await {
Some(v) => v,
_ => {
@ -237,14 +256,29 @@ async fn main() {
#[cfg(target_family="unix")]
tasks.push(tokio::spawn(signals::handle(state.clone())).map(|res| res.expect("Signal handler panicked")).boxed());
let (addr, server) = warp::serve(push
.or(read))
.bind_with_graceful_shutdown(state.config().bindpoint, async move {
tokio::signal::ctrl_c().await.unwrap();
state.shutdown();
});
info!("Server bound on {:?}", addr);
server.await;
require_send(async {
let server = {
let s2 = AssertNotSend::new(state.clone()); //temp clone the Arcs here for shutdown if server fails to bind, assert they cannot remain cloned across an await boundary.
match bind::try_serve(warp::serve(push
.or(read)),
state.config().bindpoint.clone(),
async move {
tokio::signal::ctrl_c().await.unwrap();
state.shutdown();
}) {
Ok((addr, server)) => {
info!("Server bound on {:?}", addr);
server
},
Err(err) => {
error!("Failed to bind server: {}", err);
s2.into_inner().shutdown();
return;
},
}
};
server.await;
}).await;
// Cleanup
async move {
@ -255,3 +289,5 @@ async fn main() {
}.await;
info!("Shut down gracefully")
}
mod bind;

Loading…
Cancel
Save