You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
4.4 KiB
156 lines
4.4 KiB
use super::*;
|
|
|
|
use tokio::io::{AsyncWrite, AsyncRead};
|
|
use std::sync::Arc;
|
|
use openssl::symm::Crypter;
|
|
|
|
use std::{
|
|
pin::Pin,
|
|
task::{Poll, Context},
|
|
io,
|
|
};
|
|
|
|
use crypt::{
|
|
RsaPublicKey,
|
|
RsaPrivateKey,
|
|
};
|
|
|
|
/// A type that implements both `AsyncWrite` and `AsyncRead`
|
|
pub trait AsyncStream: AsyncRead + AsyncWrite{}
|
|
impl<T: AsyncRead + AsyncWrite + ?Sized> AsyncStream for T{}
|
|
|
|
/// Inner rsa data for encrypted stream read+write halves
|
|
struct EncryptedStreamMeta
|
|
{
|
|
us: RsaPrivateKey,
|
|
them: Option<RsaPublicKey>,
|
|
}
|
|
|
|
/// Writable half of `EncryptedStream`.
|
|
pub struct WriteHalf<S>
|
|
where S: AsyncWrite
|
|
{
|
|
meta: Arc<EncryptedStreamMeta>,
|
|
|
|
backing_write: Box<dual::DualStream<S>>,
|
|
}
|
|
|
|
/// Readable half of `EncryptedStream`.
|
|
#[pin_project]
|
|
pub struct ReadHalf<S>
|
|
where S: AsyncRead
|
|
{
|
|
meta: Arc<EncryptedStreamMeta>,
|
|
|
|
/// chacha20_poly1305 decrypter for incoming reads from `S`
|
|
//TODO: chacha20stream: implement a read version of AsyncSink so we don't need to keep this?
|
|
cipher: Option<Crypter>,
|
|
#[pin] backing_read: Box<S>,
|
|
}
|
|
|
|
struct ReadWriteCombined<R, W>
|
|
{
|
|
/// Since chacha20stream has no AsyncRead counterpart, we have to do it ourselves.
|
|
cipher_read: Option<Crypter>,
|
|
backing_read: R,
|
|
|
|
backing_write: dual::DualStream<W>,
|
|
}
|
|
|
|
/// RSA/chacha20 encrypted stream
|
|
pub struct EncryptedStream<R, W>
|
|
where R: AsyncRead,
|
|
W: AsyncWrite,
|
|
{
|
|
meta: EncryptedStreamMeta,
|
|
|
|
// Keep the streams on the heap to keep this type not hueg.
|
|
backing: Box<ReadWriteCombined<R, W>>,
|
|
}
|
|
|
|
//TODO: How do we use this with a single AsyncStream instead of requiring 2? Will we need to make our own Arc wrapper?? Ugh,, for now let's ignore this I guess... Most read+write thingies have a Read/WriteHalf split mechanism.
|
|
//
|
|
// Note that this does actually work fine with things like tokio's `duplex()` (i think)
|
|
impl<R: AsyncRead, W: AsyncWrite> EncryptedStream<R, W>
|
|
{
|
|
/// Has this stream done its RSA key exchange?
|
|
pub fn has_exchanged(&self) -> bool
|
|
{
|
|
self.meta.them.is_some()
|
|
}
|
|
|
|
/// Split this stream into a read and writeable half.
|
|
pub fn split(self) -> (WriteHalf<W>, ReadHalf<R>)
|
|
{
|
|
let meta = Arc::new(self.meta);
|
|
let (read, write) = {
|
|
let ReadWriteCombined { cipher_read, backing_read, backing_write } = *self.backing;
|
|
|
|
((cipher_read, backing_read), backing_write)
|
|
};
|
|
|
|
(WriteHalf {
|
|
meta: Arc::clone(&meta),
|
|
backing_write: Box::new(write),
|
|
}, ReadHalf {
|
|
meta,
|
|
cipher: read.0,
|
|
backing_read: Box::new(read.1),
|
|
})
|
|
}
|
|
|
|
/// Join a split `EncryptedStream` from halves.
|
|
///
|
|
/// # Panics
|
|
/// If the read and write half are not from the same split.
|
|
pub fn from_split((write, read): (WriteHalf<W>, ReadHalf<R>)) -> Self
|
|
{
|
|
if !Arc::ptr_eq(&write.meta, &read.meta) {
|
|
panic!("Read and Write halves are not from the same split");
|
|
}
|
|
|
|
todo!("Drop write's `meta`, consume read's `meta`. Move the streams into `ReadWriteCombined`")
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncWrite> AsyncWrite for WriteHalf<S>
|
|
{
|
|
#[inline] fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
|
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_write(cx, buf)
|
|
}
|
|
#[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
|
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_flush(cx)
|
|
}
|
|
#[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
|
unsafe {self.map_unchecked_mut(|this| this.backing_write.as_mut())}.poll_shutdown(cx)
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncRead> AsyncRead for ReadHalf<S>
|
|
{
|
|
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
|
let this = self.project();
|
|
let cipher = this.cipher.as_mut();
|
|
let stream = unsafe {this.backing_read.map_unchecked_mut(|f| f.as_mut())};
|
|
|
|
let res = stream.poll_read(cx,buf);
|
|
if let Some(cipher) = cipher {
|
|
// Decrypt the buffer if the read succeeded
|
|
res.map(move |res| res.and_then(move |sz| {
|
|
alloca_limit(sz, move |obuf| -> io::Result<usize> {
|
|
// This `sz` and old `sz` should always be the same.
|
|
let sz = cipher.update(&buf[..sz], &mut obuf[..])?;
|
|
let _f = cipher.finalize(&mut obuf[..sz])?;
|
|
debug_assert_eq!(_f, 0);
|
|
|
|
// Copy decrypted buffer into output buffer
|
|
buf.copy_from_slice(&obuf[..sz]);
|
|
Ok(sz)
|
|
})
|
|
}))
|
|
} else {
|
|
res
|
|
}
|
|
}
|
|
}
|