You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

346 lines
12 KiB

use super::*;
use tokio::io::{AsyncWrite, AsyncRead};
use std::sync::Arc;
use openssl::symm::Crypter;
use openssl::error::ErrorStack;
use ::bytes::{Buf, BufMut};
use std::{
pin::Pin,
task::{Poll, Context},
io,
};
use crypt::{
RsaPublicKey,
RsaPrivateKey,
};
/// A type that implements both `AsyncWrite` and `AsyncRead`
pub trait AsyncStream: AsyncRead + AsyncWrite{}
impl<T: AsyncRead + AsyncWrite + ?Sized> AsyncStream for T{}
/// Inner rsa data for encrypted stream read+write halves
///
/// # Exchange / mutation
/// For split streams, this becomes immutable. If exchange has not been performed by the combined stream before splitting, then it is impossible for the split Read and Write halves to form EncryptedRead and EncryptedWrite instances on top of themselves.
/// The stream must be re-joined, exchanged, and then split again in this case.
/// Therefore exchange should happen before the original stream is split at all.
///
/// Only the combined stream can mutate this structure. The halves hold it behind an immutable shared reference.
struct EncryptedStreamMeta
{
us: RsaPrivateKey,
them: Option<RsaPublicKey>,
}
/// Writable half of `EncryptedStream`.
#[pin_project]
pub struct WriteHalf<S>
where S: AsyncWrite
{
/// Shared reference to the RSA data of the backing stream, held by both Write and Read halves.
///
/// # Immutability of this metadata
/// Exchange can only happen on the combined Read+Write stream, so we don't need ayn mutability of `meta` here. Mutating `meta` happens only when it's owned by the combined stream (not in an `Arc`, which is only used to share it between the Read and Write half).
meta: Arc<EncryptedStreamMeta>,
#[pin] backing_write: S,//Box<dual::DualStream<S>>,
}
#[pin_project]
pub struct EncryptedWriteHalf<'a, S>
where S: AsyncWrite,
{
/// Used to transform input `buf` into `self.crypt_buffer` before polling a write to `backing_write` with the newly filled `self.crypt_buffer`.
/// See below 2 fields.
cipher: Crypter,
/// Slice pointer of the input `buf` that corresponds to the transformed data in `crypt_buf`.
/// Used to check if a `Pending` write was cancelled, by comparing if the input `buf` slice of this next write is different from the last one (which's data is stored in this field after the poll becomes `Pending`.)
///
/// # Usage
/// Before checking is `crypt_buffer` is empty and that we should re-poll the backing stream with it, we check the input `buf` against this value.
/// If they differ, then the `Pending` result from the last poll was discarded, and we clear the `crypt_buffer` and re-encrypt the new `buf` into it.
///
/// After a `Pending` write to `backing_write`, a `SliceMeta` from the input `buf` is written to this field.
/// If it was *not* a `Pending` poll result, then this field is re-set to `Default` (an invalid `null` value).
///
/// This compares **pointer** and **length** identity of the slice. (See `SliceMeta` for more information.)
/// Which is a faster method of determining if the buffer has changed than applying `Hash` to the whole buffer each time `poll_write` is called just to compare.
///
/// # Initialised
/// Initialised as `Default` (`null`).
/// Will be `null` if `crypt_buffer` is empty (i.e. a non-`Pending` poll result).
crypt_buf_ptr: SliceMeta<u8>,
/// Buffer written to when encrypting the input `buf`.
///
/// It is cleared after a `Ready` `poll_write()` on `backing_write`.
/// On a `Pending` write, this buffer is resized to only fit the transformed data written to it, and left as it is until the next call to `poll_write`.
///
/// If the poll was not discarded (see above field), on the next call to this instance's `poll_write` we just immediatly re-poll `backing_write` with this buffer.
/// If it was disarded. We re-set this buffer and transform the new input `buf` into it as if the previous poll returned `Ready`.
///
/// This exists so we don't have to transform the entire `buf` on every poll. We can just transform it once and then wait until it is `Ready` before discarding the data (`.empty()`) and allowing new data to fill it on the next, fresh `poll_write`.
crypt_buffer: Vec<u8>,
#[pin] backing: &'a mut WriteHalf<S>,
}
/// **Forcefully** transform `buf` into a transformed buffer.
///
/// # Does **not** do these things
/// Doesn't check for ptr ident with `buf` against `crypt_buf_ptr`. You should do that yourself.
/// Doesn't truncate `crypt_buffer` after transformation.
fn transform_into(crypt_buffer: &mut Vec<u8>, cipher: &mut Crypter, buf: &[u8]) -> Result<usize, ErrorStack>
{
if crypt_buffer.len() < buf.len() {
crypt_buffer.resize(buf.len(), 0);
}
let n = cipher.update(buf, &mut crypt_buffer[..buf.len()])?;
let _f = cipher.finalize(&mut crypt_buffer[..n])?;
debug_assert_eq!(_f, 0);
Ok(n)
}
impl<'a, S: AsyncWrite> EncryptedWriteHalf<'a, S>
{
#[inline(always)] fn forward(self: Pin<&mut Self>) -> Pin<&mut WriteHalf<S>>
{
unsafe {self.map_unchecked_mut(|this| this.backing)}
}
}
impl<'a, S: AsyncWrite> AsyncWrite for EncryptedWriteHalf<'a, S>
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
let this = self.as_mut().project();
if this.crypt_buffer.is_empty() || this.crypt_buf_ptr != buf {
// Transform `buf` into self.crypt_buffer
let n = transform_into(this.crypt_buffer, this.cipher, buf)?;
*this.crypt_buf_ptr = buf.into();
this.crypt_buffer.truncate(n);
} // else { /* No need to transform */ }
let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_write(cx, &this.crypt_buffer[..]);
if poll.is_ready()
{
*this.crypt_buf_ptr = Default::default();
this.crypt_buffer.clear();
}
poll
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_flush(cx);
if poll.is_ready() {
this.crypt_buffer.clear();
*this.crypt_buf_ptr = Default::default();
}
poll
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_shutdown(cx);
if poll.is_ready() {
bytes::blank(&mut this.crypt_buffer[..]);
this.crypt_buffer.clear();
*this.crypt_buf_ptr = Default::default();
}
poll
}
}
/// Readable half of `EncryptedStream`.
#[pin_project]
pub struct ReadHalf<S>
where S: AsyncRead
{
meta: Arc<EncryptedStreamMeta>,
/// chacha20_poly1305 decrypter for incoming reads from `S`
#[pin] backing_read: S,
}
#[pin_project]
pub struct EncryptedReadHalf<'a, S>
where S: AsyncRead,
{
cipher: Crypter,
#[pin] backing: &'a mut ReadHalf<S>,
}
impl<'a, S: AsyncRead> AsyncRead for EncryptedReadHalf<'a, S>
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let this = self.project();
let cipher = this.cipher;
let stream = unsafe {this.backing.map_unchecked_mut(|f| &mut f.backing_read)};
let res = stream.poll_read(cx,buf);
// Decrypt the buffer if the read succeeded
res.map(move |res| res.and_then(move |sz| {
alloca_limit(sz, move |obuf| -> io::Result<usize> {
// This `sz` and old `sz` should always be the same.
let sz = cipher.update(&buf[..sz], &mut obuf[..])?;
let _f = cipher.finalize(&mut obuf[..sz])?;
debug_assert_eq!(_f, 0);
// Copy decrypted buffer into output buffer
buf.copy_from_slice(&obuf[..sz]);
Ok(sz)
})
}))
}
}
impl<S: AsyncRead> AsyncRead for ReadHalf<S>
{
#[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.project().backing_read.poll_read(cx, buf)
}
#[inline] fn poll_read_buf<B: BufMut>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<io::Result<usize>>
where
Self: Sized, {
self.project().backing_read.poll_read_buf(cx, buf)
}
}
impl<S: AsyncWrite> AsyncWrite for WriteHalf<S>
{
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.project().backing_write.poll_write(cx, buf)
}
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().backing_write.poll_flush(cx)
}
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().backing_write.poll_shutdown(cx)
}
#[inline] fn poll_write_buf<B: Buf>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<Result<usize, io::Error>>
where
Self: Sized, {
self.project().backing_write.poll_write_buf(cx, buf)
}
}
//TODO: Rework everything past this point:
/*
struct ReadWriteCombined<R, W>
{
/// Since chacha20stream has no AsyncRead counterpart, we have to do it ourselves.
cipher_read: Option<Crypter>,
backing_read: R,
backing_write: dual::DualStream<W>,
}
/// RSA/chacha20 encrypted stream
pub struct EncryptedStream<R, W>
where R: AsyncRead,
W: AsyncWrite,
{
meta: EncryptedStreamMeta,
// Keep the streams on the heap to keep this type not hueg.
backing: Box<ReadWriteCombined<R, W>>,
}
//TODO: How do we use this with a single AsyncStream instead of requiring 2? Will we need to make our own Arc wrapper?? Ugh,, for now let's ignore this I guess... Most read+write thingies have a Read/WriteHalf split mechanism.
//
// Note that this does actually work fine with things like tokio's `duplex()` (i think)
impl<R: AsyncRead, W: AsyncWrite> EncryptedStream<R, W>
{
/// Has this stream done its RSA key exchange?
pub fn has_exchanged(&self) -> bool
{
self.meta.them.is_some()
}
/// Split this stream into a read and writeable half.
pub fn split(self) -> (WriteHalf<W>, ReadHalf<R>)
{
let meta = Arc::new(self.meta);
let (read, write) = {
let ReadWriteCombined { cipher_read, backing_read, backing_write } = *self.backing;
((cipher_read, backing_read), backing_write)
};
(WriteHalf {
meta: Arc::clone(&meta),
backing_write: Box::new(write),
}, ReadHalf {
meta,
cipher: read.0,
backing_read: Box::new(read.1),
})
}
/// Join a split `EncryptedStream` from halves.
///
/// # Panics
/// If the read and write half are not from the same split.
pub fn from_split((write, read): (WriteHalf<W>, ReadHalf<R>)) -> Self
{
if !Arc::ptr_eq(&write.meta, &read.meta) {
panic!("Read and Write halves are not from the same split");
}
todo!("Drop write's `meta`, consume read's `meta`. Move the streams into `ReadWriteCombined`")
}
}
/*
impl<S: AsyncWrite> AsyncWrite for WriteHalf<S>
{
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_write(cx, buf)
}
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_flush(cx)
}
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_shutdown(cx)
}
}
impl<S: AsyncRead> AsyncRead for ReadHalf<S>
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let this = self.project();
let cipher = this.cipher.as_mut();
let stream = unsafe {this.backing_read.map_unchecked_mut(|f| f.as_mut())};
let res = stream.poll_read(cx,buf);
if let Some(cipher) = cipher {
// Decrypt the buffer if the read succeeded
res.map(move |res| res.and_then(move |sz| {
alloca_limit(sz, move |obuf| -> io::Result<usize> {
// This `sz` and old `sz` should always be the same.
let sz = cipher.update(&buf[..sz], &mut obuf[..])?;
let _f = cipher.finalize(&mut obuf[..sz])?;
debug_assert_eq!(_f, 0);
// Copy decrypted buffer into output buffer
buf.copy_from_slice(&obuf[..sz]);
Ok(sz)
})
}))
} else {
res
}
}
}
*/
*/