added AsyncWrite and AsyncRead impls to WriteHalf and ReadHalf respectively

no-dual
Avril 3 years ago
parent 0035f38c3c
commit 0ba53ec354
Signed by: flanchan
GPG Key ID: 284488987C31F630

@ -10,3 +10,5 @@ edition = "2018"
tokio = {version = "0.2", features=["full"]}
chacha20stream = {version = "1.0", features=["async"]}
openssl = "0.10.33"
stackalloc = "1.1.0"
pin-project = "1.0.6"

@ -1,4 +1,5 @@
//! Extensions and macros
use std::cell::RefCell;
#[macro_export] macro_rules! basic_enum {
($(#[$meta:meta])* $vis:vis $name:ident $(; $tcomment:literal)?: $($var:ident $(=> $comment:literal)?),+ $(,)?) => {
@ -121,6 +122,32 @@
}
};
($vis:vis $name:ident $(; $comment:literal)?) => {
$crate::bool_type!($vis $name $(; $comment)? => Yes, No);
$crate::bool_type!($vis $name $(; $comment)? => Yes, No);
}
}
/// Max size of bytes we'll allocate to the stack at runtime before using a heap allocated buffer.
pub const STACK_SIZE_LIMIT: usize = 4096;
/// Allocate `size` bytes. Allocates on the stack if size is lower than `STACK_SIZE_LIMIT`, otherwise allocates on the heap.
pub fn alloca_limit<F, T>(size: usize, f: F) -> T
where F: FnOnce(&mut [u8]) -> T
{
if size > STACK_SIZE_LIMIT {
thread_local! {
static BUFFER: RefCell<Vec<u8>> = RefCell::new(vec![0u8; STACK_SIZE_LIMIT*2]);
}
BUFFER.with(move |buf| {
let mut buf = buf.borrow_mut();
if buf.len() < size {
buf.resize(size, 0);
}
f(&mut buf[..size])
})
} else {
stackalloc::alloca_zeroed(size, f)
//
// TODO: Is this okay to do? I'm not sure it is.. We'll see
//stackalloc::alloca(size, move |buf| f(unsafe { stackalloc::helpers::slice_assume_init_mut(buf) }))
}
}

@ -1,6 +1,8 @@
#![allow(dead_code)]
#[macro_use] extern crate pin_project;
// Extensions & macros
#[macro_use] mod ext;
#[allow(unused_imports)] use ext::*;

@ -4,6 +4,12 @@ use tokio::io::{AsyncWrite, AsyncRead};
use std::sync::Arc;
use openssl::symm::Crypter;
use std::{
pin::Pin,
task::{Poll, Context},
io,
};
use crypt::{
RsaPublicKey,
RsaPrivateKey,
@ -30,6 +36,7 @@ where S: AsyncWrite
}
/// Readable half of `EncryptedStream`.
#[pin_project]
pub struct ReadHalf<S>
where S: AsyncRead
{
@ -38,7 +45,7 @@ where S: AsyncRead
/// chacha20_poly1305 decrypter for incoming reads from `S`
//TODO: chacha20stream: implement a read version of AsyncSink so we don't need to keep this?
cipher: Option<Crypter>,
backing_read: Box<S>,
#[pin] backing_read: Box<S>,
}
struct ReadWriteCombined<R, W>
@ -62,6 +69,8 @@ where R: AsyncRead,
}
//TODO: How do we use this with a single AsyncStream instead of requiring 2? Will we need to make our own Arc wrapper?? Ugh,, for now let's ignore this I guess... Most read+write thingies have a Read/WriteHalf split mechanism.
//
// Note that this does actually work fine with things like tokio's `duplex()` (i think)
impl<R: AsyncRead, W: AsyncWrite> EncryptedStream<R, W>
{
/// Has this stream done its RSA key exchange?
@ -103,3 +112,44 @@ impl<R: AsyncRead, W: AsyncWrite> EncryptedStream<R, W>
todo!("Drop write's `meta`, consume read's `meta`. Move the streams into `ReadWriteCombined`")
}
}
impl<S: AsyncWrite> AsyncWrite for WriteHalf<S>
{
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_write(cx, buf)
}
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_flush(cx)
}
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_shutdown(cx)
}
}
impl<S: AsyncRead> AsyncRead for ReadHalf<S>
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let this = self.project();
let cipher = this.cipher.as_mut();
let stream = unsafe {this.backing_read.map_unchecked_mut(|f| f.as_mut())};
let res = stream.poll_read(cx,buf);
if let Some(cipher) = cipher {
// Decrypt the buffer if the read succeeded
res.map(move |res| res.and_then(move |sz| {
alloca_limit(sz, move |obuf| -> io::Result<usize> {
// This `sz` and old `sz` should always be the same.
let sz = cipher.update(&buf[..sz], &mut obuf[..])?;
let _f = cipher.finalize(&mut obuf[..sz])?;
debug_assert_eq!(_f, 0);
// Copy decrypted buffer into output buffer
buf.copy_from_slice(&obuf[..sz]);
Ok(sz)
})
}))
} else {
res
}
}
}

Loading…
Cancel
Save