diff --git a/Cargo.toml b/Cargo.toml index a45cb0e..7eb67fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,5 @@ edition = "2018" tokio = {version = "0.2", features=["full"]} chacha20stream = {version = "1.0", features=["async"]} openssl = "0.10.33" +stackalloc = "1.1.0" +pin-project = "1.0.6" diff --git a/src/ext.rs b/src/ext.rs index 8dd0645..a7777d4 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -1,4 +1,5 @@ //! Extensions and macros +use std::cell::RefCell; #[macro_export] macro_rules! basic_enum { ($(#[$meta:meta])* $vis:vis $name:ident $(; $tcomment:literal)?: $($var:ident $(=> $comment:literal)?),+ $(,)?) => { @@ -121,6 +122,32 @@ } }; ($vis:vis $name:ident $(; $comment:literal)?) => { - $crate::bool_type!($vis $name $(; $comment)? => Yes, No); + $crate::bool_type!($vis $name $(; $comment)? => Yes, No); + } +} + +/// Max size of bytes we'll allocate to the stack at runtime before using a heap allocated buffer. +pub const STACK_SIZE_LIMIT: usize = 4096; + +/// Allocate `size` bytes. Allocates on the stack if size is lower than `STACK_SIZE_LIMIT`, otherwise allocates on the heap. +pub fn alloca_limit(size: usize, f: F) -> T +where F: FnOnce(&mut [u8]) -> T +{ + if size > STACK_SIZE_LIMIT { + thread_local! { + static BUFFER: RefCell> = RefCell::new(vec![0u8; STACK_SIZE_LIMIT*2]); + } + BUFFER.with(move |buf| { + let mut buf = buf.borrow_mut(); + if buf.len() < size { + buf.resize(size, 0); + } + f(&mut buf[..size]) + }) + } else { + stackalloc::alloca_zeroed(size, f) + // + // TODO: Is this okay to do? I'm not sure it is.. We'll see + //stackalloc::alloca(size, move |buf| f(unsafe { stackalloc::helpers::slice_assume_init_mut(buf) })) } } diff --git a/src/lib.rs b/src/lib.rs index 950f8f6..74106d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ #![allow(dead_code)] +#[macro_use] extern crate pin_project; + // Extensions & macros #[macro_use] mod ext; #[allow(unused_imports)] use ext::*; diff --git a/src/stream.rs b/src/stream.rs index 9e5d3f4..f13e317 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,6 +4,12 @@ use tokio::io::{AsyncWrite, AsyncRead}; use std::sync::Arc; use openssl::symm::Crypter; +use std::{ + pin::Pin, + task::{Poll, Context}, + io, +}; + use crypt::{ RsaPublicKey, RsaPrivateKey, @@ -30,6 +36,7 @@ where S: AsyncWrite } /// Readable half of `EncryptedStream`. +#[pin_project] pub struct ReadHalf where S: AsyncRead { @@ -38,7 +45,7 @@ where S: AsyncRead /// chacha20_poly1305 decrypter for incoming reads from `S` //TODO: chacha20stream: implement a read version of AsyncSink so we don't need to keep this? cipher: Option, - backing_read: Box, + #[pin] backing_read: Box, } struct ReadWriteCombined @@ -62,6 +69,8 @@ where R: AsyncRead, } //TODO: How do we use this with a single AsyncStream instead of requiring 2? Will we need to make our own Arc wrapper?? Ugh,, for now let's ignore this I guess... Most read+write thingies have a Read/WriteHalf split mechanism. +// +// Note that this does actually work fine with things like tokio's `duplex()` (i think) impl EncryptedStream { /// Has this stream done its RSA key exchange? @@ -103,3 +112,44 @@ impl EncryptedStream todo!("Drop write's `meta`, consume read's `meta`. Move the streams into `ReadWriteCombined`") } } + +impl AsyncWrite for WriteHalf +{ + #[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_write(cx, buf) + } + #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_flush(cx) + } + #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_shutdown(cx) + } +} + +impl AsyncRead for ReadHalf +{ + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + let this = self.project(); + let cipher = this.cipher.as_mut(); + let stream = unsafe {this.backing_read.map_unchecked_mut(|f| f.as_mut())}; + + let res = stream.poll_read(cx,buf); + if let Some(cipher) = cipher { + // Decrypt the buffer if the read succeeded + res.map(move |res| res.and_then(move |sz| { + alloca_limit(sz, move |obuf| -> io::Result { + // This `sz` and old `sz` should always be the same. + let sz = cipher.update(&buf[..sz], &mut obuf[..])?; + let _f = cipher.finalize(&mut obuf[..sz])?; + debug_assert_eq!(_f, 0); + + // Copy decrypted buffer into output buffer + buf.copy_from_slice(&obuf[..sz]); + Ok(sz) + }) + })) + } else { + res + } + } +}