From 2ebdf4ac46c3f709bf13d88e06627f60f7f997a9 Mon Sep 17 00:00:00 2001 From: Avril Date: Sat, 17 Apr 2021 01:45:07 +0100 Subject: [PATCH] added poll_write to EncryptedWriteHalf --- Cargo.toml | 1 + src/stream.rs | 98 ++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7eb67fc..2615c98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,4 @@ chacha20stream = {version = "1.0", features=["async"]} openssl = "0.10.33" stackalloc = "1.1.0" pin-project = "1.0.6" +bytes = "0.5.6" \ No newline at end of file diff --git a/src/stream.rs b/src/stream.rs index ebedbd2..0d56051 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -2,7 +2,10 @@ use super::*; use tokio::io::{AsyncWrite, AsyncRead}; use std::sync::Arc; + use openssl::symm::Crypter; +use openssl::error::ErrorStack; +use ::bytes::{Buf, BufMut}; use std::{ pin::Pin, @@ -48,8 +51,9 @@ where S: AsyncWrite } //TODO: WriteHalf's AsyncWrite impl should just forward to backing_write +#[pin_project] pub struct EncryptedWriteHalf<'a, S> - where S: AsyncWrite, +where S: AsyncWrite, { /// Used to transform input `buf` into `self.crypt_buffer` before polling a write to `backing_write` with the newly filled `self.crypt_buffer`. /// See below 2 fields. @@ -83,7 +87,61 @@ pub struct EncryptedWriteHalf<'a, S> /// This exists so we don't have to transform the entire `buf` on every poll. We can just transform it once and then wait until it is `Ready` before discarding the data (`.empty()`) and allowing new data to fill it on the next, fresh `poll_write`. crypt_buffer: Vec, - backing: &'a mut WriteHalf, + #[pin] backing: &'a mut WriteHalf, +} +/// **Forcefully** transform `buf` into a transformed buffer. +/// +/// # Does **not** do these things +/// Doesn't check for ptr ident with `buf` against `crypt_buf_ptr`. You should do that yourself. +/// Doesn't truncate `crypt_buffer` after transformation. +fn transform_into(crypt_buffer: &mut Vec, cipher: &mut Crypter, buf: &[u8]) -> Result +{ + if crypt_buffer.len() < buf.len() { + crypt_buffer.resize(buf.len(), 0); + } + let n = cipher.update(buf, &mut crypt_buffer[..buf.len()])?; + let _f = cipher.finalize(&mut crypt_buffer[..n])?; + debug_assert_eq!(_f, 0); + + Ok(n) +} +impl<'a, S: AsyncWrite> EncryptedWriteHalf<'a, S> +{ + #[inline(always)] fn forward(self: Pin<&mut Self>) -> Pin<&mut WriteHalf> + { + unsafe {self.map_unchecked_mut(|this| this.backing)} + } +} + +impl<'a, S: AsyncWrite> AsyncWrite for EncryptedWriteHalf<'a, S> +{ + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = self.as_mut().project(); + + if this.crypt_buffer.is_empty() || this.crypt_buf_ptr != buf { + // Transform `buf` into self.crypt_buffer + + let n = transform_into(this.crypt_buffer, this.cipher, buf)?; + *this.crypt_buf_ptr = buf.into(); + + this.crypt_buffer.truncate(n); + } // else { /* No need to transform */ } + let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_write(cx, &this.crypt_buffer[..]); + + if poll.is_ready() + { + *this.crypt_buf_ptr = Default::default(); + this.crypt_buffer.clear(); + } + + poll + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!() + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!() + } } //TODO: EncryptedWriteHalf's AsyncWrite impl should en/decrypt the input buffer into `crypt_buffer` then send it to `backing`. @@ -101,7 +159,7 @@ where S: AsyncRead //TODO: ReadHalf's AsyncRead impl should just forward to backing_read, pub struct EncryptedReadHalf<'a, S> - where S: AsyncRead, +where S: AsyncRead, { cipher: Crypter, backing: &'a mut ReadHalf, @@ -109,15 +167,45 @@ pub struct EncryptedReadHalf<'a, S> //TODO: EncryptedReadHalf's AsyncRead impl should en/decrypt the read from backing. +impl AsyncRead for ReadHalf +{ + #[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + self.project().backing_read.poll_read(cx, buf) + } + #[inline] fn poll_read_buf(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll> + where + Self: Sized, { + self.project().backing_read.poll_read_buf(cx, buf) + } +} + +impl AsyncWrite for WriteHalf +{ + #[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.project().backing_write.poll_write(cx, buf) + } + #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().backing_write.poll_flush(cx) + } + #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().backing_write.poll_shutdown(cx) + } + #[inline] fn poll_write_buf(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll> + where + Self: Sized, { + self.project().backing_write.poll_write_buf(cx, buf) + } +} + //TODO: Rework everything past this point: /* struct ReadWriteCombined { - /// Since chacha20stream has no AsyncRead counterpart, we have to do it ourselves. +/// Since chacha20stream has no AsyncRead counterpart, we have to do it ourselves. cipher_read: Option, backing_read: R, - + backing_write: dual::DualStream, }