You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

431 lines
13 KiB

use super::*;
use tokio::io::{AsyncWrite, AsyncRead};
use std::sync::Arc;
use openssl::symm::Crypter;
use openssl::error::ErrorStack;
use ::bytes::{Buf, BufMut};
use std::{
pin::Pin,
task::{Poll, Context},
io,
};
use crypt::{
RsaPublicKey,
RsaPrivateKey,
};
mod traits;
pub use traits::*;
mod exchange;
/// Combined Read + Write encryptable async stream.
///
/// The `AsyncRead` and `AsyncWrite` impls of this type forward to the backing impls for `S`.
///
/// # Exchange
/// A combined stream is the only way to exchange pubkeys and enabling the creation of encrypted read/write wrappers on the combined stream or splits.
#[pin_project]
#[derive(Debug)]
pub struct Stream<S>
{
meta: EncryptedStreamMeta,
#[pin] stream: S,
}
/// `Stream` with enabled encryption.
pub struct EncryptedStream<'a, S>
{
read_cipher: Crypter,
write_cipher: Crypter,
write_crypt_buf_ptr: SliceMeta<u8>,
write_crypt_buffer: Vec<u8>,
backing: &'a mut Stream<S>,
}
impl<Tx, Rx> Stream<Merge<Tx, Rx>>
where Tx: AsyncWrite,
Rx: AsyncRead
{
/// Exchange RSA keys through this stream.
pub async fn exchange(&mut self) -> io::Result<()>
{
todo!()
}
/// Merge an `AsyncWrite`, and `AsyncRead` stream into `Stream`.
pub fn merged(tx: Tx, rx: Rx) -> Self
{
Self {
meta: EncryptedStreamMeta::new(),
stream: Merge(tx, rx),
}
}
}
/*
impl<S> Stream<S>
where S: Split,
S::First: AsyncWrite,
S::Second: AsyncRead
{
/// Create a new `Stream` from two streams, one implemetor of `AsyncWrite`, and one of `AsyncRead`.
pub fn new(tx: S::First, rx: S::Second) -> Self
{
Self {
meta: EncryptedStreamMeta {
them: None,
us: crypt::generate(),
},
stream: S::unsplit(tx, rx),
}
}
}
impl<S: AsyncStream> Stream<S>
{
/// Create a new `Stream` from an implementor of both `AsyncRead` and `AsyncWrite`.
pub fn new_single(stream: S) -> Self
{
Self {
meta: EncryptedStreamMeta {
them: None,
us: crypt::generate(),
},
stream,
}
}
/// Create a split by cloning `S`.
pub fn split_clone(self) -> (WriteHalf<S>, ReadHalf<S>)
where S: Clone
{
Stream {
stream: (self.stream.clone(), self.stream),
meta: self.meta
}.split()
}
}*/
impl<S> Split for Stream<S>
where S: Split,
S::First: AsyncWrite,
S::Second: AsyncRead
{
type First = WriteHalf<S::First>;
type Second = ReadHalf<S::Second>;
#[inline] fn split(self) -> (Self::First, Self::Second) {
self.split()
}
#[inline] fn unsplit(a: Self::First, b: Self::Second) -> Self {
Self::unsplit(a, b)
}
}
impl<S> Stream<S>
where S: Split,
S::First: AsyncWrite,
S::Second: AsyncRead
{
/// Combine a previously split `EncryptedStream`'s halves back into a single type.
///
/// # Panics
/// If the two halves didn't originally come from the same `EncryptedStream`.
pub fn unsplit(tx: WriteHalf<S::First>, rx: ReadHalf<S::Second>) -> Self
{
#[inline(never)] fn panic_not_ptr_eq() -> !
{
panic!("Cannot join halves from different splits")
}
if !Arc::ptr_eq(&tx.meta, &rx.meta) {
panic_not_ptr_eq();
}
let WriteHalf { meta: _meta, backing_write: tx } = tx;
drop(_meta);
let ReadHalf { meta, backing_read: rx } = rx;
let meta = Arc::try_unwrap(meta).unwrap();
Self {
meta,
stream: S::unsplit(tx, rx),
}
}
/// Split this `EncryptedStream` into a read and a write half.
pub fn split(self) -> (WriteHalf<S::First>, ReadHalf<S::Second>)
{
let meta = Arc::new(self.meta);
let (tx, rx) = self.stream.split();
(WriteHalf {
meta: meta.clone(),
backing_write: tx,
}, ReadHalf {
meta,
backing_read: rx,
})
}
}
impl<S: AsyncRead> AsyncRead for Stream<S>
{
#[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.project().stream.poll_read(cx, buf)
}
#[inline] fn poll_read_buf<B: BufMut>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<io::Result<usize>>
where
Self: Sized, {
self.project().stream.poll_read_buf(cx, buf)
}
}
impl<S: AsyncWrite> AsyncWrite for Stream<S>
{
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.project().stream.poll_write(cx, buf)
}
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx)
}
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx)
}
#[inline] fn poll_write_buf<B: Buf>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<Result<usize, io::Error>>
where
Self: Sized, {
self.project().stream.poll_write_buf(cx, buf)
}
}
/// Inner rsa data for encrypted stream read+write halves
///
/// # Exchange / mutation
/// For split streams, this becomes immutable. If exchange has not been performed by the combined stream before splitting, then it is impossible for the split Read and Write halves to form EncryptedRead and EncryptedWrite instances on top of themselves.
/// The stream must be re-joined, exchanged, and then split again in this case.
/// Therefore exchange should happen before the original stream is split at all.
///
/// Only the combined stream can mutate this structure. The halves hold it behind an immutable shared reference.
#[derive(Debug)]
struct EncryptedStreamMeta
{
us: RsaPrivateKey,
them: Option<RsaPublicKey>,
}
impl EncryptedStreamMeta
{
/// Create a new meta with a newly generated private key.
#[inline(always)] pub fn new() -> Self
{
Self {
them: None,
us: crypt::generate(),
}
}
}
/// Writable half of `EncryptedStream`.
#[pin_project]
#[derive(Debug)]
pub struct WriteHalf<S>
where S: AsyncWrite
{
/// Shared reference to the RSA data of the backing stream, held by both Write and Read halves.
///
/// # Immutability of this metadata
/// Exchange can only happen on the combined Read+Write stream, so we don't need ayn mutability of `meta` here. Mutating `meta` happens only when it's owned by the combined stream (not in an `Arc`, which is only used to share it between the Read and Write half).
meta: Arc<EncryptedStreamMeta>,
#[pin] backing_write: S,//Box<dual::DualStream<S>>,
}
#[pin_project]
pub struct EncryptedWriteHalf<'a, S>
where S: AsyncWrite,
{
/// Used to transform input `buf` into `self.crypt_buffer` before polling a write to `backing_write` with the newly filled `self.crypt_buffer`.
/// See below 2 fields.
cipher: Crypter,
/// Slice pointer of the input `buf` that corresponds to the transformed data in `crypt_buf`.
/// Used to check if a `Pending` write was cancelled, by comparing if the input `buf` slice of this next write is different from the last one (which's data is stored in this field after the poll becomes `Pending`.)
///
/// # Usage
/// Before checking is `crypt_buffer` is empty and that we should re-poll the backing stream with it, we check the input `buf` against this value.
/// If they differ, then the `Pending` result from the last poll was discarded, and we clear the `crypt_buffer` and re-encrypt the new `buf` into it.
///
/// After a `Pending` write to `backing_write`, a `SliceMeta` from the input `buf` is written to this field.
/// If it was *not* a `Pending` poll result, then this field is re-set to `Default` (an invalid `null` value).
///
/// This compares **pointer** and **length** identity of the slice. (See `SliceMeta` for more information.)
/// Which is a faster method of determining if the buffer has changed than applying `Hash` to the whole buffer each time `poll_write` is called just to compare.
///
/// # Initialised
/// Initialised as `Default` (`null`).
/// Will be `null` if `crypt_buffer` is empty (i.e. a non-`Pending` poll result).
crypt_buf_ptr: SliceMeta<u8>,
/// Buffer written to when encrypting the input `buf`.
///
/// It is cleared after a `Ready` `poll_write()` on `backing_write`.
/// On a `Pending` write, this buffer is resized to only fit the transformed data written to it, and left as it is until the next call to `poll_write`.
///
/// If the poll was not discarded (see above field), on the next call to this instance's `poll_write` we just immediatly re-poll `backing_write` with this buffer.
/// If it was disarded. We re-set this buffer and transform the new input `buf` into it as if the previous poll returned `Ready`.
///
/// This exists so we don't have to transform the entire `buf` on every poll. We can just transform it once and then wait until it is `Ready` before discarding the data (`.empty()`) and allowing new data to fill it on the next, fresh `poll_write`.
crypt_buffer: Vec<u8>,
#[pin] backing: &'a mut WriteHalf<S>,
}
/// **Forcefully** transform `buf` into a transformed buffer.
///
/// # Does **not** do these things
/// Doesn't check for ptr ident with `buf` against `crypt_buf_ptr`. You should do that yourself.
/// Doesn't truncate `crypt_buffer` after transformation.
fn transform_into(crypt_buffer: &mut Vec<u8>, cipher: &mut Crypter, buf: &[u8]) -> Result<usize, ErrorStack>
{
if crypt_buffer.len() < buf.len() {
crypt_buffer.resize(buf.len(), 0);
}
let n = cipher.update(buf, &mut crypt_buffer[..buf.len()])?;
let _f = cipher.finalize(&mut crypt_buffer[..n])?;
debug_assert_eq!(_f, 0);
Ok(n)
}
impl<'a, S: AsyncWrite> EncryptedWriteHalf<'a, S>
{
#[inline(always)] fn forward(self: Pin<&mut Self>) -> Pin<&mut WriteHalf<S>>
{
unsafe {self.map_unchecked_mut(|this| this.backing)}
}
}
impl<'a, S: AsyncWrite> AsyncWrite for EncryptedWriteHalf<'a, S>
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
let this = self.as_mut().project();
if this.crypt_buffer.is_empty() || this.crypt_buf_ptr != buf {
// Transform `buf` into self.crypt_buffer
let n = transform_into(this.crypt_buffer, this.cipher, buf)?;
*this.crypt_buf_ptr = buf.into();
this.crypt_buffer.truncate(n);
} // else { /* No need to transform */ }
let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_write(cx, &this.crypt_buffer[..]);
if poll.is_ready()
{
*this.crypt_buf_ptr = Default::default();
this.crypt_buffer.clear();
}
poll
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_flush(cx);
if poll.is_ready() {
this.crypt_buffer.clear();
*this.crypt_buf_ptr = Default::default();
}
poll
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
let poll = unsafe {this.backing.map_unchecked_mut(|this| *this)}.poll_shutdown(cx);
if poll.is_ready() {
bytes::blank(&mut this.crypt_buffer[..]);
this.crypt_buffer.clear();
*this.crypt_buf_ptr = Default::default();
}
poll
}
}
/// Readable half of `EncryptedStream`.
#[pin_project]
#[derive(Debug)]
pub struct ReadHalf<S>
where S: AsyncRead
{
meta: Arc<EncryptedStreamMeta>,
/// chacha20_poly1305 decrypter for incoming reads from `S`
#[pin] backing_read: S,
}
#[pin_project]
pub struct EncryptedReadHalf<'a, S>
where S: AsyncRead,
{
cipher: Crypter,
#[pin] backing: &'a mut ReadHalf<S>,
}
impl<'a, S: AsyncRead> AsyncRead for EncryptedReadHalf<'a, S>
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let this = self.project();
let cipher = this.cipher;
let stream = unsafe {this.backing.map_unchecked_mut(|f| &mut f.backing_read)};
let res = stream.poll_read(cx,buf);
// Decrypt the buffer if the read succeeded
res.map(move |res| res.and_then(move |sz| {
alloca_limit(sz, move |obuf| -> io::Result<usize> {
// This `sz` and old `sz` should always be the same.
let sz = cipher.update(&buf[..sz], &mut obuf[..])?;
let _f = cipher.finalize(&mut obuf[..sz])?;
debug_assert_eq!(_f, 0);
// Copy decrypted buffer into output buffer
buf.copy_from_slice(&obuf[..sz]);
Ok(sz)
})
}))
}
}
impl<S: AsyncRead> AsyncRead for ReadHalf<S>
{
#[inline] fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.project().backing_read.poll_read(cx, buf)
}
#[inline] fn poll_read_buf<B: BufMut>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<io::Result<usize>>
where
Self: Sized, {
self.project().backing_read.poll_read_buf(cx, buf)
}
}
impl<S: AsyncWrite> AsyncWrite for WriteHalf<S>
{
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.project().backing_write.poll_write(cx, buf)
}
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().backing_write.poll_flush(cx)
}
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().backing_write.poll_shutdown(cx)
}
#[inline] fn poll_write_buf<B: Buf>(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B) -> Poll<Result<usize, io::Error>>
where
Self: Sized, {
self.project().backing_write.poll_write_buf(cx, buf)
}
}