diff --git a/src/message.rs b/src/message.rs index b07e984..939b3e1 100644 --- a/src/message.rs +++ b/src/message.rs @@ -8,7 +8,10 @@ use cryptohelpers::{ rsa, }; use uuid::Uuid; -use std::borrow::Borrow; +use std::borrow::{ + Borrow, + Cow +}; use std::io; use std::marker::Unpin; use tokio::io::{ @@ -19,6 +22,9 @@ use tokio::io::{ mod serial; pub use serial::*; +mod builder; +pub use builder::*; + /// Size of encrypted AES key pub const RSA_BLOCK_SIZE: usize = 512; @@ -28,6 +34,9 @@ pub const MAX_ALLOC_SIZE: usize = 4096; // 4kb /// A value that can be used for messages. pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{} +impl MessageValue for T +where T: Serialize + for<'de> Deserialize<'de>{} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Message { @@ -55,6 +64,14 @@ struct SerHeader responds_to: Option, } +impl AsRef for Message +{ + #[inline(always)] fn as_ref(&self) -> &V + { + self.value_ref() + } +} + /// A serialized message that can be sent over a socket. /// /// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message` (of the same `V`.) @@ -71,13 +88,39 @@ pub struct SerializedMessage enc_key: Option<[u8; RSA_BLOCK_SIZE]>, // we can't derive Serialize because of this array.. meh.. /// Signature of hash of un-encrypted `data`. sig: Option, - + + //TODO: Add a message header checksum. _phantom: PhantomData, } impl Message { + /// A reference to the value itself. + pub fn value_ref(&self) -> &V + { + &self.value + } + + /// A mutable reference to the value itself. + /// + /// # Safety + /// Mutating a value inside a message may cause invalid metadata. + pub unsafe fn value_mut(&mut self) -> &mut V + { + &mut self.value + } + + /// Consume into just the value + pub fn into_value(self) -> V + { + self.value + } + /// Serialise this message into one that can be converted to/from bytes. + /// + /// # Panics + /// * If this message was specified to be encrypted, but `S` doesn't support encryption. + /// * If this message was specified to be signed, but `S` doesn't support signing. pub fn serialise(&self, send_with: impl Borrow) -> eyre::Result> { let send_with: &S = send_with.borrow(); @@ -110,6 +153,53 @@ impl Message _phantom: PhantomData, }) } + + /// Try to deserialize and validate a received message. + /// + /// If a part of the message is invalid, an error is returned. + pub fn deserialize<'a, S: ?Sized + MessageReceiver>(serial: &'a SerializedMessage, recv_with: impl Borrow) -> eyre::Result + { + let recv_with: &S = recv_with.borrow(); + macro_rules! assert_valid { + ($ex:expr, $fmt:literal $($tt:tt)*) => { + if $ex { + Ok(()) + } else { + Err(eyre!($fmt $($tt)*)) + } + } + } + //Validate hashes + assert_valid!(sha256::compute_slice(&serial.data) == serial.hash, "Non-matching hashes")?; + + // Decrypted data + let (data, key): (Cow<'a, [u8]>, Option) = if let Some(enc_key) = &serial.enc_key { + let key = recv_with.decrypt_key(enc_key).ok_or(eyre!("Message was decrypted, but receiver doesn't support decryption"))? + .wrap_err(eyre!("Failed to decrypt session key"))?; + let mut data = Vec::with_capacity(serial.data.len()); + aes::decrypt_stream_sync(&key, &mut &serial.data[..], &mut data).wrap_err(eyre!("Failed to decrypt body"))?; + (Cow::Owned(data), Some(key)) + } else { + (Cow::Borrowed(&serial.data[..]), None) + }; + + let sign = if let Some(sig) = &serial.sig { + // Validate signature + assert_valid!(recv_with.verify_data(&data[..], sig).ok_or(eyre!("Message was signed, but receiver doesn't support signature verification"))?.wrap_err("Failed to verify signature")?, + "Non-matching signature")?; + true + } else { + false + }; + + // Deserialise value + Ok(Self { + value: serde_cbor::from_slice(&data[..]).wrap_err(eyre!("Failed to deserialise value"))?, + header: serial.header.clone(), + key, + sign + }) + } } impl SerializedMessage @@ -141,6 +231,7 @@ impl SerializedMessage } } write!(: &self.header); + write!(u64::try_from(self.data.len())?.to_be_bytes()); write!(self.data); write!(self.hash); write!(? self.enc_key); @@ -169,6 +260,7 @@ impl SerializedMessage }; (: $ser:expr) => { { + let mut w2 = WriteCounter(0, &mut writer); serde_cbor::to_writer(&mut w2, $ser)?; w+=w2.0; @@ -191,7 +283,7 @@ impl SerializedMessage { macro_rules! read { ($b:expr) => { - read_all($b, &mut reader)? + read_all($b, &mut reader).wrap_err(eyre!("Failed to read from stream"))? }; (? $ot:expr) => { { @@ -213,22 +305,25 @@ impl SerializedMessage } }; (: $ser:ty) => { - serde_cbor::from_reader::<$ty, _>(&mut reader)? + serde_cbor::from_reader::<$ser, _>(&mut reader).wrap_err(eyre!("Failed to deserialise {} from reader", std::any::type_name::<$ser>()))? }; (:) => { - serde_cbor::from_reader(&mut reader)? + serde_cbor::from_reader(&mut reader).wrap_err(eyre!("Failed to deserialise type from reader"))? }; ($into:expr, $num:expr) => { - copy_buffer($into, &mut reader, $num)? + { + let num = $num; + copy_buffer($into, &mut reader, num).wrap_err(eyre!("Failed to read {} bytes from reader", num))? + } } } - let header: SerHeader = read!(:); + let header = read!(: SerHeader); let data_len = { let mut bytes = [0u8; std::mem::size_of::()]; read!(&mut bytes); u64::from_be_bytes(bytes) }.try_into()?; - let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate massive buffers by accident on corrupted/malformed messages + let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate OR try to read massive buffers by accident on corrupted/malformed messages read!(&mut data, data_len); if data.len()!=data_len { return Err(eyre!("Failed to read {} bytes from buffer (got {})", data_len, data.len())); @@ -262,3 +357,28 @@ impl SerializedMessage Self::from_reader(&mut &bytes[..]) } } + +#[cfg(test)] +mod tests +{ + use super::*; + #[test] + fn message_serial() + { + let message = MessageBuilder::new() + .create(format!("This is a string, and some random data: {:?}", aes::AesKey::generate().unwrap())) + .expect("Failed to create message"); + println!(">> Created message: {:?}", message); + let serialised = message.serialise(()).expect("Failed to serialise message"); + println!(">> Serialised message: {:?}", serialised); + let binary = serialised.into_bytes(); + println!(">> Written to {} bytes", binary.len()); + let read = SerializedMessage::from_bytes(&binary).expect("Failed to read serialised message from binary"); + println!(">> Read from bytes: {:?}", read); + let deserialised = Message::deserialize(&read, ()).expect("Failed to deserialise message"); + println!(">> Deserialised message: {:?}", deserialised); + + assert_eq!(message, deserialised, "Messages not identical"); + assert_eq!(message.into_value(), deserialised.into_value(), "Message values not identical"); + } +} diff --git a/src/message/builder.rs b/src/message/builder.rs new file mode 100644 index 0000000..c8cca90 --- /dev/null +++ b/src/message/builder.rs @@ -0,0 +1,105 @@ +//! Building `Message`s +use super::*; +use std::time::SystemTime; + +/// Builder for the `Message` type +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MessageBuilder +{ + sign: bool, + encrypt: bool, + respond: Option, + + _phantom: PhantomData, +} + +impl Default for MessageBuilder +{ + #[inline] + fn default() -> Self + { + Self::new() + } +} + + +impl MessageBuilder +{ + /// Create a new builder for a message with default settings + pub const fn new() -> Self + { + Self { + sign: false, + encrypt: false, + respond: None, + _phantom: PhantomData, + } + } + + /// Specify if the message should be signed when serialized. + pub const fn sign(mut self, sign: bool) -> Self + { + self.sign = sign; + self + } + + /// Specify if the message should be encrypted when serialized. + /// + /// A key will be generated randomly for the message on creation. + pub const fn encrypt(mut self, encrypt: bool) -> Self + { + self.encrypt = encrypt; + self + } + + /// Specify a message ID that this message should respond to. + pub const fn respond(mut self, to: Uuid) -> Self + { + self.respond = Some(to); + self + } +} + +impl MessageBuilder +{ + /// Create a message from this builder with this value. + pub fn create(self, value: V) -> eyre::Result> + { + let key = if self.encrypt { + Some(aes::AesKey::generate().wrap_err(eyre!("Failed to generate session key"))?) + } else { + None + }; + let header = SerHeader::new(self.respond); + + Ok(Message { + header, + key, + sign: self.sign, + value, + }) + } +} + +/// Get the current unix timestamp. +pub(super) fn timestamp_now() -> u64 +{ + match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { + Ok(n) => n.as_secs(), + Err(_) => panic!("Timestamp for now returned before unix epoch."), + } +} + +impl SerHeader +{ + /// Create a new header with optional response ID. + #[inline] pub fn new(responds: Option) -> Self + { + Self { + id: Uuid::new_v4(), + idemp: Uuid::new_v4(), + timestamp: timestamp_now(), + responds_to: responds, + } + } +} diff --git a/src/message/serial.rs b/src/message/serial.rs index 718e0cd..5b4da0f 100644 --- a/src/message/serial.rs +++ b/src/message/serial.rs @@ -17,8 +17,8 @@ pub trait MessageSender /// A type that can be used to deserialise a message pub trait MessageReceiver { - #[inline] fn decrypt_key(&self, _enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option{ None } - #[inline] fn verify_data(&self, _data: &[u8], _sig: rsa::Signature) -> Option { None } + #[inline] fn decrypt_key(&self, _enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option>{ None } + #[inline] fn verify_data(&self, _data: &[u8], _sig: &rsa::Signature) -> Option> { None } } impl MessageSender for (){}