added poll_write to EncryptedWriteHalf

no-dual
Avril 3 years ago
parent 2056ff6a58
commit 2ebdf4ac46
Signed by: flanchan
GPG Key ID: 284488987C31F630

@ -12,3 +12,4 @@ chacha20stream = {version = "1.0", features=["async"]}
openssl = "0.10.33"
stackalloc = "1.1.0"
pin-project = "1.0.6"
bytes = "0.5.6"

@ -2,7 +2,10 @@ use super::*;
use tokio::io::{AsyncWrite, AsyncRead};
use std::sync::Arc;
use openssl::symm::Crypter;
use openssl::error::ErrorStack;
use ::bytes::{Buf, BufMut};
use std::{
pin::Pin,
@ -48,8 +51,9 @@ where S: AsyncWrite
}
//TODO: WriteHalf's AsyncWrite impl should just forward to backing_write
#[pin_project]
pub struct EncryptedWriteHalf<'a, S>
where S: AsyncWrite,
where S: AsyncWrite,
{
/// Used to transform input `buf` into `self.crypt_buffer` before polling a write to `backing_write` with the newly filled `self.crypt_buffer`.
/// See below 2 fields.
@ -83,7 +87,61 @@ pub struct EncryptedWriteHalf<'a, S>
/// This exists so we don't have to transform the entire `buf` on every poll. We can just transform it once and then wait until it is `Ready` before discarding the data (`.empty()`) and allowing new data to fill it on the next, fresh `poll_write`.
crypt_buffer: Vec<u8>,
backing: &'a mut WriteHalf<S>,
#[pin] backing: &'a mut WriteHalf<S>,
}
/// **Forcefully** transform `buf` into a transformed buffer.
///
/// # Does **not** do these things
/// Doesn't check for ptr ident with `buf` against `crypt_buf_ptr`. You should do that yourself.
/// Doesn't truncate `crypt_buffer` after transformation.
fn transform_into(crypt_buffer: &mut Vec<u8>, cipher: &mut Crypter, buf: &[u8]) -> Result<usize, ErrorStack>
{
if crypt_buffer.len() < buf.len() {
crypt_buffer.resize(buf.len(), 0);
}
let n = cipher.update(buf, &mut crypt_buffer[..buf.len()])?;
let _f = cipher.finalize(&mut crypt_buffer[..n])?;
debug_assert_eq!(_f, 0);
Ok(n)
}
impl<'a, S: AsyncWrite> EncryptedWriteHalf<'a, S>
{
#[inline(always)] fn forward(self: Pin<&mut Self>) -> Pin<&mut WriteHalf<S>>
{
unsafe {self.map_unchecked_mut(|this| this.backing)}
}
}
impl<'a, S: AsyncWrite> AsyncWrite for EncryptedWriteHalf<'a, S>
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
let this = self.as_mut().project();
if this.crypt_buffer.is_empty() || this.crypt_buf_ptr != buf {
// Transform `buf` into self.crypt_buffer
let n = transform_into(this.crypt_buffer, this.cipher, buf)?;
*this.crypt_buf_ptr = buf.into();
this.crypt_buffer.truncate(n);
} // else { /* No need to transform */ }
let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_write(cx, &this.crypt_buffer[..]);
if poll.is_ready()
{
*this.crypt_buf_ptr = Default::default();
this.crypt_buffer.clear();
}
poll
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
todo!()
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
todo!()
}
}
//TODO: EncryptedWriteHalf's AsyncWrite impl should en/decrypt the input buffer into `crypt_buffer` then send it to `backing`.
@ -101,7 +159,7 @@ where S: AsyncRead
//TODO: ReadHalf's AsyncRead impl should just forward to backing_read,
pub struct EncryptedReadHalf<'a, S>
where S: AsyncRead,
where S: AsyncRead,
{
cipher: Crypter,
backing: &'a mut ReadHalf<S>,
@ -109,15 +167,45 @@ pub struct EncryptedReadHalf<'a, S>
//TODO: EncryptedReadHalf's AsyncRead impl should en/decrypt the read from backing.
impl<S: AsyncRead> AsyncRead for ReadHalf<S>
{
#[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.project().backing_read.poll_read(cx, buf)
}
#[inline] fn poll_read_buf<B: BufMut>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<io::Result<usize>>
where
Self: Sized, {
self.project().backing_read.poll_read_buf(cx, buf)
}
}
impl<S: AsyncWrite> AsyncWrite for WriteHalf<S>
{
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.project().backing_write.poll_write(cx, buf)
}
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().backing_write.poll_flush(cx)
}
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().backing_write.poll_shutdown(cx)
}
#[inline] fn poll_write_buf<B: Buf>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<Result<usize, io::Error>>
where
Self: Sized, {
self.project().backing_write.poll_write_buf(cx, buf)
}
}
//TODO: Rework everything past this point:
/*
struct ReadWriteCombined<R, W>
{
/// Since chacha20stream has no AsyncRead counterpart, we have to do it ourselves.
/// Since chacha20stream has no AsyncRead counterpart, we have to do it ourselves.
cipher_read: Option<Crypter>,
backing_read: R,
backing_write: dual::DualStream<W>,
}

Loading…
Cancel
Save