diff --git a/Cargo.lock b/Cargo.lock index b512b73..7dd93ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -799,6 +799,7 @@ name = "rsh" version = "0.1.0" dependencies = [ "ad-hoc-iter", + "base64 0.13.0", "bytes 1.0.1", "chacha20stream", "color-eyre", diff --git a/Cargo.toml b/Cargo.toml index d29c398..3271d69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2018" [dependencies] ad-hoc-iter = "0.2.3" +base64 = "0.13.0" bytes = { version = "1.0.1", features = ["serde"] } chacha20stream = { version = "2.1.0", features = ["async", "serde"] } color-eyre = "0.5.11" diff --git a/src/ext/alloc.rs b/src/ext/alloc.rs new file mode 100644 index 0000000..96613c2 --- /dev/null +++ b/src/ext/alloc.rs @@ -0,0 +1,91 @@ +//! Stack allocation helpers +use super::*; + +/// Max size of memory allowed to be allocated on the stack. +pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB + +/// A stack-allocated vector that spills onto the heap when needed. +pub type StackVec = SmallVec<[T; STACK_MEM_ALLOC_MAX]>; + +/// Allocate a local buffer initialised from `init`. +pub fn alloc_local_with(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory: Vec = iter::repeat_with(init).take(sz).collect(); + within(&mut memory[..]) + } else { + stackalloc::stackalloc_with(sz, init, within) + } +} + + + +/// Allocate a local zero-initialised byte buffer +pub fn alloc_local_bytes(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory: Vec> = vec_uninit(sz); + within(unsafe { + std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz); + stackalloc::helpers::slice_assume_init_mut(&mut memory[..]) + }) + } else { + stackalloc::alloca_zeroed(sz, within) + } +} + + +/// Allocate a local zero-initialised buffer +pub fn alloc_local_zeroed(sz: usize, within: impl FnOnce(&mut [MaybeUninit]) -> U) -> U +{ + let sz_bytes = mem::size_of::() * sz; + if sz > STACK_MEM_ALLOC_MAX { + let mut memory = vec_uninit(sz); + unsafe { + std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes); + } + within(&mut memory[..]) + } else { + stackalloc::alloca_zeroed(sz_bytes, move |buf| { + unsafe { + debug_assert_eq!(buf.len() / mem::size_of::(), sz); + within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit, sz)) + } + }) + } +} + + +/// Allocate a local uninitialised buffer +pub fn alloc_local_uninit(sz: usize, within: impl FnOnce(&mut [MaybeUninit]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory = vec_uninit(sz); + within(&mut memory[..]) + } else { + stackalloc::stackalloc_uninit(sz, within) + } +} + + +/// Allocate a local buffer initialised with `init`. +pub fn alloc_local(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory: Vec = iter::repeat(init).take(sz).collect(); + within(&mut memory[..]) + } else { + stackalloc::stackalloc(sz, init, within) + } +} + +/// Allocate a local buffer initialised with `T::default()`. +pub fn alloc_local_with_default(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory: Vec = iter::repeat_with(Default::default).take(sz).collect(); + within(&mut memory[..]) + } else { + stackalloc::stackalloc_with_default(sz, within) + } +} diff --git a/src/ext/hex.rs b/src/ext/hex.rs new file mode 100644 index 0000000..1c1f277 --- /dev/null +++ b/src/ext/hex.rs @@ -0,0 +1,124 @@ +use std::{ + mem, + iter::{ + self, + ExactSizeIterator, + FusedIterator, + }, + slice, + fmt, +}; +#[derive(Debug, Clone)] +pub struct HexStringIter(I, [u8; 2]); + +impl> HexStringIter +{ + /// Write this hex string iterator to a formattable buffer + pub fn consume(self, f: &mut F) -> fmt::Result + where F: std::fmt::Write + { + if self.1[0] != 0 { + write!(f, "{}", self.1[0] as char)?; + } + if self.1[1] != 0 { + write!(f, "{}", self.1[1] as char)?; + } + + for x in self.0 { + write!(f, "{:02x}", x)?; + } + + Ok(()) + } + + /// Consume into a string + pub fn into_string(self) -> String + { + let mut output = match self.size_hint() { + (0, None) => String::new(), + (_, Some(x)) | + (x, None) => String::with_capacity(x), + }; + self.consume(&mut output).unwrap(); + output + } +} + +pub trait HexStringIterExt: Sized +{ + fn into_hex(self) -> HexStringIter; +} + +pub type HexStringSliceIter<'a> = HexStringIter>>; + +pub trait HexStringSliceIterExt +{ + fn hex(&self) -> HexStringSliceIter<'_>; +} + +impl HexStringSliceIterExt for S +where S: AsRef<[u8]> +{ + fn hex(&self) -> HexStringSliceIter<'_> + { + self.as_ref().iter().copied().into_hex() + } +} + +impl> HexStringIterExt for I +{ + #[inline] fn into_hex(self) -> HexStringIter { + HexStringIter(self.into_iter(), [0u8; 2]) + } +} + +impl> Iterator for HexStringIter +{ + type Item = char; + fn next(&mut self) -> Option + { + match self.1 { + [_, 0] => { + use std::io::Write; + write!(&mut self.1[..], "{:02x}", self.0.next()?).unwrap(); + + Some(mem::replace(&mut self.1[0], 0) as char) + }, + [0, _] => Some(mem::replace(&mut self.1[1], 0) as char), + _ => unreachable!(), + } + } + + fn size_hint(&self) -> (usize, Option) { + let (l, h) = self.0.size_hint(); + + (l * 2, h.map(|x| x*2)) + } +} + +impl + ExactSizeIterator> ExactSizeIterator for HexStringIter{} +impl + FusedIterator> FusedIterator for HexStringIter{} + +impl> From> for String +{ + fn from(from: HexStringIter) -> Self + { + from.into_string() + } +} + +impl + Clone> fmt::Display for HexStringIter +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + self.clone().consume(f) + } +} + +/* +#[macro_export] macro_rules! prog1 { + ($first:expr, $($rest:expr);+ $(;)?) => { + ($first, $( $rest ),+).0 + } +} +*/ diff --git a/src/ext.rs b/src/ext/mod.rs similarity index 53% rename from src/ext.rs rename to src/ext/mod.rs index dc648b5..84bd83a 100644 --- a/src/ext.rs +++ b/src/ext/mod.rs @@ -4,11 +4,11 @@ use std::mem::{self, MaybeUninit}; use std::iter; use smallvec::SmallVec; -/// Max size of memory allowed to be allocated on the stack. -pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB +mod alloc; +pub use alloc::*; -/// A stack-allocated vector that spills onto the heap when needed. -pub type StackVec = SmallVec<[T; STACK_MEM_ALLOC_MAX]>; +mod hex; +pub use hex::*; /// A maybe-atom that can spill into a vector. pub type MaybeVec = SmallVec<[T; 1]>; @@ -23,89 +23,6 @@ pub fn vec_uninit(sz: usize) -> Vec> } } -/// Allocate a local buffer initialised from `init`. -pub fn alloc_local_with(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U -{ - if sz > STACK_MEM_ALLOC_MAX { - let mut memory: Vec = iter::repeat_with(init).take(sz).collect(); - within(&mut memory[..]) - } else { - stackalloc::stackalloc_with(sz, init, within) - } -} - - - -/// Allocate a local zero-initialised byte buffer -pub fn alloc_local_bytes(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U -{ - if sz > STACK_MEM_ALLOC_MAX { - let mut memory: Vec> = vec_uninit(sz); - within(unsafe { - std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz); - stackalloc::helpers::slice_assume_init_mut(&mut memory[..]) - }) - } else { - stackalloc::alloca_zeroed(sz, within) - } -} - - -/// Allocate a local zero-initialised buffer -pub fn alloc_local_zeroed(sz: usize, within: impl FnOnce(&mut [MaybeUninit]) -> U) -> U -{ - let sz_bytes = mem::size_of::() * sz; - if sz > STACK_MEM_ALLOC_MAX { - let mut memory = vec_uninit(sz); - unsafe { - std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes); - } - within(&mut memory[..]) - } else { - stackalloc::alloca_zeroed(sz_bytes, move |buf| { - unsafe { - debug_assert_eq!(buf.len() / mem::size_of::(), sz); - within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit, sz)) - } - }) - } -} - - -/// Allocate a local uninitialised buffer -pub fn alloc_local_uninit(sz: usize, within: impl FnOnce(&mut [MaybeUninit]) -> U) -> U -{ - if sz > STACK_MEM_ALLOC_MAX { - let mut memory = vec_uninit(sz); - within(&mut memory[..]) - } else { - stackalloc::stackalloc_uninit(sz, within) - } -} - - -/// Allocate a local buffer initialised with `init`. -pub fn alloc_local(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U -{ - if sz > STACK_MEM_ALLOC_MAX { - let mut memory: Vec = iter::repeat(init).take(sz).collect(); - within(&mut memory[..]) - } else { - stackalloc::stackalloc(sz, init, within) - } -} - -/// Allocate a local buffer initialised with `T::default()`. -pub fn alloc_local_with_default(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U -{ - if sz > STACK_MEM_ALLOC_MAX { - let mut memory: Vec = iter::repeat_with(Default::default).take(sz).collect(); - within(&mut memory[..]) - } else { - stackalloc::stackalloc_with_default(sz, within) - } -} - /// Create a blanket-implementing trait that is a subtrait of any number of traits. /// /// # Usage diff --git a/src/sock/enc.rs b/src/sock/enc.rs index 892eb05..c068418 100644 --- a/src/sock/enc.rs +++ b/src/sock/enc.rs @@ -2,8 +2,14 @@ use super::*; use cryptohelpers::{ rsa::{ + self, RsaPublicKey, RsaPrivateKey, + + openssl::{ + symm::Crypter, + error::ErrorStack, + }, }, sha256, }; @@ -12,6 +18,8 @@ use chacha20stream::{ AsyncSource, Key, IV, + + cha, }; use std::sync::Arc; use tokio::{ @@ -20,25 +28,24 @@ use tokio::{ RwLockReadGuard, RwLockWriteGuard, }, - io::{ - DuplexStream, - }, }; use std::{ io, + fmt, task::{ Context, Poll, }, pin::Pin, - marker::{ - Unpin, - PhantomPinned, - }, + marker::Unpin, }; +use smallvec::SmallVec; /// Size of a single RSA ciphertext. pub const RSA_CIPHERTEXT_SIZE: usize = 512; +/// A single, full block of RSA ciphertext. +type RsaCiphertextBlock = [u8; RSA_CIPHERTEXT_SIZE]; + /// Max size to read when exchanging keys const TRANS_KEY_MAX_SIZE: usize = 4096; @@ -55,7 +62,7 @@ struct ESockState { encw: bool, } -/// Contains a Key and IV that can be serialized and then encrypted +/// Contains a cc20 Key and IV that can be serialized and then encrypted #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] struct ESockSessionKey { @@ -63,6 +70,67 @@ struct ESockSessionKey iv: IV, } +impl fmt::Display for ESockSessionKey +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "Key: {}, IV: {}", self.key.hex(), self.iv.hex()) + } +} + + +impl ESockSessionKey +{ + /// Generate a new cc20 key + iv, + pub fn generate() -> Self + { + let (key,iv) = cha::keygen(); + Self{key,iv} + } + + /// Generate an encryption device + pub fn to_decrypter(&self) -> Result + { + cha::decrypter(&self.key, &self.iv) + } + + /// Generate an encryption device + pub fn to_encrypter(&self) -> Result + { + cha::encrypter(&self.key, &self.iv) + } + + /// Encrypt with RSA + pub fn to_ciphertext(&self, rsa_key: &K) -> eyre::Result + { + let mut output = [0u8; RSA_CIPHERTEXT_SIZE]; + let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new(); // We know size will fit into here. + serde_cbor::to_writer(&mut temp, self) + .wrap_err(eyre!("Failed to CBOR encode session key to buffer")) + .with_section(|| self.clone().header("Session key was"))?; + debug_assert!(temp.len() < RSA_CIPHERTEXT_SIZE); + + let _wr = rsa::encrypt_slice_sync(&temp, rsa_key, &mut &mut output[..]) + .wrap_err(eyre!("Failed to encrypt session key with RSA public key")) + .with_section(|| self.clone().header("Session key was")) + .with_section({let temp = temp.len(); move || temp.header("Encoded data size was")}) + .with_section(move || base64::encode(temp).header("Encoded data (base64) was")) + ?; + debug_assert_eq!(_wr, output.len()); + + Ok(output) + } + + /// Decrypt from RSA + pub fn from_ciphertext(data: &[u8; RSA_CIPHERTEXT_SIZE], rsa_key: &K) -> eyre::Result + where ::KeyType: rsa::openssl::pkey::HasPrivate //ugh, why do we have to have this bound??? it should be implied ffs... :/ + { + let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new(); + rsa::decrypt_slice_sync(data, rsa_key, &mut temp)?; + Ok(serde_cbor::from_slice(&temp[..])?) + } +} + /// A tx+rx socket. #[pin_project] #[derive(Debug)] @@ -88,32 +156,7 @@ impl ESock { (self.tx.inner_mut(), self.rx.inner_mut()) } - - /// Create a future that exchanges keys - pub fn exchange_unsafe(&mut self) -> Exchange<'_, W, R> - { - let us = self.info.us.get_public_parts(); - todo!("Currently unimplemented") - /* - Exchange{ - sock: self, - write_buf: Default::default(), - read_buf: Default::default(), - _pin: PhantomPinned, - - read_state: Default::default(), - write_state: Default::default(), - - us, - them: None, - us_written: Default::default(), - us_buf: (), - write_sz_num: (), - write_sz_buf: (), - read_sz_buf: (), - }*/ - } - + ///Get a mutable ref to unencrypted read+write fn unencrypted(&mut self) -> (&mut W, &mut R) { @@ -143,8 +186,16 @@ impl ESock /// Enable write encryption pub async fn set_encrypted_write(&mut self, set: bool) -> eyre::Result<()> { + use tokio::prelude::*; if set { - let (key, iv) = ((),()); + let session_key = ESockSessionKey::generate(); + let data = session_key.to_ciphertext(self.info.them.as_ref().expect("Cannot set encrypted write when keys have not been exchanged"))?; + let crypter = session_key.to_encrypter()?; + // Send rsa `data` over unencrypted endpoint + self.unencrypted().0.write_all(&data[..]).await?; + // Set crypter of `tx` to `session_key`. + *self.tx.crypter_mut() = crypter; + // Set `encw` to true self.state.encw = true; Ok(()) } else { @@ -152,6 +203,30 @@ impl ESock Ok(()) } } + + /// Enable read encryption + /// + /// The other endpoint must have sent a `set_encrypted_write()` + pub async fn set_encrypted_read(&mut self, set: bool) -> eyre::Result<()> + { + use tokio::prelude::*; + if set { + let mut data = [0u8; RSA_CIPHERTEXT_SIZE]; + // Read `data` from unencrypted endpoint + self.unencrypted().1.read_exact(&mut data[..]).await?; + // Decrypt `data` + let session_key = ESockSessionKey::from_ciphertext(&data, &self.info.us)?; + // Set crypter of `rx` to `session_key`. + *self.rx.crypter_mut() = session_key.to_decrypter()?; + // Set `encr` to true + self.state.encr = true; + Ok(()) + } else { + + self.state.encr = false; + Ok(()) + } + } /// Get dynamic ref to unencrypted write+read fn unencrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_)) @@ -208,117 +283,6 @@ impl ESock } } -#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)] -enum ExchangeState -{ - /// We are currently reading/writing the buffer's size - BufferSize, - /// We are currently reading/writing the buffer itself - Buffer, -} - -impl Default for ExchangeState -{ - #[inline] - fn default() -> Self - { - Self::BufferSize - } -} - - -#[pin_project] -#[derive(Debug)] -pub struct Exchange<'a, W, R> -{ - sock: &'a mut ESock, - - us: RsaPublicKey, - - us_written: usize, - us_buf: Vec, - - /// The return value - them: Option, - - write_sz_num: usize, - write_sz_buf: [u8; std::mem::size_of::()], - read_sz_buf: [u8; std::mem::size_of::()], - - read_buf: Vec, - - write_state: ExchangeState, - read_state: ExchangeState, - - #[pin] _pin: PhantomPinned, -} - -/* -impl<'a, W: AsyncWrite, R: AsyncRead> Future for Exchange<'a, W, R> -{ -type Output = eyre::Result<()>; -fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -use futures::ready; -let this = self.project(); -let (tx, rx) = { -let sock = this.sock; -//XXX: Idk if this is safe? -unsafe { -(Pin::new_unchecked(&mut sock.tx), Pin::new_unchecked(&mut sock.rx)) - } - }; - - if this.us_buf.is_empty() { - *this.us_buf = this.us.to_bytes(); - } - - let poll_write = loop { - break match this.write_state { - ExchangeState::BufferSize => { - if *this.write_sz_num == 0 { - *this.write_sz_buf = u64::try_from(this.us_buf.len())?.to_be_bytes(); - } - // Write this to tx. - match tx.poll_write(cx, &this.write_sz_buf[(*this.write_sz_num)..]) { - x @ Poll::Ready(Ok(n)) => { - *this.write_sz_num+=n; - if *this.write_sz_num == this.write_sz_buf.len() { - // We've written all the size bytes, continue to writing the buffer bytes. - *this.write_state = ExchangeState::Buffer; - continue; - } - - x - }, - x => x, - } - }, - ExchangeState::Buffer => { - match tx.poll_write(cx, &this.us_buf[(*this.us_written)..]) { - x @ Poll::Ready(Ok(n)) => { - if *this.us_written == this.us.len() { - - } - - x - }, - x=> x, - } - }, - } - }; - let poll_read = match this.read_state { - ExchangeState::BufferSize => { - - }, - ExchangeState::Buffer => { - - }, - }; - todo!("This is going to be dificult to implement... We don't have access to write_all and read_exact") - } -} -*/ /// Write half for `ESock`. #[pin_project] #[derive(Debug)]