//! Messages use super::*; use std::marker::PhantomData; use serde::{Serialize, Deserialize}; use cryptohelpers::{ sha256, aes, rsa, }; use uuid::Uuid; use std::borrow::{ Borrow, Cow }; use std::io; use std::marker::Unpin; use tokio::io::{ AsyncWrite, AsyncRead, }; mod serial; pub use serial::*; mod builder; pub use builder::*; /// Size of encrypted AES key pub const RSA_BLOCK_SIZE: usize = 512; /// Max size to pre-allocate when reading a message buffer. pub const MAX_ALLOC_SIZE: usize = 4096; // 4kb /// A value that can be used for messages. pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{} impl MessageValue for T where T: Serialize + for<'de> Deserialize<'de>{} #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Message { header: SerHeader, /// Optional key to use to encrypt the message key: Option, /// Should the message body be signed? sign: bool, /// Value to serialise value: V, } /// `SerializedMessage` header. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] struct SerHeader { /// Message ID id: Uuid, /// Message idempodence ID idemp: Uuid, /// Timestamp of when this message was created (Unix TS, UTC). timestamp: u64, /// `id` of message this one is responding to, if needed. responds_to: Option, } impl AsRef for Message { #[inline(always)] fn as_ref(&self) -> &V { self.value_ref() } } /// A serialized message that can be sent over a socket. /// /// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message` (of the same `V`.) #[derive(Debug, Clone, PartialEq, Eq)] pub struct SerializedMessage { header: SerHeader, /// cbor serialised `V`. data: Vec, /// Hash of `data` (after encryption) hash: sha256::Sha256Hash, /// `key` encrypted with recipient's RSA public key. enc_key: Option<[u8; RSA_BLOCK_SIZE]>, // we can't derive Serialize because of this array.. meh.. /// Signature of hash of un-encrypted `data`. sig: Option, //TODO: Add a message header checksum. _phantom: PhantomData, } impl Message { /// A reference to the value itself. pub fn value_ref(&self) -> &V { &self.value } /// A mutable reference to the value itself. /// /// # Safety /// Mutating a value inside a message may cause invalid metadata. pub unsafe fn value_mut(&mut self) -> &mut V { &mut self.value } /// Consume into just the value pub fn into_value(self) -> V { self.value } /// Serialise this message into one that can be converted to/from bytes. /// /// # Panics /// * If this message was specified to be encrypted, but `S` doesn't support encryption. /// * If this message was specified to be signed, but `S` doesn't support signing. pub fn serialise(&self, send_with: impl Borrow) -> eyre::Result> { let send_with: &S = send_with.borrow(); let data = serde_cbor::to_vec(&self.value)?; let sig = if self.sign { Some(send_with.sign_data(&data[..]).expect("Message expected signing, sender did not support it")) } else { None }; let (data, enc_key) = if let Some(key) = &self.key { // Encrypt the body let enc_key = send_with.encrypt_key(key).expect("Message expected encryption, sender did not support it"); (aes::encrypt_slice_sync(key, data)?, Some(enc_key)) } else { // Don't encrypt the body (data, None) }; // Compute hash of data let hash = sha256::compute_slice(&data); Ok(SerializedMessage{ header: self.header.clone(), data, sig, enc_key, hash, _phantom: PhantomData, }) } /// Try to deserialize and validate a received message. /// /// If a part of the message is invalid, an error is returned. pub fn deserialize<'a, S: ?Sized + MessageReceiver>(serial: &'a SerializedMessage, recv_with: impl Borrow) -> eyre::Result { let recv_with: &S = recv_with.borrow(); macro_rules! assert_valid { ($ex:expr, $fmt:literal $($tt:tt)*) => { if $ex { Ok(()) } else { Err(eyre!($fmt $($tt)*)) } } } //Validate hashes assert_valid!(sha256::compute_slice(&serial.data) == serial.hash, "Non-matching hashes")?; // Decrypted data let (data, key): (Cow<'a, [u8]>, Option) = if let Some(enc_key) = &serial.enc_key { let key = recv_with.decrypt_key(enc_key).ok_or(eyre!("Message was decrypted, but receiver doesn't support decryption"))? .wrap_err(eyre!("Failed to decrypt session key"))?; let mut data = Vec::with_capacity(serial.data.len()); aes::decrypt_stream_sync(&key, &mut &serial.data[..], &mut data).wrap_err(eyre!("Failed to decrypt body"))?; (Cow::Owned(data), Some(key)) } else { (Cow::Borrowed(&serial.data[..]), None) }; let sign = if let Some(sig) = &serial.sig { // Validate signature assert_valid!(recv_with.verify_data(&data[..], sig).ok_or(eyre!("Message was signed, but receiver doesn't support signature verification"))?.wrap_err("Failed to verify signature")?, "Non-matching signature") .with_section(move || format!("Embedded sig was: {}", sig))?; true } else { false }; // Deserialise value Ok(Self { value: serde_cbor::from_slice(&data[..]).wrap_err(eyre!("Failed to deserialise value"))?, header: serial.header.clone(), key, sign }) } } impl SerializedMessage { /// Consume into an async writer pub async fn into_writer_async(self, mut writer: W) -> eyre::Result { let mut w = 0; macro_rules! write { ($b:expr) => { w+=write_all_async(&mut writer, $b).await? }; (? $o:expr) => { match $o { Some(key) => { write!([1]); write!(key); }, None => { write!([0]); }, } }; (: $ser:expr) => { { let mut v = StackVec::new(); serde_cbor::to_writer(&mut v, $ser)?; write!(&v[..]); } } } write!(: &self.header); write!(u64::try_from(self.data.len())?.to_be_bytes()); write!(self.data); write!(self.hash); write!(? self.enc_key); write!(? self.sig); Ok(w) } /// Consume into a syncronous writer pub fn into_writer(self, mut writer: impl io::Write) -> eyre::Result { let mut w = 0; macro_rules! write { ($b:expr) => { w+=write_all(&mut writer, $b)? }; (? $o:expr) => { match $o { Some(key) => { write!([1]); write!(key); }, None => { write!([0]); }, } }; (: $ser:expr) => { { let mut ser = StackVec::new(); serde_cbor::to_writer(&mut ser, $ser)?; write!(u64::try_from(ser.len())?.to_be_bytes()); write!(ser); /* let mut w2 = WriteCounter(0, &mut writer); serde_cbor::to_writer(&mut w2, $ser)?; w+=w2.0;*/ } }; } write!(: &self.header); write!(u64::try_from(self.data.len())?.to_be_bytes()); write!(self.data); write!(self.hash); write!(? self.enc_key); write!(? self.sig); Ok(w) } /// Create from a reader. /// /// The message may be in an invalid state. It is only possible to extract the value after validating it into a `Message`. pub fn from_reader(mut reader: impl io::Read) -> eyre::Result { macro_rules! read { ($b:expr; $fmt:literal $($tt:tt)*) => { read_all($b, &mut reader).wrap_err(eyre!($fmt $($tt)*)) }; ($b:expr) => { read!($b; "Failed to read from stream")?; }; (? $ot:expr) => { { let mut b = [0u8; 1]; read!(&mut b[..]); match b[0] { 1 => { let mut o = $ot; read!(&mut o); Some(o) }, 0 => { None }, x => { return Err(eyre!("Invalid option state {:?}", x)); } } } }; (: $ser:ty) => { { let mut len = [0u8; std::mem::size_of::()]; read!(&mut len[..]); let len = usize::try_from(u64::from_be_bytes(len))?; //TODO: Find realistic max size for `$ser`. if len > MAX_ALLOC_SIZE { return Err(eyre!("Invalid length read: {}", len) .with_section(|| format!("Max length read: {}", MAX_ALLOC_SIZE))) } alloc_local_bytes(len, |de| { read!(&mut de[..]); serde_cbor::from_slice::<$ser>(&de[..]).wrap_err(eyre!("Failed to deserialise {} from reader", std::any::type_name::<$ser>())) })? } }; ($into:expr, $num:expr) => { { let num = $num; copy_buffer($into, &mut reader, num).wrap_err(eyre!("Failed to read {} bytes from reader", num))? } } } let header = read!(: SerHeader); let data_len = { let mut bytes = [0u8; std::mem::size_of::()]; read!(&mut bytes); u64::from_be_bytes(bytes) }.try_into()?; let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate OR try to read massive buffers by accident on corrupted/malformed messages read!(&mut data, data_len); if data.len()!=data_len { return Err(eyre!("Failed to read {} bytes from buffer (got {})", data_len, data.len())); } let mut hash = sha256::Sha256Hash::default(); read!(&mut hash); let enc_key: Option<[u8; RSA_BLOCK_SIZE]> = read!(? [0u8; RSA_BLOCK_SIZE]); let sig: Option = read!(? rsa::Signature::default()); Ok(Self { header, data, hash, enc_key, sig, _phantom: PhantomData, }) } /// Consume into `Vec`. pub fn into_bytes(self) -> Vec { let mut v = Vec::with_capacity(self.data.len()<<1); self.into_writer(&mut v).expect("Failed to write to in-memory buffer"); v } /// Create from bytes #[inline] pub fn from_bytes(bytes: impl AsRef<[u8]>) -> eyre::Result { let bytes = bytes.as_ref(); Self::from_reader(&mut &bytes[..]) } } #[cfg(test)] mod tests { use super::*; /// Generic message test function fn message_serial_generic(s: S, d: D) where S: MessageSender, D: MessageReceiver { eprintln!("=== Message serialisation with tc, rc: S: {}, D: {}", std::any::type_name::(), std::any::type_name::()); let message = MessageBuilder::for_sender::() .create(format!("This is a string, and some random data: {:?}", aes::AesKey::generate().unwrap())) .expect("Failed to create message"); println!(">> Created message: {:?}", message); let serialised = message.serialise(s).expect("Failed to serialise message"); println!(">> Serialised message: {:?}", serialised); let binary = serialised.into_bytes(); println!(">> Written to {} bytes", binary.len()); let read = SerializedMessage::from_bytes(&binary).expect("Failed to read serialised message from binary"); println!(">> Read from bytes: {:?}", read); let deserialised = Message::deserialize(&read, d).expect("Failed to deserialise message"); println!(">> Deserialised message: {:?}", deserialised); assert_eq!(message, deserialised, "Messages not identical"); assert_eq!(message.into_value(), deserialised.into_value(), "Message values not identical"); eprintln!("=== Passed (S: {}, D: {}) ===", std::any::type_name::(), std::any::type_name::()); } #[test] fn message_serial_basic() { message_serial_generic((), ()); } #[test] fn message_serial_sign() { let rsa_priv = rsa::RsaPrivateKey::generate().unwrap(); struct Sign(rsa::RsaPrivateKey); struct Verify(rsa::RsaPublicKey); impl MessageSender for Sign { const CAP_SIGN: bool = true; fn sign_data(&self, data: &[u8]) -> Option { Some(rsa::sign_slice(data, &self.0).expect("Failed to sign")) } } impl MessageReceiver for Verify { fn verify_data(&self, data: &[u8], sig: &rsa::Signature) -> Option> { Some(sig.verify_slice(data, &self.0).map_err(Into::into)) } } let verify = Verify(rsa_priv.get_public_parts()); println!("Signing priv-key: {:?}", rsa_priv); println!("Verifying pub-key: {:?}", verify.0); message_serial_generic(Sign(rsa_priv), verify); } #[test] fn message_serial_encrypt() { color_eyre::install().unwrap(); let rsa_priv = rsa::RsaPrivateKey::generate().unwrap(); struct Dec(rsa::RsaPrivateKey); struct Enc(rsa::RsaPublicKey); impl MessageSender for Enc { const CAP_ENCRYPT: bool = true; fn encrypt_key(&self, key: &aes::AesKey) -> Option<[u8; RSA_BLOCK_SIZE]> { let mut output = [0u8; RSA_BLOCK_SIZE]; use rsa::HasPublicComponents; let w = rsa::encrypt_slice_sync(key, &self.0, &mut &mut output[..]).expect("Failed to encrypt session key"); assert_eq!(w, output.len()); Some(output) } } impl MessageReceiver for Dec { fn decrypt_key(&self, enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option> { let mut output = aes::AesKey::empty(); match rsa::decrypt_slice_sync(enc_key, &self.0, &mut output.as_mut()) { Ok(sz) => assert_eq!(sz, output.as_ref().len()), Err(err) => return Some(Err(err.into())), } Some(Ok(output)) } } let enc = Enc(rsa_priv.get_public_parts()); println!("Encrypting pub-key: {:?}", enc.0); println!("Decrypting priv-key: {:?}", rsa_priv); message_serial_generic(enc, Dec(rsa_priv)); } #[test] fn rsa_bullshit() { use rsa::HasPublicComponents; let rsa = rsa::openssl::rsa::Rsa::generate(cryptohelpers::consts::RSA_KEY_BITS).unwrap(); eprintln!("rn: {}, re: {},", rsa.n(), rsa.e()); let rsa_comp: rsa::RsaPrivateKey = rsa.clone().into(); eprintln!("n: {:?}, e: {:?}", rsa_comp.n(), rsa_comp.e()); let rsa2: rsa::openssl::rsa::Rsa<_> = rsa_comp.get_rsa_pub().unwrap(); eprintln!("n2: {}, e2: {}", rsa2.n(), rsa2.e()); let mut data = Vec::new(); data.extend_from_slice(aes::AesKey::generate().unwrap().as_ref()); let rend = rsa::encrypt_slice_to_vec(data, &rsa2).unwrap(); assert_eq!(rend.len(), RSA_BLOCK_SIZE); } //TODO: message_serial_sign(), message_serial_encrypted(), message_serial_encrypted_sign() }