You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
genmarkov/src/api.rs

100 lines
2.3 KiB

//! For API calls if enabled
use super::*;
use std::{
fmt,
error,
iter,
convert::Infallible,
};
use futures::{
stream::{
self,
BoxStream,
StreamExt,
},
};
#[inline] fn aggregate(mut body: impl Buf) -> Result<String, std::str::Utf8Error>
{
/*let mut output = Vec::new();
while body.has_remaining() {
let bytes = body.bytes();
output.extend_from_slice(&bytes[..]);
let cnt = bytes.len();
body.advance(cnt);
}*/
std::str::from_utf8(&body.to_bytes()).map(ToOwned::to_owned)
}
pub async fn single(host: IpAddr, num: Option<usize>, body: impl Buf) -> Result<impl warp::Reply, warp::reject::Rejection>
{
single_stream(host, num, body).await
.map(|rx| Response::new(Body::wrap_stream(rx.map(move |x| {
info!("{} <- {:?}", host, x);
x
}))))
.map_err(warp::reject::custom)
}
//TODO: Change to stream impl like normal `feed` has, instead of taking aggregate?
async fn single_stream(host: IpAddr, num: Option<usize>, body: impl Buf) -> Result<BoxStream<'static, Result<String, Infallible>>, ApiError>
{
let body = aggregate(body)?;
info!("{} <- {:?}", host, &body[..]);
let mut chain = Chain::new();
if_debug! {
let timer = std::time::Instant::now();
}
cfg_if! {
if #[cfg(feature="split-newlines")] {
for body in body.split('\n').filter(|line| !line.trim().is_empty()) {
feed::feed(&mut chain, body, 1..);
}
}else {
feed::feed(&mut chain, body, 1..);
}
}
if_debug!{
trace!("Write took {}ms", timer.elapsed().as_millis());
}
match num {
None => Ok(stream::iter(iter::once(Ok(chain.generate_str()))).boxed()),
Some(num) => {
let (mut tx, rx) = mpsc::channel(num);
tokio::spawn(async move {
for string in chain.str_iter_for(num) {
tx.send(string).await.expect("Failed to send string to body");
}
});
Ok(StreamExt::map(rx, |x| Ok::<_, Infallible>(x)).boxed())
}
}
}
#[derive(Debug)]
pub enum ApiError {
Body,
}
impl warp::reject::Reject for ApiError{}
impl error::Error for ApiError{}
impl std::fmt::Display for ApiError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
match self {
Self::Body => write!(f, "invalid data in request body"),
}
}
}
impl From<std::str::Utf8Error> for ApiError
{
fn from(_: std::str::Utf8Error) -> Self
{
Self::Body
}
}