diff --git a/Cargo.lock b/Cargo.lock index 5904eca..a9fd4a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "cryptohelpers" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46bc3c4ea63c83c528b6a98b514a64e7767b5796ef666bf0dd8ac64cc5a5bf7e" +checksum = "14be74ce15793a86acd04872953368ce27d07f384f07b8028bd5aaa31a031a38" dependencies = [ "crc", "futures", @@ -778,8 +778,10 @@ dependencies = [ "color-eyre", "cryptohelpers", "pin-project", + "rustc_version", "serde", "serde_cbor", + "smallvec", "stackalloc", "tokio 0.2.25", "tokio-uring", @@ -894,6 +896,9 @@ name = "smallvec" version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" +dependencies = [ + "serde", +] [[package]] name = "stackalloc" diff --git a/Cargo.toml b/Cargo.toml index f845cd9..833c252 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,15 @@ edition = "2018" [dependencies] chacha20stream = { version = "1.0.3", features = ["async"] } color-eyre = "0.5.11" -cryptohelpers = { version = "1.8", features = ["serialise", "full"] } +cryptohelpers = { version = "1.8.1" , features = ["serialise", "full"] } pin-project = "1.0.8" serde = { version = "1.0.126", features = ["derive"] } serde_cbor = "0.11.1" +smallvec = { version = "1.6.1", features = ["union", "serde", "write"] } stackalloc = "1.1.1" tokio = { version = "0.2", features = ["full"] } tokio-uring = "0.1.0" uuid = { version = "0.8.2", features = ["v4", "serde"] } + +[build-dependencies] +rustc_version = "0.2" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..6399463 --- /dev/null +++ b/build.rs @@ -0,0 +1,24 @@ + +extern crate rustc_version; +use rustc_version::{version, version_meta, Channel}; + +fn main() { + // Assert we haven't travelled back in time + assert!(version().unwrap().major >= 1); + + // Set cfg flags depending on release channel + match version_meta().unwrap().channel { + Channel::Stable => { + println!("cargo:rustc-cfg=stable"); + } + Channel::Beta => { + println!("cargo:rustc-cfg=beta"); + } + Channel::Nightly => { + println!("cargo:rustc-cfg=nightly"); + } + Channel::Dev => { + println!("cargo:rustc-cfg=dev"); + } + } +} diff --git a/src/ext.rs b/src/ext.rs index c73552a..c758f45 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -2,10 +2,18 @@ use super::*; use std::mem::{self, MaybeUninit}; use std::iter; +use smallvec::SmallVec; /// Max size of memory allowed to be allocated on the stack. pub const STACK_MEM_ALLOC_MAX: usize = 4096; +/// A stack-allocated vector that spills onto the heap when needed. +pub type StackVec = SmallVec<[T; STACK_MEM_ALLOC_MAX]>; + +/// A maybe-atom that can spill into a vector. +pub type MaybeVec = SmallVec<[T; 1]>; + +/// Allocate a vector of `MaybeUninit`. pub fn vec_uninit(sz: usize) -> Vec> { let mut mem: Vec = Vec::with_capacity(sz); @@ -15,7 +23,7 @@ pub fn vec_uninit(sz: usize) -> Vec> } } -/// Allocate a local buffer initialised with `init`. +/// Allocate a local buffer initialised from `init`. pub fn alloc_local_with(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U { if sz > STACK_MEM_ALLOC_MAX { @@ -26,6 +34,56 @@ pub fn alloc_local_with(sz: usize, init: impl FnMut() -> T, within: impl F } } + + +/// Allocate a local zero-initialised byte buffer +pub fn alloc_local_bytes(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory: Vec> = vec_uninit(sz); + within(unsafe { + std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz); + stackalloc::helpers::slice_assume_init_mut(&mut memory[..]) + }) + } else { + stackalloc::alloca_zeroed(sz, within) + } +} + + +/// Allocate a local zero-initialised buffer +pub fn alloc_local_zeroed(sz: usize, within: impl FnOnce(&mut [MaybeUninit]) -> U) -> U +{ + let sz_bytes = mem::size_of::() * sz; + if sz > STACK_MEM_ALLOC_MAX { + let mut memory = vec_uninit(sz); + unsafe { + std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes); + } + within(&mut memory[..]) + } else { + stackalloc::alloca_zeroed(sz_bytes, move |buf| { + unsafe { + debug_assert_eq!(buf.len() / mem::size_of::(), sz); + within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit, sz)) + } + }) + } +} + + +/// Allocate a local uninitialised buffer +pub fn alloc_local_uninit(sz: usize, within: impl FnOnce(&mut [MaybeUninit]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory = vec_uninit(sz); + within(&mut memory[..]) + } else { + stackalloc::stackalloc_uninit(sz, within) + } +} + + /// Allocate a local buffer initialised with `init`. pub fn alloc_local(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U { @@ -36,3 +94,14 @@ pub fn alloc_local(sz: usize, init: T, within: impl FnOnce(&mut [T] stackalloc::stackalloc(sz, init, within) } } + +/// Allocate a local buffer initialised with `T::default()`. +pub fn alloc_local_with_default(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U +{ + if sz > STACK_MEM_ALLOC_MAX { + let mut memory: Vec = iter::repeat_with(Default::default).take(sz).collect(); + within(&mut memory[..]) + } else { + stackalloc::stackalloc_with_default(sz, within) + } +} diff --git a/src/main.rs b/src/main.rs index 48135c4..797ad9d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ //! Remote communication +#![cfg_attr(nightly, feature(const_fn_trait_bound))] #![allow(dead_code)] diff --git a/src/message.rs b/src/message.rs index 4b7b701..f034e14 100644 --- a/src/message.rs +++ b/src/message.rs @@ -186,7 +186,8 @@ impl Message let sign = if let Some(sig) = &serial.sig { // Validate signature assert_valid!(recv_with.verify_data(&data[..], sig).ok_or(eyre!("Message was signed, but receiver doesn't support signature verification"))?.wrap_err("Failed to verify signature")?, - "Non-matching signature")?; + "Non-matching signature") + .with_section(move || format!("Embedded sig was: {}", sig))?; true } else { false @@ -225,7 +226,8 @@ impl SerializedMessage }; (: $ser:expr) => { { - let v = serde_cbor::to_vec($ser)?; + let mut v = StackVec::new(); + serde_cbor::to_writer(&mut v, $ser)?; write!(&v[..]); } } @@ -260,7 +262,8 @@ impl SerializedMessage }; (: $ser:expr) => { { - let ser = serde_cbor::to_vec($ser)?; + let mut ser = StackVec::new(); + serde_cbor::to_writer(&mut ser, $ser)?; write!(u64::try_from(ser.len())?.to_be_bytes()); write!(ser); /* @@ -285,8 +288,12 @@ impl SerializedMessage pub fn from_reader(mut reader: impl io::Read) -> eyre::Result { macro_rules! read { + + ($b:expr; $fmt:literal $($tt:tt)*) => { + read_all($b, &mut reader).wrap_err(eyre!($fmt $($tt)*)) + }; ($b:expr) => { - read_all($b, &mut reader).wrap_err(eyre!("Failed to read from stream"))? + read!($b; "Failed to read from stream")?; }; (? $ot:expr) => { { @@ -314,11 +321,14 @@ impl SerializedMessage let len = usize::try_from(u64::from_be_bytes(len))?; //TODO: Find realistic max size for `$ser`. if len > MAX_ALLOC_SIZE { - return Err(eyre!("Invalid length read: {}", len)); + return Err(eyre!("Invalid length read: {}", len) + .with_section(|| format!("Max length read: {}", MAX_ALLOC_SIZE))) } - let mut de = Vec::with_capacity(len); - read!(&mut de, len); - serde_cbor::from_slice::<$ser>(&de[..]).wrap_err(eyre!("Failed to deserialise {} from reader", std::any::type_name::<$ser>()))? + alloc_local_bytes(len, |de| { + read!(&mut de[..]); + serde_cbor::from_slice::<$ser>(&de[..]).wrap_err(eyre!("Failed to deserialise {} from reader", std::any::type_name::<$ser>())) + })? + } }; ($into:expr, $num:expr) => { @@ -380,7 +390,7 @@ mod tests D: MessageReceiver { eprintln!("=== Message serialisation with tc, rc: S: {}, D: {}", std::any::type_name::(), std::any::type_name::()); - let message = MessageBuilder::new() + let message = MessageBuilder::for_sender::() .create(format!("This is a string, and some random data: {:?}", aes::AesKey::generate().unwrap())) .expect("Failed to create message"); println!(">> Created message: {:?}", message); @@ -404,5 +414,89 @@ mod tests message_serial_generic((), ()); } + + #[test] + fn message_serial_sign() + { + let rsa_priv = rsa::RsaPrivateKey::generate().unwrap(); + struct Sign(rsa::RsaPrivateKey); + struct Verify(rsa::RsaPublicKey); + + impl MessageSender for Sign + { + const CAP_SIGN: bool = true; + fn sign_data(&self, data: &[u8]) -> Option { + Some(rsa::sign_slice(data, &self.0).expect("Failed to sign")) + } + } + impl MessageReceiver for Verify + { + fn verify_data(&self, data: &[u8], sig: &rsa::Signature) -> Option> { + Some(sig.verify_slice(data, &self.0).map_err(Into::into)) + } + } + let verify = Verify(rsa_priv.get_public_parts()); + println!("Signing priv-key: {:?}", rsa_priv); + println!("Verifying pub-key: {:?}", verify.0); + message_serial_generic(Sign(rsa_priv), verify); + } + + #[test] + fn message_serial_encrypt() + { + color_eyre::install().unwrap(); + let rsa_priv = rsa::RsaPrivateKey::generate().unwrap(); + struct Dec(rsa::RsaPrivateKey); + struct Enc(rsa::RsaPublicKey); + + impl MessageSender for Enc + { + const CAP_ENCRYPT: bool = true; + + fn encrypt_key(&self, key: &aes::AesKey) -> Option<[u8; RSA_BLOCK_SIZE]> { + let mut output = [0u8; RSA_BLOCK_SIZE]; + use rsa::HasPublicComponents; + let w = rsa::encrypt_slice_sync(key, &self.0, &mut &mut output[..]).expect("Failed to encrypt session key"); + assert_eq!(w, output.len()); + + Some(output) + } + } + impl MessageReceiver for Dec + { + fn decrypt_key(&self, enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option> { + let mut output = aes::AesKey::empty(); + match rsa::decrypt_slice_sync(enc_key, &self.0, &mut output.as_mut()) { + Ok(sz) => assert_eq!(sz, + output.as_ref().len()), + Err(err) => return Some(Err(err.into())), + } + Some(Ok(output)) + } + } + let enc = Enc(rsa_priv.get_public_parts()); + println!("Encrypting pub-key: {:?}", enc.0); + println!("Decrypting priv-key: {:?}", rsa_priv); + message_serial_generic(enc, Dec(rsa_priv)); + } + + #[test] + fn rsa_bullshit() + { + use rsa::HasPublicComponents; + let rsa = rsa::openssl::rsa::Rsa::generate(cryptohelpers::consts::RSA_KEY_BITS).unwrap(); + eprintln!("rn: {}, re: {},", rsa.n(), rsa.e()); + let rsa_comp: rsa::RsaPrivateKey = rsa.clone().into(); + eprintln!("n: {:?}, e: {:?}", rsa_comp.n(), rsa_comp.e()); + let rsa2: rsa::openssl::rsa::Rsa<_> = rsa_comp.get_rsa_pub().unwrap(); + eprintln!("n2: {}, e2: {}", rsa2.n(), rsa2.e()); + + let mut data = Vec::new(); + data.extend_from_slice(aes::AesKey::generate().unwrap().as_ref()); + + let rend = rsa::encrypt_slice_to_vec(data, &rsa2).unwrap(); + + assert_eq!(rend.len(), RSA_BLOCK_SIZE); + } //TODO: message_serial_sign(), message_serial_encrypted(), message_serial_encrypted_sign() } diff --git a/src/message/builder.rs b/src/message/builder.rs index c8cca90..296b82d 100644 --- a/src/message/builder.rs +++ b/src/message/builder.rs @@ -58,10 +58,29 @@ impl MessageBuilder self.respond = Some(to); self } + + /// Create a new builder with the capabilities of a sender + #[cfg(nightly)] + pub const fn for_sender() -> Self + { + Self::new() + .sign(S::CAP_SIGN) + .encrypt(S::CAP_ENCRYPT) + } + + /// Create a new builder with the capabilities of a sender + #[cfg(not(nightly))] + pub fn for_sender() -> Self + { + Self::new() + .sign(S::CAP_SIGN) + .encrypt(S::CAP_ENCRYPT) + } } impl MessageBuilder { + /// Create a message from this builder with this value. pub fn create(self, value: V) -> eyre::Result> { diff --git a/src/message/serial.rs b/src/message/serial.rs index eeb5c6b..ad4b40f 100644 --- a/src/message/serial.rs +++ b/src/message/serial.rs @@ -10,6 +10,9 @@ use std::{ /// A type that can be used to serialise a message pub trait MessageSender { + const CAP_ENCRYPT: bool = false; + const CAP_SIGN: bool = false; + #[inline] fn encrypt_key(&self, _key: &aes::AesKey) -> Option<[u8; RSA_BLOCK_SIZE]> { None } #[inline] fn sign_data(&self, _data: &[u8]) -> Option { None } }