Added better error messages and context to `ESockSessionKey::to_ciphertext()`

Fortune for rsh's current commit: Blessing − 吉
master
Avril 3 years ago
parent 3a5331b5f1
commit 3991d82f93
Signed by: flanchan
GPG Key ID: 284488987C31F630

1
Cargo.lock generated

@ -799,6 +799,7 @@ name = "rsh"
version = "0.1.0"
dependencies = [
"ad-hoc-iter",
"base64 0.13.0",
"bytes 1.0.1",
"chacha20stream",
"color-eyre",

@ -7,6 +7,7 @@ edition = "2018"
[dependencies]
ad-hoc-iter = "0.2.3"
base64 = "0.13.0"
bytes = { version = "1.0.1", features = ["serde"] }
chacha20stream = { version = "2.1.0", features = ["async", "serde"] }
color-eyre = "0.5.11"

@ -0,0 +1,91 @@
//! Stack allocation helpers
use super::*;
/// Max size of memory allowed to be allocated on the stack.
pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB
/// A stack-allocated vector that spills onto the heap when needed.
pub type StackVec<T> = SmallVec<[T; STACK_MEM_ALLOC_MAX]>;
/// Allocate a local buffer initialised from `init`.
pub fn alloc_local_with<T, U>(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with(sz, init, within)
}
}
/// Allocate a local zero-initialised byte buffer
pub fn alloc_local_bytes<U>(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<MaybeUninit<u8>> = vec_uninit(sz);
within(unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz);
stackalloc::helpers::slice_assume_init_mut(&mut memory[..])
})
} else {
stackalloc::alloca_zeroed(sz, within)
}
}
/// Allocate a local zero-initialised buffer
pub fn alloc_local_zeroed<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
let sz_bytes = mem::size_of::<T>() * sz;
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes);
}
within(&mut memory[..])
} else {
stackalloc::alloca_zeroed(sz_bytes, move |buf| {
unsafe {
debug_assert_eq!(buf.len() / mem::size_of::<T>(), sz);
within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<T>, sz))
}
})
}
}
/// Allocate a local uninitialised buffer
pub fn alloc_local_uninit<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
within(&mut memory[..])
} else {
stackalloc::stackalloc_uninit(sz, within)
}
}
/// Allocate a local buffer initialised with `init`.
pub fn alloc_local<T: Clone, U>(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc(sz, init, within)
}
}
/// Allocate a local buffer initialised with `T::default()`.
pub fn alloc_local_with_default<T: Default, U>(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(Default::default).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with_default(sz, within)
}
}

@ -0,0 +1,124 @@
use std::{
mem,
iter::{
self,
ExactSizeIterator,
FusedIterator,
},
slice,
fmt,
};
#[derive(Debug, Clone)]
pub struct HexStringIter<I>(I, [u8; 2]);
impl<I: Iterator<Item = u8>> HexStringIter<I>
{
/// Write this hex string iterator to a formattable buffer
pub fn consume<F>(self, f: &mut F) -> fmt::Result
where F: std::fmt::Write
{
if self.1[0] != 0 {
write!(f, "{}", self.1[0] as char)?;
}
if self.1[1] != 0 {
write!(f, "{}", self.1[1] as char)?;
}
for x in self.0 {
write!(f, "{:02x}", x)?;
}
Ok(())
}
/// Consume into a string
pub fn into_string(self) -> String
{
let mut output = match self.size_hint() {
(0, None) => String::new(),
(_, Some(x)) |
(x, None) => String::with_capacity(x),
};
self.consume(&mut output).unwrap();
output
}
}
pub trait HexStringIterExt<I>: Sized
{
fn into_hex(self) -> HexStringIter<I>;
}
pub type HexStringSliceIter<'a> = HexStringIter<iter::Copied<slice::Iter<'a, u8>>>;
pub trait HexStringSliceIterExt
{
fn hex(&self) -> HexStringSliceIter<'_>;
}
impl<S> HexStringSliceIterExt for S
where S: AsRef<[u8]>
{
fn hex(&self) -> HexStringSliceIter<'_>
{
self.as_ref().iter().copied().into_hex()
}
}
impl<I: IntoIterator<Item=u8>> HexStringIterExt<I::IntoIter> for I
{
#[inline] fn into_hex(self) -> HexStringIter<I::IntoIter> {
HexStringIter(self.into_iter(), [0u8; 2])
}
}
impl<I: Iterator<Item = u8>> Iterator for HexStringIter<I>
{
type Item = char;
fn next(&mut self) -> Option<Self::Item>
{
match self.1 {
[_, 0] => {
use std::io::Write;
write!(&mut self.1[..], "{:02x}", self.0.next()?).unwrap();
Some(mem::replace(&mut self.1[0], 0) as char)
},
[0, _] => Some(mem::replace(&mut self.1[1], 0) as char),
_ => unreachable!(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (l, h) = self.0.size_hint();
(l * 2, h.map(|x| x*2))
}
}
impl<I: Iterator<Item = u8> + ExactSizeIterator> ExactSizeIterator for HexStringIter<I>{}
impl<I: Iterator<Item = u8> + FusedIterator> FusedIterator for HexStringIter<I>{}
impl<I: Iterator<Item = u8>> From<HexStringIter<I>> for String
{
fn from(from: HexStringIter<I>) -> Self
{
from.into_string()
}
}
impl<I: Iterator<Item = u8> + Clone> fmt::Display for HexStringIter<I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
self.clone().consume(f)
}
}
/*
#[macro_export] macro_rules! prog1 {
($first:expr, $($rest:expr);+ $(;)?) => {
($first, $( $rest ),+).0
}
}
*/

@ -4,11 +4,11 @@ use std::mem::{self, MaybeUninit};
use std::iter;
use smallvec::SmallVec;
/// Max size of memory allowed to be allocated on the stack.
pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB
mod alloc;
pub use alloc::*;
/// A stack-allocated vector that spills onto the heap when needed.
pub type StackVec<T> = SmallVec<[T; STACK_MEM_ALLOC_MAX]>;
mod hex;
pub use hex::*;
/// A maybe-atom that can spill into a vector.
pub type MaybeVec<T> = SmallVec<[T; 1]>;
@ -23,89 +23,6 @@ pub fn vec_uninit<T>(sz: usize) -> Vec<MaybeUninit<T>>
}
}
/// Allocate a local buffer initialised from `init`.
pub fn alloc_local_with<T, U>(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with(sz, init, within)
}
}
/// Allocate a local zero-initialised byte buffer
pub fn alloc_local_bytes<U>(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<MaybeUninit<u8>> = vec_uninit(sz);
within(unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz);
stackalloc::helpers::slice_assume_init_mut(&mut memory[..])
})
} else {
stackalloc::alloca_zeroed(sz, within)
}
}
/// Allocate a local zero-initialised buffer
pub fn alloc_local_zeroed<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
let sz_bytes = mem::size_of::<T>() * sz;
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes);
}
within(&mut memory[..])
} else {
stackalloc::alloca_zeroed(sz_bytes, move |buf| {
unsafe {
debug_assert_eq!(buf.len() / mem::size_of::<T>(), sz);
within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<T>, sz))
}
})
}
}
/// Allocate a local uninitialised buffer
pub fn alloc_local_uninit<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
within(&mut memory[..])
} else {
stackalloc::stackalloc_uninit(sz, within)
}
}
/// Allocate a local buffer initialised with `init`.
pub fn alloc_local<T: Clone, U>(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc(sz, init, within)
}
}
/// Allocate a local buffer initialised with `T::default()`.
pub fn alloc_local_with_default<T: Default, U>(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(Default::default).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with_default(sz, within)
}
}
/// Create a blanket-implementing trait that is a subtrait of any number of traits.
///
/// # Usage

@ -2,8 +2,14 @@
use super::*;
use cryptohelpers::{
rsa::{
self,
RsaPublicKey,
RsaPrivateKey,
openssl::{
symm::Crypter,
error::ErrorStack,
},
},
sha256,
};
@ -12,6 +18,8 @@ use chacha20stream::{
AsyncSource,
Key, IV,
cha,
};
use std::sync::Arc;
use tokio::{
@ -20,25 +28,24 @@ use tokio::{
RwLockReadGuard,
RwLockWriteGuard,
},
io::{
DuplexStream,
},
};
use std::{
io,
fmt,
task::{
Context, Poll,
},
pin::Pin,
marker::{
Unpin,
PhantomPinned,
},
marker::Unpin,
};
use smallvec::SmallVec;
/// Size of a single RSA ciphertext.
pub const RSA_CIPHERTEXT_SIZE: usize = 512;
/// A single, full block of RSA ciphertext.
type RsaCiphertextBlock = [u8; RSA_CIPHERTEXT_SIZE];
/// Max size to read when exchanging keys
const TRANS_KEY_MAX_SIZE: usize = 4096;
@ -55,7 +62,7 @@ struct ESockState {
encw: bool,
}
/// Contains a Key and IV that can be serialized and then encrypted
/// Contains a cc20 Key and IV that can be serialized and then encrypted
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct ESockSessionKey
{
@ -63,6 +70,67 @@ struct ESockSessionKey
iv: IV,
}
impl fmt::Display for ESockSessionKey
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "Key: {}, IV: {}", self.key.hex(), self.iv.hex())
}
}
impl ESockSessionKey
{
/// Generate a new cc20 key + iv,
pub fn generate() -> Self
{
let (key,iv) = cha::keygen();
Self{key,iv}
}
/// Generate an encryption device
pub fn to_decrypter(&self) -> Result<Crypter, ErrorStack>
{
cha::decrypter(&self.key, &self.iv)
}
/// Generate an encryption device
pub fn to_encrypter(&self) -> Result<Crypter, ErrorStack>
{
cha::encrypter(&self.key, &self.iv)
}
/// Encrypt with RSA
pub fn to_ciphertext<K: ?Sized + rsa::PublicKey>(&self, rsa_key: &K) -> eyre::Result<RsaCiphertextBlock>
{
let mut output = [0u8; RSA_CIPHERTEXT_SIZE];
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new(); // We know size will fit into here.
serde_cbor::to_writer(&mut temp, self)
.wrap_err(eyre!("Failed to CBOR encode session key to buffer"))
.with_section(|| self.clone().header("Session key was"))?;
debug_assert!(temp.len() < RSA_CIPHERTEXT_SIZE);
let _wr = rsa::encrypt_slice_sync(&temp, rsa_key, &mut &mut output[..])
.wrap_err(eyre!("Failed to encrypt session key with RSA public key"))
.with_section(|| self.clone().header("Session key was"))
.with_section({let temp = temp.len(); move || temp.header("Encoded data size was")})
.with_section(move || base64::encode(temp).header("Encoded data (base64) was"))
?;
debug_assert_eq!(_wr, output.len());
Ok(output)
}
/// Decrypt from RSA
pub fn from_ciphertext<K: ?Sized + rsa::PrivateKey>(data: &[u8; RSA_CIPHERTEXT_SIZE], rsa_key: &K) -> eyre::Result<Self>
where <K as rsa::PublicKey>::KeyType: rsa::openssl::pkey::HasPrivate //ugh, why do we have to have this bound??? it should be implied ffs... :/
{
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new();
rsa::decrypt_slice_sync(data, rsa_key, &mut temp)?;
Ok(serde_cbor::from_slice(&temp[..])?)
}
}
/// A tx+rx socket.
#[pin_project]
#[derive(Debug)]
@ -89,31 +157,6 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
(self.tx.inner_mut(), self.rx.inner_mut())
}
/// Create a future that exchanges keys
pub fn exchange_unsafe(&mut self) -> Exchange<'_, W, R>
{
let us = self.info.us.get_public_parts();
todo!("Currently unimplemented")
/*
Exchange{
sock: self,
write_buf: Default::default(),
read_buf: Default::default(),
_pin: PhantomPinned,
read_state: Default::default(),
write_state: Default::default(),
us,
them: None,
us_written: Default::default(),
us_buf: (),
write_sz_num: (),
write_sz_buf: (),
read_sz_buf: (),
}*/
}
///Get a mutable ref to unencrypted read+write
fn unencrypted(&mut self) -> (&mut W, &mut R)
{
@ -143,8 +186,16 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
/// Enable write encryption
pub async fn set_encrypted_write(&mut self, set: bool) -> eyre::Result<()>
{
use tokio::prelude::*;
if set {
let (key, iv) = ((),());
let session_key = ESockSessionKey::generate();
let data = session_key.to_ciphertext(self.info.them.as_ref().expect("Cannot set encrypted write when keys have not been exchanged"))?;
let crypter = session_key.to_encrypter()?;
// Send rsa `data` over unencrypted endpoint
self.unencrypted().0.write_all(&data[..]).await?;
// Set crypter of `tx` to `session_key`.
*self.tx.crypter_mut() = crypter;
// Set `encw` to true
self.state.encw = true;
Ok(())
} else {
@ -153,6 +204,30 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
}
}
/// Enable read encryption
///
/// The other endpoint must have sent a `set_encrypted_write()`
pub async fn set_encrypted_read(&mut self, set: bool) -> eyre::Result<()>
{
use tokio::prelude::*;
if set {
let mut data = [0u8; RSA_CIPHERTEXT_SIZE];
// Read `data` from unencrypted endpoint
self.unencrypted().1.read_exact(&mut data[..]).await?;
// Decrypt `data`
let session_key = ESockSessionKey::from_ciphertext(&data, &self.info.us)?;
// Set crypter of `rx` to `session_key`.
*self.rx.crypter_mut() = session_key.to_decrypter()?;
// Set `encr` to true
self.state.encr = true;
Ok(())
} else {
self.state.encr = false;
Ok(())
}
}
/// Get dynamic ref to unencrypted write+read
fn unencrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{
@ -208,117 +283,6 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
enum ExchangeState
{
/// We are currently reading/writing the buffer's size
BufferSize,
/// We are currently reading/writing the buffer itself
Buffer,
}
impl Default for ExchangeState
{
#[inline]
fn default() -> Self
{
Self::BufferSize
}
}
#[pin_project]
#[derive(Debug)]
pub struct Exchange<'a, W, R>
{
sock: &'a mut ESock<W, R>,
us: RsaPublicKey,
us_written: usize,
us_buf: Vec<u8>,
/// The return value
them: Option<RsaPublicKey>,
write_sz_num: usize,
write_sz_buf: [u8; std::mem::size_of::<u64>()],
read_sz_buf: [u8; std::mem::size_of::<u64>()],
read_buf: Vec<u8>,
write_state: ExchangeState,
read_state: ExchangeState,
#[pin] _pin: PhantomPinned,
}
/*
impl<'a, W: AsyncWrite, R: AsyncRead> Future for Exchange<'a, W, R>
{
type Output = eyre::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
use futures::ready;
let this = self.project();
let (tx, rx) = {
let sock = this.sock;
//XXX: Idk if this is safe?
unsafe {
(Pin::new_unchecked(&mut sock.tx), Pin::new_unchecked(&mut sock.rx))
}
};
if this.us_buf.is_empty() {
*this.us_buf = this.us.to_bytes();
}
let poll_write = loop {
break match this.write_state {
ExchangeState::BufferSize => {
if *this.write_sz_num == 0 {
*this.write_sz_buf = u64::try_from(this.us_buf.len())?.to_be_bytes();
}
// Write this to tx.
match tx.poll_write(cx, &this.write_sz_buf[(*this.write_sz_num)..]) {
x @ Poll::Ready(Ok(n)) => {
*this.write_sz_num+=n;
if *this.write_sz_num == this.write_sz_buf.len() {
// We've written all the size bytes, continue to writing the buffer bytes.
*this.write_state = ExchangeState::Buffer;
continue;
}
x
},
x => x,
}
},
ExchangeState::Buffer => {
match tx.poll_write(cx, &this.us_buf[(*this.us_written)..]) {
x @ Poll::Ready(Ok(n)) => {
if *this.us_written == this.us.len() {
}
x
},
x=> x,
}
},
}
};
let poll_read = match this.read_state {
ExchangeState::BufferSize => {
},
ExchangeState::Buffer => {
},
};
todo!("This is going to be dificult to implement... We don't have access to write_all and read_exact")
}
}
*/
/// Write half for `ESock`.
#[pin_project]
#[derive(Debug)]

Loading…
Cancel
Save