From 372c7330662772ece195548c0e6c1a542f4afc2f Mon Sep 17 00:00:00 2001 From: Avril Date: Tue, 20 Apr 2021 01:21:32 +0100 Subject: [PATCH] impl AsyncRead for EncryptedReadHalf completed impl AsyncWrite for EncryptedWriteHalf --- src/stream.rs | 50 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index ce2640b..e9a6010 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -49,7 +49,6 @@ where S: AsyncWrite #[pin] backing_write: S,//Box>, } -//TODO: WriteHalf's AsyncWrite impl should just forward to backing_write #[pin_project] pub struct EncryptedWriteHalf<'a, S> @@ -137,13 +136,27 @@ impl<'a, S: AsyncWrite> AsyncWrite for EncryptedWriteHalf<'a, S> poll } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - todo!("TODO: Copy the impl of chacha20stream's `AsyncSink::flush()`") + let this = self.project(); + + let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_flush(cx); + if poll.is_ready() { + this.crypt_buffer.clear(); + *this.crypt_buf_ptr = Default::default(); + } + poll } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - todo!("TODO: Copy the impl of chacha20stream's `AsyncSink::shutdown()`") + let this = self.project(); + + let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_shutdown(cx); + if poll.is_ready() { + bytes::blank(&mut this.crypt_buffer[..]); + this.crypt_buffer.clear(); + *this.crypt_buf_ptr = Default::default(); + } + poll } } -//TODO: EncryptedWriteHalf's AsyncWrite impl should en/decrypt the input buffer into `crypt_buffer` then send it to `backing`. /// Readable half of `EncryptedStream`. #[pin_project] @@ -156,15 +169,38 @@ where S: AsyncRead #[pin] backing_read: S, } -//TODO: ReadHalf's AsyncRead impl should just forward to backing_read, +#[pin_project] pub struct EncryptedReadHalf<'a, S> where S: AsyncRead, { cipher: Crypter, - backing: &'a mut ReadHalf, + #[pin] backing: &'a mut ReadHalf, +} + +impl<'a, S: AsyncRead> AsyncRead for EncryptedReadHalf<'a, S> +{ + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + let this = self.project(); + let cipher = this.cipher; + let stream = unsafe {this.backing.map_unchecked_mut(|f| &mut f.backing_read)}; + + let res = stream.poll_read(cx,buf); + // Decrypt the buffer if the read succeeded + res.map(move |res| res.and_then(move |sz| { + alloca_limit(sz, move |obuf| -> io::Result { + // This `sz` and old `sz` should always be the same. + let sz = cipher.update(&buf[..sz], &mut obuf[..])?; + let _f = cipher.finalize(&mut obuf[..sz])?; + debug_assert_eq!(_f, 0); + + // Copy decrypted buffer into output buffer + buf.copy_from_slice(&obuf[..sz]); + Ok(sz) + }) + })) + } } -//TODO: EncryptedReadHalf's AsyncRead impl should en/decrypt the read from backing. impl AsyncRead for ReadHalf