Compare commits

..

No commits in common. 'master' and 'exchange-unsafe' have entirely different histories.

6
Cargo.lock generated

@ -185,11 +185,10 @@ dependencies = [
[[package]] [[package]]
name = "cryptohelpers" name = "cryptohelpers"
version = "1.8.2" version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9143447fb393f8d38abbb617af9b986a0941785ddc63685bd8de735fb31bcafc" checksum = "14be74ce15793a86acd04872953368ce27d07f384f07b8028bd5aaa31a031a38"
dependencies = [ dependencies = [
"base64 0.13.0",
"crc", "crc",
"futures", "futures",
"getrandom 0.1.16", "getrandom 0.1.16",
@ -800,7 +799,6 @@ name = "rsh"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"ad-hoc-iter", "ad-hoc-iter",
"base64 0.13.0",
"bytes 1.0.1", "bytes 1.0.1",
"chacha20stream", "chacha20stream",
"color-eyre", "color-eyre",

@ -7,17 +7,16 @@ edition = "2018"
[dependencies] [dependencies]
ad-hoc-iter = "0.2.3" ad-hoc-iter = "0.2.3"
base64 = "0.13.0"
bytes = { version = "1.0.1", features = ["serde"] } bytes = { version = "1.0.1", features = ["serde"] }
chacha20stream = { version = "2.1.0", features = ["async", "serde"] } chacha20stream = { version = "2.1.0", features = ["async", "serde"] }
color-eyre = "0.5.11" color-eyre = "0.5.11"
cryptohelpers = { version = "1.8.2" , features = ["serialise", "full"] } cryptohelpers = { version = "1.8.1" , features = ["serialise", "full"] }
futures = "0.3.16" futures = "0.3.16"
mopa = "0.2.2" mopa = "0.2.2"
pin-project = "1.0.8" pin-project = "1.0.8"
serde = { version = "1.0.126", features = ["derive"] } serde = { version = "1.0.126", features = ["derive"] }
serde_cbor = "0.11.1" serde_cbor = "0.11.1"
smallvec = { version = "1.6.1", features = ["union", "serde", "write", "const_generics"] } smallvec = { version = "1.6.1", features = ["union", "serde", "write"] }
stackalloc = "1.1.1" stackalloc = "1.1.1"
tokio = { version = "0.2", features = ["full"] } tokio = { version = "0.2", features = ["full"] }
tokio-uring = "0.1.0" tokio-uring = "0.1.0"

@ -4,14 +4,11 @@ use std::mem::{self, MaybeUninit};
use std::iter; use std::iter;
use smallvec::SmallVec; use smallvec::SmallVec;
mod alloc; /// Max size of memory allowed to be allocated on the stack.
pub use alloc::*; pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB
mod hex; /// A stack-allocated vector that spills onto the heap when needed.
pub use hex::*; pub type StackVec<T> = SmallVec<[T; STACK_MEM_ALLOC_MAX]>;
mod base64;
pub use self::base64::*;
/// A maybe-atom that can spill into a vector. /// A maybe-atom that can spill into a vector.
pub type MaybeVec<T> = SmallVec<[T; 1]>; pub type MaybeVec<T> = SmallVec<[T; 1]>;
@ -26,6 +23,89 @@ pub fn vec_uninit<T>(sz: usize) -> Vec<MaybeUninit<T>>
} }
} }
/// Allocate a local buffer initialised from `init`.
pub fn alloc_local_with<T, U>(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with(sz, init, within)
}
}
/// Allocate a local zero-initialised byte buffer
pub fn alloc_local_bytes<U>(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<MaybeUninit<u8>> = vec_uninit(sz);
within(unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz);
stackalloc::helpers::slice_assume_init_mut(&mut memory[..])
})
} else {
stackalloc::alloca_zeroed(sz, within)
}
}
/// Allocate a local zero-initialised buffer
pub fn alloc_local_zeroed<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
let sz_bytes = mem::size_of::<T>() * sz;
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes);
}
within(&mut memory[..])
} else {
stackalloc::alloca_zeroed(sz_bytes, move |buf| {
unsafe {
debug_assert_eq!(buf.len() / mem::size_of::<T>(), sz);
within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<T>, sz))
}
})
}
}
/// Allocate a local uninitialised buffer
pub fn alloc_local_uninit<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
within(&mut memory[..])
} else {
stackalloc::stackalloc_uninit(sz, within)
}
}
/// Allocate a local buffer initialised with `init`.
pub fn alloc_local<T: Clone, U>(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc(sz, init, within)
}
}
/// Allocate a local buffer initialised with `T::default()`.
pub fn alloc_local_with_default<T: Default, U>(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(Default::default).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with_default(sz, within)
}
}
/// Create a blanket-implementing trait that is a subtrait of any number of traits. /// Create a blanket-implementing trait that is a subtrait of any number of traits.
/// ///
/// # Usage /// # Usage
@ -91,9 +171,7 @@ const _:() = {
let _ref: &std::io::Stdin = a.downcast_ref::<std::io::Stdin>().unwrap(); let _ref: &std::io::Stdin = a.downcast_ref::<std::io::Stdin>().unwrap();
} }
} }
/*
XXX: This is broken on newest nightly?
const _TEST: () = _a::<dyn Test>(); const _TEST: () = _a::<dyn Test>();
const _TEST2: () = _b::<dyn TestAny>(); const _TEST2: () = _b::<dyn TestAny>();
*/
}; };

@ -1,91 +0,0 @@
//! Stack allocation helpers
use super::*;
/// Max size of memory allowed to be allocated on the stack.
pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB
/// A stack-allocated vector that spills onto the heap when needed.
pub type StackVec<T> = SmallVec<[T; STACK_MEM_ALLOC_MAX]>;
/// Allocate a local buffer initialised from `init`.
pub fn alloc_local_with<T, U>(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with(sz, init, within)
}
}
/// Allocate a local zero-initialised byte buffer
pub fn alloc_local_bytes<U>(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<MaybeUninit<u8>> = vec_uninit(sz);
within(unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz);
stackalloc::helpers::slice_assume_init_mut(&mut memory[..])
})
} else {
stackalloc::alloca_zeroed(sz, within)
}
}
/// Allocate a local zero-initialised buffer
pub fn alloc_local_zeroed<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
let sz_bytes = mem::size_of::<T>() * sz;
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes);
}
within(&mut memory[..])
} else {
stackalloc::alloca_zeroed(sz_bytes, move |buf| {
unsafe {
debug_assert_eq!(buf.len() / mem::size_of::<T>(), sz);
within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<T>, sz))
}
})
}
}
/// Allocate a local uninitialised buffer
pub fn alloc_local_uninit<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
within(&mut memory[..])
} else {
stackalloc::stackalloc_uninit(sz, within)
}
}
/// Allocate a local buffer initialised with `init`.
pub fn alloc_local<T: Clone, U>(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc(sz, init, within)
}
}
/// Allocate a local buffer initialised with `T::default()`.
pub fn alloc_local_with_default<T: Default, U>(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(Default::default).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with_default(sz, within)
}
}

@ -1,15 +0,0 @@
//! Base64 formatting extensions
use super::*;
pub trait Base64StringExt
{
fn to_base64_string(&self) -> String;
}
impl<T: ?Sized> Base64StringExt for T
where T: AsRef<[u8]>
{
fn to_base64_string(&self) -> String {
::base64::encode(self.as_ref())
}
}

@ -1,124 +0,0 @@
use std::{
mem,
iter::{
self,
ExactSizeIterator,
FusedIterator,
},
slice,
fmt,
};
#[derive(Debug, Clone)]
pub struct HexStringIter<I>(I, [u8; 2]);
impl<I: Iterator<Item = u8>> HexStringIter<I>
{
/// Write this hex string iterator to a formattable buffer
pub fn consume<F>(self, f: &mut F) -> fmt::Result
where F: std::fmt::Write
{
if self.1[0] != 0 {
write!(f, "{}", self.1[0] as char)?;
}
if self.1[1] != 0 {
write!(f, "{}", self.1[1] as char)?;
}
for x in self.0 {
write!(f, "{:02x}", x)?;
}
Ok(())
}
/// Consume into a string
pub fn into_string(self) -> String
{
let mut output = match self.size_hint() {
(0, None) => String::new(),
(_, Some(x)) |
(x, None) => String::with_capacity(x),
};
self.consume(&mut output).unwrap();
output
}
}
pub trait HexStringIterExt<I>: Sized
{
fn into_hex(self) -> HexStringIter<I>;
}
pub type HexStringSliceIter<'a> = HexStringIter<iter::Copied<slice::Iter<'a, u8>>>;
pub trait HexStringSliceIterExt
{
fn hex(&self) -> HexStringSliceIter<'_>;
}
impl<S> HexStringSliceIterExt for S
where S: AsRef<[u8]>
{
fn hex(&self) -> HexStringSliceIter<'_>
{
self.as_ref().iter().copied().into_hex()
}
}
impl<I: IntoIterator<Item=u8>> HexStringIterExt<I::IntoIter> for I
{
#[inline] fn into_hex(self) -> HexStringIter<I::IntoIter> {
HexStringIter(self.into_iter(), [0u8; 2])
}
}
impl<I: Iterator<Item = u8>> Iterator for HexStringIter<I>
{
type Item = char;
fn next(&mut self) -> Option<Self::Item>
{
match self.1 {
[_, 0] => {
use std::io::Write;
write!(&mut self.1[..], "{:02x}", self.0.next()?).unwrap();
Some(mem::replace(&mut self.1[0], 0) as char)
},
[0, _] => Some(mem::replace(&mut self.1[1], 0) as char),
_ => unreachable!(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (l, h) = self.0.size_hint();
(l * 2, h.map(|x| x*2))
}
}
impl<I: Iterator<Item = u8> + ExactSizeIterator> ExactSizeIterator for HexStringIter<I>{}
impl<I: Iterator<Item = u8> + FusedIterator> FusedIterator for HexStringIter<I>{}
impl<I: Iterator<Item = u8>> From<HexStringIter<I>> for String
{
fn from(from: HexStringIter<I>) -> Self
{
from.into_string()
}
}
impl<I: Iterator<Item = u8> + Clone> fmt::Display for HexStringIter<I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
self.clone().consume(f)
}
}
/*
#[macro_export] macro_rules! prog1 {
($first:expr, $($rest:expr);+ $(;)?) => {
($first, $( $rest ),+).0
}
}
*/

@ -60,7 +60,7 @@ fn message_serial_sign()
#[test] #[test]
fn message_serial_encrypt() fn message_serial_encrypt()
{ {
//color_eyre::install().unwrap(); color_eyre::install().unwrap();
let rsa_priv = rsa::RsaPrivateKey::generate().unwrap(); let rsa_priv = rsa::RsaPrivateKey::generate().unwrap();
struct Dec(rsa::RsaPrivateKey); struct Dec(rsa::RsaPrivateKey);
struct Enc(rsa::RsaPublicKey); struct Enc(rsa::RsaPublicKey);

@ -19,7 +19,6 @@ use bytes::Bytes;
use cancel::*; use cancel::*;
pub mod enc; pub mod enc;
pub mod pipe;
/// Details of a newly accepted raw socket peer. /// Details of a newly accepted raw socket peer.
/// ///

@ -2,14 +2,8 @@
use super::*; use super::*;
use cryptohelpers::{ use cryptohelpers::{
rsa::{ rsa::{
self,
RsaPublicKey, RsaPublicKey,
RsaPrivateKey, RsaPrivateKey,
openssl::{
symm::Crypter,
error::ErrorStack,
},
}, },
sha256, sha256,
}; };
@ -18,8 +12,6 @@ use chacha20stream::{
AsyncSource, AsyncSource,
Key, IV, Key, IV,
cha,
}; };
use std::sync::Arc; use std::sync::Arc;
use tokio::{ use tokio::{
@ -28,24 +20,25 @@ use tokio::{
RwLockReadGuard, RwLockReadGuard,
RwLockWriteGuard, RwLockWriteGuard,
}, },
io::{
DuplexStream,
},
}; };
use std::{ use std::{
io, io,
fmt,
task::{ task::{
Context, Poll, Context, Poll,
}, },
pin::Pin, pin::Pin,
marker::Unpin, marker::{
Unpin,
PhantomPinned,
},
}; };
use smallvec::SmallVec;
/// Size of a single RSA ciphertext. /// Size of a single RSA ciphertext.
pub const RSA_CIPHERTEXT_SIZE: usize = 512; pub const RSA_CIPHERTEXT_SIZE: usize = 512;
/// A single, full block of RSA ciphertext.
type RsaCiphertextBlock = [u8; RSA_CIPHERTEXT_SIZE];
/// Max size to read when exchanging keys /// Max size to read when exchanging keys
const TRANS_KEY_MAX_SIZE: usize = 4096; const TRANS_KEY_MAX_SIZE: usize = 4096;
@ -56,44 +49,13 @@ struct ESockInfo {
them: Option<RsaPublicKey>, them: Option<RsaPublicKey>,
} }
impl ESockInfo #[derive(Debug)]
{
/// Generate a new private key
pub fn new(us: impl Into<RsaPrivateKey>) -> Self
{
Self {
us: us.into(),
them: None,
}
}
/// Generate a new private key for the local endpoint
pub fn generate() -> Result<Self, rsa::Error>
{
Ok(Self::new(RsaPrivateKey::generate()?))
}
}
/// The encryption state of the Tx and Rx instances.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
struct ESockState { struct ESockState {
encr: bool, encr: bool,
encw: bool, encw: bool,
} }
impl Default for ESockState /// Contains a Key and IV that can be serialized and then encrypted
{
#[inline]
fn default() -> Self
{
Self {
encr: false,
encw: false,
}
}
}
/// Contains a cc20 Key and IV that can be serialized and then encrypted
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct ESockSessionKey struct ESockSessionKey
{ {
@ -101,71 +63,6 @@ struct ESockSessionKey
iv: IV, iv: IV,
} }
impl fmt::Display for ESockSessionKey
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "Key: {}, IV: {}", self.key.hex(), self.iv.hex())
}
}
impl ESockSessionKey
{
/// Generate a new cc20 key + iv,
pub fn generate() -> Self
{
let (key,iv) = cha::keygen();
Self{key,iv}
}
/// Generate an encryption device
pub fn to_decrypter(&self) -> Result<Crypter, ErrorStack>
{
cha::decrypter(&self.key, &self.iv)
}
/// Generate an encryption device
pub fn to_encrypter(&self) -> Result<Crypter, ErrorStack>
{
cha::encrypter(&self.key, &self.iv)
}
/// Encrypt with RSA
pub fn to_ciphertext<K: ?Sized + rsa::PublicKey>(&self, rsa_key: &K) -> eyre::Result<RsaCiphertextBlock>
{
let mut output = [0u8; RSA_CIPHERTEXT_SIZE];
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new(); // We know size will fit into here.
serde_cbor::to_writer(&mut temp, self)
.wrap_err(eyre!("Failed to CBOR encode session key to buffer"))
.with_section(|| self.clone().header("Session key was"))?;
debug_assert!(temp.len() < RSA_CIPHERTEXT_SIZE);
let _wr = rsa::encrypt_slice_sync(&temp, rsa_key, &mut &mut output[..])
.wrap_err(eyre!("Failed to encrypt session key with RSA public key"))
.with_section(|| self.clone().header("Session key was"))
.with_section({let temp = temp.len(); move || temp.header("Encoded data size was")})
.with_section(move || base64::encode(temp).header("Encoded data (base64) was"))?;
debug_assert_eq!(_wr, output.len());
Ok(output)
}
/// Decrypt from RSA
pub fn from_ciphertext<K: ?Sized + rsa::PrivateKey>(data: &[u8; RSA_CIPHERTEXT_SIZE], rsa_key: &K) -> eyre::Result<Self>
where <K as rsa::PublicKey>::KeyType: rsa::openssl::pkey::HasPrivate //ugh, why do we have to have this bound??? it should be implied ffs... :/
{
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new();
rsa::decrypt_slice_sync(data, rsa_key, &mut temp)
.wrap_err(eyre!("Failed to decrypt ciphertext to session key"))
.with_section({let data = data.len(); move || data.header("Ciphertext length was")})
.with_section(|| base64::encode(data).header("Ciphertext was"))?;
Ok(serde_cbor::from_slice(&temp[..])
.wrap_err(eyre!("Failed to decode CBOR data to session key object"))
.with_section({let temp = temp.len(); move || temp.header("Encoded data size was")})
.with_section(move || base64::encode(temp).header("Encoded data (base64) was"))?)
}
}
/// A tx+rx socket. /// A tx+rx socket.
#[pin_project] #[pin_project]
#[derive(Debug)] #[derive(Debug)]
@ -182,7 +79,7 @@ pub struct ESock<W, R> {
impl<W: AsyncWrite, R: AsyncRead> ESock<W, R> impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
{ {
fn inner(&self) -> (&W, &R) pub fn inner(&self) -> (&W, &R)
{ {
(self.tx.inner(), self.rx.inner()) (self.tx.inner(), self.rx.inner())
} }
@ -192,6 +89,31 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
(self.tx.inner_mut(), self.rx.inner_mut()) (self.tx.inner_mut(), self.rx.inner_mut())
} }
/// Create a future that exchanges keys
pub fn exchange_unsafe(&mut self) -> Exchange<'_, W, R>
{
let us = self.info.us.get_public_parts();
todo!("Currently unimplemented")
/*
Exchange{
sock: self,
write_buf: Default::default(),
read_buf: Default::default(),
_pin: PhantomPinned,
read_state: Default::default(),
write_state: Default::default(),
us,
them: None,
us_written: Default::default(),
us_buf: (),
write_sz_num: (),
write_sz_buf: (),
read_sz_buf: (),
}*/
}
///Get a mutable ref to unencrypted read+write ///Get a mutable ref to unencrypted read+write
fn unencrypted(&mut self) -> (&mut W, &mut R) fn unencrypted(&mut self) -> (&mut W, &mut R)
{ {
@ -210,155 +132,19 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
} }
/// Is the Write + Read operation encrypted? Tuple is `(Tx, Rx)`. /// Is the Write + Read operation encrypted? Tuple is `(Tx, Rx)`.
#[inline] pub fn is_encrypted(&self) -> (bool, bool) pub fn is_encrypted(&self) -> (bool, bool)
{ {
(self.state.encw, self.state.encr) (self.state.encw, self.state.encr)
} }
/// Create a new `ESock` wrapper over this writer and reader with this specific RSA key.
pub fn with_key(key: impl Into<RsaPrivateKey>, tx: W, rx: R) -> Self
{
let (tk, tiv) = cha::keygen();
Self {
info: ESockInfo::new(key),
state: Default::default(),
// Note: These key+IV pairs are never used, as `state` defaults to unencrypted, and a new key/iv pair is generated when we `set_encrypted_write/read(true)`.
// TODO: Have a method to exchange these default session keys after `exchange()`?
tx: AsyncSink::encrypt(tx, tk, tiv).expect("Failed to create temp AsyncSink"),
rx: AsyncSource::encrypt(rx, tk, tiv).expect("Failed to create temp AsyncSource"),
}
}
/// Create a new `ESock` wrapper over this writer and reader with a newly generated private key
#[inline] pub fn new(tx: W, rx: R) -> Result<Self, rsa::Error>
{
Ok(Self::with_key(RsaPrivateKey::generate()?, tx, rx))
}
/// The local RSA private key
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
{
&self.info.us
}
/// THe remote RSA public key (if exchange has happened.)
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
{
self.info.them.as_ref()
}
/// Split this `ESock` into a read+write pair.
///
/// # Note
/// You must preform an `exchange()` before splitting, as exchanging RSA keys is not possible on a single half.
///
/// It is also more efficient to `set_encrypted_write/read(true)` on `ESock` than it is on the halves, but changinc encryption modes on halves is still possible.
pub fn split(self) -> (ESockWriteHalf<W>, ESockReadHalf<R>)
{
let arced = Arc::new(self.info);
(ESockWriteHalf(Arc::clone(&arced), self.tx, self.state.encw),
ESockReadHalf(arced, self.rx, self.state.encr))
}
/// Merge a previously split `ESock` into a single one again.
///
/// # Panics
/// If the two halves were not split from the same `ESock`.
pub fn unsplit(txh: ESockWriteHalf<W>, rxh: ESockReadHalf<R>) -> Self
{
#[cold]
#[inline(never)]
fn _panic_ptr_ineq() -> !
{
panic!("Cannot merge halves split from different sources")
}
if !Arc::ptr_eq(&txh.0, &rxh.0) {
_panic_ptr_ineq();
}
let tx = txh.1;
drop(txh.0);
let info = Arc::try_unwrap(rxh.0).unwrap();
let rx = rxh.1;
Self {
state: ESockState {
encw: txh.2,
encr: rxh.2,
},
info,
tx, rx
}
}
}
async fn set_encrypted_write_for<T: AsyncWrite + Unpin>(info: &ESockInfo, tx: &mut AsyncSink<T>) -> eyre::Result<()>
{
use tokio::prelude::*;
let session_key = ESockSessionKey::generate();
let data = {
let them = info.them.as_ref().expect("Cannot set encrypted write when keys have not been exchanged");
session_key.to_ciphertext(them)
.wrap_err(eyre!("Failed to encrypt session key with foreign endpoint's key"))
.with_section(|| session_key.to_string().header("Session key was"))
.with_section(|| them.to_string().header("Foreign pubkey was"))?
};
let crypter = session_key.to_encrypter()
.wrap_err(eyre!("Failed to create encryption device from session key for Tx"))
.with_section(|| session_key.to_string().header("Session key was"))?;
// Send rsa `data` over unencrypted endpoint
tx.inner_mut().write_all(&data[..]).await
.wrap_err(eyre!("Failed to write ciphertext to endpoint"))
.with_section(|| data.to_base64_string().header("Ciphertext of session key was"))?;
// Set crypter of `tx` to `session_key`.
*tx.crypter_mut() = crypter;
Ok(())
}
async fn set_encrypted_read_for<T: AsyncRead + Unpin>(info: &ESockInfo, rx: &mut AsyncSource<T>) -> eyre::Result<()>
{
use tokio::prelude::*;
let mut data = [0u8; RSA_CIPHERTEXT_SIZE];
// Read `data` from unencrypted endpoint
rx.inner_mut().read_exact(&mut data[..]).await
.wrap_err(eyre!("Failed to read ciphertext from endpoint"))?;
// Decrypt `data`
let session_key = ESockSessionKey::from_ciphertext(&data, &info.us)
.wrap_err(eyre!("Failed to decrypt session key from ciphertext"))
.with_section(|| data.to_base64_string().header("Ciphertext was"))
.with_section(|| info.us.to_string().header("Our RSA key is"))?;
// Set crypter of `rx` to `session_key`.
*rx.crypter_mut() = session_key.to_decrypter()
.wrap_err(eyre!("Failed to create decryption device from session key for Rx"))
.with_section(|| session_key.to_string().header("Decrypted session key was"))?;
Ok(())
} }
impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R> impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
{ {
/// Get the Tx and Rx of the stream.
///
/// # Returns
/// Returns encrypted stream halfs if the stream is encrypted, unencrypted if not.
pub fn stream(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{
(if self.state.encw {
&mut self.tx
} else {
self.tx.inner_mut()
}, if self.state.encr {
&mut self.rx
} else {
self.rx.inner_mut()
})
}
/// Enable write encryption /// Enable write encryption
pub async fn set_encrypted_write(&mut self, set: bool) -> eyre::Result<()> pub async fn set_encrypted_write(&mut self, set: bool) -> eyre::Result<()>
{ {
if set { if set {
set_encrypted_write_for(&self.info, &mut self.tx).await?; let (key, iv) = ((),());
// Set `encw` to true
self.state.encw = true; self.state.encw = true;
Ok(()) Ok(())
} else { } else {
@ -367,22 +153,6 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
} }
} }
/// Enable read encryption
///
/// The other endpoint must have sent a `set_encrypted_write()`
pub async fn set_encrypted_read(&mut self, set: bool) -> eyre::Result<()>
{
if set {
set_encrypted_read_for(&self.info, &mut self.rx).await?;
// Set `encr` to true
self.state.encr = true;
Ok(())
} else {
self.state.encr = false;
Ok(())
}
}
/// Get dynamic ref to unencrypted write+read /// Get dynamic ref to unencrypted write+read
fn unencrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_)) fn unencrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{ {
@ -405,29 +175,17 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
// Read the public key from `rx`. // Read the public key from `rx`.
//TODO: Find pubkey max size. //TODO: Find pubkey max size.
let mut sz_buf = [0u8; std::mem::size_of::<u64>()]; let mut sz_buf = [0u8; std::mem::size_of::<u64>()];
rx.read_exact(&mut sz_buf[..]).await rx.read_exact(&mut sz_buf[..]).await?;
.wrap_err(eyre!("Failed to read size of pubkey form endpoint"))?; let sz= match usize::try_from(u64::from_be_bytes(sz_buf))? {
let sz64 = u64::from_be_bytes(sz_buf); x if x > TRANS_KEY_MAX_SIZE => return Err(eyre!("Recv'd key size exceeded max")),
let sz= match usize::try_from(sz64)
.wrap_err(eyre!("Read size could not fit into u64"))
.with_section(|| format!("{:?}", sz_buf).header("Read buffer was"))
.with_section(|| u64::from_be_bytes(sz_buf).header("64=bit size value was"))
.with_warning(|| "This should not happen, it is only possible when you are running a machine with a pointer size lower than 64 bits.")
.with_suggestion(|| "The message is likely malformed. If it is not, then you are communicating with an endpoint of 64 bits whereas your pointer size is far less.")? {
x if x > TRANS_KEY_MAX_SIZE => return Err(eyre!("Recv'd key size exceeded max acceptable key buffer size")),
x => x x => x
}; };
let mut key_bytes = Vec::with_capacity(sz); let mut key_bytes = Vec::with_capacity(sz);
tokio::io::copy(&mut rx.take(sz64), &mut key_bytes).await tokio::io::copy(&mut rx.take(sz as u64), &mut key_bytes).await?;
.wrap_err("Failed to read key bytes into buffer")
.with_section(move || sz64.header("Pubkey size to read was"))?;
if key_bytes.len() != sz { if key_bytes.len() != sz {
return Err(eyre!("Could not read required bytes")); return Err(eyre!("Could not read required bytes"));
} }
let k = RsaPublicKey::from_bytes(&key_bytes) let k = RsaPublicKey::from_bytes(key_bytes)?;
.wrap_err("Failed to construct RSA public key from read bytes")
.with_section(|| sz.header("Pubkey size was"))
.with_section(move || key_bytes.to_base64_string().header("Pubkey bytes were"))?;
Result::<RsaPublicKey, eyre::Report>::Ok(k) Result::<RsaPublicKey, eyre::Report>::Ok(k)
} }
@ -435,231 +193,145 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
let write_fut = { let write_fut = {
let key_bytes = our_key.to_bytes(); let key_bytes = our_key.to_bytes();
assert!(key_bytes.len() <= TRANS_KEY_MAX_SIZE); assert!(key_bytes.len() <= TRANS_KEY_MAX_SIZE);
let sz64 = u64::try_from(key_bytes.len()) let sz_buf = u64::try_from(key_bytes.len())?.to_be_bytes();
.wrap_err(eyre!("Size of our pubkey could not fit into u64"))
.with_section(|| key_bytes.len().header("Size was"))
.with_warning(|| "This should not happen, it is only possible when you are running a machine with a pointer size larger than 64 bits.")
.with_warning(|| "There was likely internal memory corruption.")?;
let sz_buf = sz64.to_be_bytes();
async move { async move {
tx.write_all(&sz_buf[..]).await tx.write_all(&sz_buf[..]).await?;
.wrap_err(eyre!("Failed to write key size")) tx.write_all(&key_bytes[..]).await?;
.with_section(|| sz64.header("Key size bytes were"))
.with_section(|| format!("{:?}", sz_buf).header("Key size bytes (BE) were"))?;
tx.write_all(&key_bytes[..]).await
.wrap_err(eyre!("Failed to write key bytes"))
.with_section(|| sz64.header("Size of key was"))
.with_section(|| key_bytes.to_base64_string().header("Key bytes are"))?;
Result::<(), eyre::Report>::Ok(()) Result::<(), eyre::Report>::Ok(())
} }
}; };
let (send, recv) = tokio::join! [write_fut, read_fut]; let (send, recv) = tokio::join! [write_fut, read_fut];
send.wrap_err("Failed to send our pubkey")?; send?;
let recv = recv.wrap_err("Failed to receive foreign pubkey")?; let recv = recv?;
self.info.them = Some(recv); self.info.them = Some(recv);
Ok(()) Ok(())
} }
} }
//XXX: For some reason, non-exact reads + writes cause garbage to be produced on the receiving end? #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
// Is this fixable? Why does it disjoint? I have no idea... This is supposed to be a stream cipher, right? Why does positioning matter? Have I misunderstood how it workd? Eh... enum ExchangeState
// With this bug, it seems the `while read(buffer) > 0` construct is impossible. This might make this entirely useless. Hopefully with the rigid size-based format for `Message` we won't run into this problem, but subsequent data streaming will likely be affected unless we use rigid, fixed, and (inefficiently) communicated buffer sizes.
impl<W, R> AsyncWrite for ESock<W, R>
where W: AsyncWrite
{ {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> { /// We are currently reading/writing the buffer's size
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho. BufferSize,
if self.state.encw { /// We are currently reading/writing the buffer itself
self.project().tx.poll_write(cx, buf) Buffer,
} else {
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self.map_unchecked_mut(|this| this.tx.inner_mut()).poll_write(cx, buf)}
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
// Should we do anything else here?
// Should we clear foreign key/current session key?
self.project().tx.poll_shutdown(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().tx.poll_flush(cx)
}
} }
impl<W, R> AsyncRead for ESock<W, R> impl Default for ExchangeState
where R: AsyncRead
{ {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { #[inline]
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho. fn default() -> Self
if self.state.encr { {
self.project().rx.poll_read(cx, buf) Self::BufferSize
} else {
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self.map_unchecked_mut(|this| this.rx.inner_mut()).poll_read(cx, buf)}
}
} }
} }
/// Write half for `ESock`.
#[pin_project]
#[derive(Debug)]
pub struct ESockWriteHalf<W>(Arc<ESockInfo>, #[pin] AsyncSink<W>, bool);
/// Read half for `ESock`.
#[pin_project] #[pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct ESockReadHalf<R>(Arc<ESockInfo>, #[pin] AsyncSource<R>, bool); pub struct Exchange<'a, W, R>
{
sock: &'a mut ESock<W, R>,
//Impl AsyncRead/Write + set_encrypted_read/write for ESockRead/WriteHalf. us: RsaPublicKey,
impl<W: AsyncWrite> ESockWriteHalf<W> us_written: usize,
{ us_buf: Vec<u8>,
/// Does this write half have a live corresponding read half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional(&self) -> bool
{
Arc::strong_count(&self.0) > 1
}
/// Is write encrypted on this half? /// The return value
#[inline(always)] pub fn is_encrypted(&self) -> bool them: Option<RsaPublicKey>,
{
self.2
}
/// The local RSA private key
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
{
&self.0.us
}
/// THe remote RSA public key (if exchange has happened.)
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
{
self.0.them.as_ref()
}
/// End an encrypted session syncronously. write_sz_num: usize,
/// write_sz_buf: [u8; std::mem::size_of::<u64>()],
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient. read_sz_buf: [u8; std::mem::size_of::<u64>()],
pub fn clear_encryption(&mut self)
{
self.2 = false;
}
}
impl<R: AsyncRead> ESockReadHalf<R> read_buf: Vec<u8>,
{
/// Does this read half have a live corresponding write half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional(&self) -> bool
{
Arc::strong_count(&self.0) > 1
}
/// Is write encrypted on this half? write_state: ExchangeState,
#[inline(always)] pub fn is_encrypted(&self) -> bool read_state: ExchangeState,
{
self.2
}
/// The local RSA private key
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
{
&self.0.us
}
/// THe remote RSA public key (if exchange has happened.)
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
{
self.0.them.as_ref()
}
/// End an encrypted session syncronously. #[pin] _pin: PhantomPinned,
///
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
pub fn clear_encryption(&mut self)
{
self.2 = false;
}
} }
impl<W: AsyncWrite + Unpin> ESockWriteHalf<W> /*
impl<'a, W: AsyncWrite, R: AsyncRead> Future for Exchange<'a, W, R>
{ {
/// Begin or end an encrypted writing session type Output = eyre::Result<()>;
pub async fn set_encryption(&mut self, set: bool) -> eyre::Result<()> fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
{ use futures::ready;
if set { let this = self.project();
set_encrypted_write_for(&self.0, &mut self.1).await?; let (tx, rx) = {
self.2 = true; let sock = this.sock;
} else { //XXX: Idk if this is safe?
self.2 = false; unsafe {
} (Pin::new_unchecked(&mut sock.tx), Pin::new_unchecked(&mut sock.rx))
Ok(())
}
} }
};
impl<R: AsyncRead + Unpin> ESockReadHalf<R> if this.us_buf.is_empty() {
{ *this.us_buf = this.us.to_bytes();
/// Begin or end an encrypted reading session
pub async fn set_encryption(&mut self, set: bool) -> eyre::Result<()>
{
if set {
set_encrypted_read_for(&self.0, &mut self.1).await?;
self.2 = true;
} else {
self.2 = false;
} }
Ok(())
let poll_write = loop {
break match this.write_state {
ExchangeState::BufferSize => {
if *this.write_sz_num == 0 {
*this.write_sz_buf = u64::try_from(this.us_buf.len())?.to_be_bytes();
} }
// Write this to tx.
match tx.poll_write(cx, &this.write_sz_buf[(*this.write_sz_num)..]) {
x @ Poll::Ready(Ok(n)) => {
*this.write_sz_num+=n;
if *this.write_sz_num == this.write_sz_buf.len() {
// We've written all the size bytes, continue to writing the buffer bytes.
*this.write_state = ExchangeState::Buffer;
continue;
} }
impl<W: AsyncWrite> AsyncWrite for ESockWriteHalf<W> x
{ },
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> { x => x,
if self.2 {
// Encrypted
self.project().1.poll_write(cx, buf)
} else {
// Unencrypted
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self.map_unchecked_mut(|this| this.1.inner_mut()).poll_write(cx, buf)}
}
} }
#[inline(always)] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { },
self.project().1.poll_flush(cx) ExchangeState::Buffer => {
match tx.poll_write(cx, &this.us_buf[(*this.us_written)..]) {
x @ Poll::Ready(Ok(n)) => {
if *this.us_written == this.us.len() {
} }
#[inline(always)] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().1.poll_flush(cx) x
},
x=> x,
} }
},
} }
};
let poll_read = match this.read_state {
ExchangeState::BufferSize => {
impl<R: AsyncRead> AsyncRead for ESockReadHalf<R> },
{ ExchangeState::Buffer => {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
if self.2 {
// Encrypted
self.project().1.poll_read(cx, buf)
} else {
// Unencrypted
// SAFETY: Uhh... well I think this is fine? Because we can project the container. },
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`? };
unsafe { self.map_unchecked_mut(|this| this.1.inner_mut()).poll_read(cx, buf)} todo!("This is going to be dificult to implement... We don't have access to write_all and read_exact")
}
} }
} }
*/
/// Write half for `ESock`.
#[pin_project]
#[derive(Debug)]
pub struct ESockWriteHalf<W>(Arc<(ESockInfo, RwLock<ESockState>)>, #[pin] AsyncSink<W>);
/// Read half for `ESock`.
#[pin_project]
#[derive(Debug)]
pub struct ESockReadHalf<R>(Arc<(ESockInfo, RwLock<ESockState>)>, #[pin] AsyncSource<R>);
#[cfg(test)] #[cfg(test)]
mod tests mod tests
{ {
use super::ESock;
#[test] #[test]
fn rsa_ciphertext_len() -> crate::eyre::Result<()> fn rsa_ciphertext_len() -> crate::eyre::Result<()>
{ {
@ -699,176 +371,4 @@ mod tests
Ok(()) Ok(())
} }
fn gen_duplex_esock(bufsz: usize) -> crate::eyre::Result<(ESock<tokio::io::DuplexStream, tokio::io::DuplexStream>, ESock<tokio::io::DuplexStream, tokio::io::DuplexStream>)>
{
use crate::*;
let (atx, brx) = tokio::io::duplex(bufsz);
let (btx, arx) = tokio::io::duplex(bufsz);
let tx = ESock::new(atx, arx).wrap_err(eyre!("Failed to create TX"))?;
let rx = ESock::new(btx, brx).wrap_err(eyre!("Failed to create RX"))?;
Ok((tx, rx))
}
#[tokio::test]
async fn esock_exchange() -> crate::eyre::Result<()>
{
use crate::*;
const VALUE: &'static [u8] = b"Hello world!";
// The duplex buffer size here is smaller than an RSA ciphertext block. So, writing the session key must be buffered with a buffer size this small (should return Pending at least once.)
// Using a low buffer size to make sure the test passes even when the entire buffer cannot be written at once.
let (mut tx, mut rx) = gen_duplex_esock(16).wrap_err(eyre!("Failed to weave socks"))?;
let writer = tokio::spawn(async move {
use tokio::prelude::*;
tx.exchange().await?;
assert!(tx.has_exchanged());
tx.set_encrypted_write(true).await?;
assert_eq!((true, false), tx.is_encrypted());
tx.write_all(VALUE).await?;
tx.write_all(VALUE).await?;
// Check resp
tx.set_encrypted_read(true).await?;
assert_eq!({
let mut chk = [0u8; 3];
tx.read_exact(&mut chk[..]).await?;
chk
}, [0xaau8,0, 0], "Failed response check");
// Write unencrypted
tx.set_encrypted_write(false).await?;
tx.write_all(&[2,1,0xfa]).await?;
Result::<_, eyre::Report>::Ok(VALUE)
});
let reader = tokio::spawn(async move {
use tokio::prelude::*;
rx.exchange().await?;
assert!(rx.has_exchanged());
rx.set_encrypted_read(true).await?;
assert_eq!((false, true), rx.is_encrypted());
let mut val = vec![0u8; VALUE.len()];
rx.read_exact(&mut val[..]).await?;
let mut val2 = vec![0u8; VALUE.len()];
rx.read_exact(&mut val2[..]).await?;
assert_eq!(val, val2);
// Send resp
rx.set_encrypted_write(true).await?;
rx.write_all(&[0xaa, 0, 0]).await?;
// Read unencrypted
rx.set_encrypted_read(false).await?;
assert_eq!({
let mut buf = [0u8; 3];
rx.read_exact(&mut buf[..]).await?;
buf
}, [2u8,1,0xfa], "2nd response incorrect");
Result::<_, eyre::Report>::Ok(val)
});
let (writer, reader) = tokio::join![writer, reader];
let writer = writer.expect("Tx task panic");
let reader = reader.expect("Rx task panic");
eprintln!("Txr: {:?}", writer);
eprintln!("Rxr: {:?}", reader);
writer?;
let val = reader?;
println!("Read: {:?}", val);
assert_eq!(&val, VALUE);
Ok(())
}
#[tokio::test]
async fn esock_split() -> crate::eyre::Result<()>
{
use super::*;
const SLICES: &'static [&'static [u8]] = &[
&[1,5,3,7,6,9,100,0],
&[7,6,2,90],
&[3,6,1,0],
&[5,1,3,3],
];
let result = SLICES.iter().map(|&slice| slice.iter().map(|&b| u64::from(b)).sum::<u64>()).sum::<u64>();
println!("Result: {}", result);
let (mut tx, mut rx) = gen_duplex_esock(super::TRANS_KEY_MAX_SIZE * 4).wrap_err(eyre!("Failed to weave socks"))?;
let (writer, reader) = {
use tokio::prelude::*;
let writer = tokio::spawn(async move {
tx.exchange().await?;
let (mut tx, mut rx) = tx.split();
//tx.set_encryption(true).await?;
let slices = &SLICES[1..];
for &slice in slices.iter()
{
println!("Writing slice: {:?}", slice);
tx.write_all(slice).await?;
}
//let mut tx = ESock::unsplit(tx, rx);
tx.write_all(SLICES[0]).await?;
Result::<_, eyre::Report>::Ok(())
});
let reader = tokio::spawn(async move {
rx.exchange().await?;
let (mut tx, mut rx) = rx.split();
//rx.set_encryption(true).await?;
let (mut mtx, mut mrx) = tokio::sync::mpsc::channel::<Vec<u8>>(16);
let sorter = tokio::spawn(async move {
let mut done = 0u64;
while let Some(buf) = mrx.recv().await
{
//buf.sort();
done += buf.iter().map(|&b| u64::from(b)).sum::<u64>();
println!("Got buffer: {:?}", buf);
tx.write_all(&buf).await?;
}
Result::<_, eyre::Report>::Ok(done)
});
let mut buffer = [0u8; 16];
while let Ok(read) = rx.read(&mut buffer[..]).await
{
if read == 0 {
break;
}
mtx.send(Vec::from(&buffer[..read])).await?;
}
drop(mtx);
let sum = sorter.await.expect("(reader) Sorter task panic")?;
Result::<_, eyre::Report>::Ok(sum)
});
let (writer, reader) = tokio::join![writer, reader];
(writer.expect("Writer task panic"),
reader.expect("Reader task panic"))
};
writer?;
assert_eq!(result, reader?);
Ok(())
}
} }

@ -1,53 +0,0 @@
//! Piping buffered data from a raw socket to `ESock`
//!
//! This exists because i'm too dumb to implement a functional AsyncRead/Write buffered wrapper stream :/
use super::*;
use std::{
io,
marker::{
Send, Sync,
Unpin,
PhantomData,
},
};
use tokio::sync::{
mpsc,
};
use enc::{
ESock,
ESockReadHalf,
ESockWriteHalf,
};
/// The default buffer size for `BufferedESock`.
pub const DEFAULT_BUFFER_SIZE: usize = 32;
/// Task-based buffered piping to/from encrypted sockets.
pub struct BufferedESock<W, R>
{
bufsz: usize,
_backing: PhantomData<ESock<W, R>>,
}
impl<W, R> BufferedESock<W, R>
where W: AsyncWrite + Unpin + Send + 'static,
R: AsyncRead + Unpin + Send + 'static
{
/// Create a new buffered ESock pipe with a specific buffer size
pub fn with_size(tx: W, rx: R, bufsz: usize) -> Self
{
//TODO: Spawn read+write buffer tasks
Self {
bufsz,
_backing: PhantomData,
}
}
/// Create a new buffered ESock pipe with the default buffer size (`DEFAULT_BUFFER_SIZE`).
#[inline] pub fn new(tx: W, rx: R) -> Self
{
Self::with_size(tx, rx, DEFAULT_BUFFER_SIZE)
}
}
Loading…
Cancel
Save