Message de/serial test written (currently failing due to serde_cbor being greedy when deserialising from reader)

Fortune for rsh's current commit: Middle blessing − 中吉
specialisation
Avril 3 years ago
parent c41d5c2c28
commit 3e59440609

@ -8,7 +8,10 @@ use cryptohelpers::{
rsa, rsa,
}; };
use uuid::Uuid; use uuid::Uuid;
use std::borrow::Borrow; use std::borrow::{
Borrow,
Cow
};
use std::io; use std::io;
use std::marker::Unpin; use std::marker::Unpin;
use tokio::io::{ use tokio::io::{
@ -19,6 +22,9 @@ use tokio::io::{
mod serial; mod serial;
pub use serial::*; pub use serial::*;
mod builder;
pub use builder::*;
/// Size of encrypted AES key /// Size of encrypted AES key
pub const RSA_BLOCK_SIZE: usize = 512; pub const RSA_BLOCK_SIZE: usize = 512;
@ -28,6 +34,9 @@ pub const MAX_ALLOC_SIZE: usize = 4096; // 4kb
/// A value that can be used for messages. /// A value that can be used for messages.
pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{} pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{}
impl<T: ?Sized> MessageValue for T
where T: Serialize + for<'de> Deserialize<'de>{}
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Message<V: ?Sized + MessageValue> pub struct Message<V: ?Sized + MessageValue>
{ {
@ -55,6 +64,14 @@ struct SerHeader
responds_to: Option<Uuid>, responds_to: Option<Uuid>,
} }
impl<V: ?Sized + MessageValue> AsRef<V> for Message<V>
{
#[inline(always)] fn as_ref(&self) -> &V
{
self.value_ref()
}
}
/// A serialized message that can be sent over a socket. /// A serialized message that can be sent over a socket.
/// ///
/// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message<V>` (of the same `V`.) /// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message<V>` (of the same `V`.)
@ -71,13 +88,39 @@ pub struct SerializedMessage<V: ?Sized + MessageValue>
enc_key: Option<[u8; RSA_BLOCK_SIZE]>, // we can't derive Serialize because of this array.. meh.. enc_key: Option<[u8; RSA_BLOCK_SIZE]>, // we can't derive Serialize because of this array.. meh..
/// Signature of hash of un-encrypted `data`. /// Signature of hash of un-encrypted `data`.
sig: Option<rsa::Signature>, sig: Option<rsa::Signature>,
//TODO: Add a message header checksum.
_phantom: PhantomData<V>, _phantom: PhantomData<V>,
} }
impl<V: ?Sized + MessageValue> Message<V> impl<V: ?Sized + MessageValue> Message<V>
{ {
/// A reference to the value itself.
pub fn value_ref(&self) -> &V
{
&self.value
}
/// A mutable reference to the value itself.
///
/// # Safety
/// Mutating a value inside a message may cause invalid metadata.
pub unsafe fn value_mut(&mut self) -> &mut V
{
&mut self.value
}
/// Consume into just the value
pub fn into_value(self) -> V
{
self.value
}
/// Serialise this message into one that can be converted to/from bytes. /// Serialise this message into one that can be converted to/from bytes.
///
/// # Panics
/// * If this message was specified to be encrypted, but `S` doesn't support encryption.
/// * If this message was specified to be signed, but `S` doesn't support signing.
pub fn serialise<S: ?Sized + MessageSender>(&self, send_with: impl Borrow<S>) -> eyre::Result<SerializedMessage<V>> pub fn serialise<S: ?Sized + MessageSender>(&self, send_with: impl Borrow<S>) -> eyre::Result<SerializedMessage<V>>
{ {
let send_with: &S = send_with.borrow(); let send_with: &S = send_with.borrow();
@ -110,6 +153,53 @@ impl<V: ?Sized + MessageValue> Message<V>
_phantom: PhantomData, _phantom: PhantomData,
}) })
} }
/// Try to deserialize and validate a received message.
///
/// If a part of the message is invalid, an error is returned.
pub fn deserialize<'a, S: ?Sized + MessageReceiver>(serial: &'a SerializedMessage<V>, recv_with: impl Borrow<S>) -> eyre::Result<Self>
{
let recv_with: &S = recv_with.borrow();
macro_rules! assert_valid {
($ex:expr, $fmt:literal $($tt:tt)*) => {
if $ex {
Ok(())
} else {
Err(eyre!($fmt $($tt)*))
}
}
}
//Validate hashes
assert_valid!(sha256::compute_slice(&serial.data) == serial.hash, "Non-matching hashes")?;
// Decrypted data
let (data, key): (Cow<'a, [u8]>, Option<aes::AesKey>) = if let Some(enc_key) = &serial.enc_key {
let key = recv_with.decrypt_key(enc_key).ok_or(eyre!("Message was decrypted, but receiver doesn't support decryption"))?
.wrap_err(eyre!("Failed to decrypt session key"))?;
let mut data = Vec::with_capacity(serial.data.len());
aes::decrypt_stream_sync(&key, &mut &serial.data[..], &mut data).wrap_err(eyre!("Failed to decrypt body"))?;
(Cow::Owned(data), Some(key))
} else {
(Cow::Borrowed(&serial.data[..]), None)
};
let sign = if let Some(sig) = &serial.sig {
// Validate signature
assert_valid!(recv_with.verify_data(&data[..], sig).ok_or(eyre!("Message was signed, but receiver doesn't support signature verification"))?.wrap_err("Failed to verify signature")?,
"Non-matching signature")?;
true
} else {
false
};
// Deserialise value
Ok(Self {
value: serde_cbor::from_slice(&data[..]).wrap_err(eyre!("Failed to deserialise value"))?,
header: serial.header.clone(),
key,
sign
})
}
} }
impl<V: ?Sized + MessageValue> SerializedMessage<V> impl<V: ?Sized + MessageValue> SerializedMessage<V>
@ -141,6 +231,7 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
} }
} }
write!(: &self.header); write!(: &self.header);
write!(u64::try_from(self.data.len())?.to_be_bytes());
write!(self.data); write!(self.data);
write!(self.hash); write!(self.hash);
write!(? self.enc_key); write!(? self.enc_key);
@ -169,6 +260,7 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
}; };
(: $ser:expr) => { (: $ser:expr) => {
{ {
let mut w2 = WriteCounter(0, &mut writer); let mut w2 = WriteCounter(0, &mut writer);
serde_cbor::to_writer(&mut w2, $ser)?; serde_cbor::to_writer(&mut w2, $ser)?;
w+=w2.0; w+=w2.0;
@ -191,7 +283,7 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
{ {
macro_rules! read { macro_rules! read {
($b:expr) => { ($b:expr) => {
read_all($b, &mut reader)? read_all($b, &mut reader).wrap_err(eyre!("Failed to read from stream"))?
}; };
(? $ot:expr) => { (? $ot:expr) => {
{ {
@ -213,22 +305,25 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
} }
}; };
(: $ser:ty) => { (: $ser:ty) => {
serde_cbor::from_reader::<$ty, _>(&mut reader)? serde_cbor::from_reader::<$ser, _>(&mut reader).wrap_err(eyre!("Failed to deserialise {} from reader", std::any::type_name::<$ser>()))?
}; };
(:) => { (:) => {
serde_cbor::from_reader(&mut reader)? serde_cbor::from_reader(&mut reader).wrap_err(eyre!("Failed to deserialise type from reader"))?
}; };
($into:expr, $num:expr) => { ($into:expr, $num:expr) => {
copy_buffer($into, &mut reader, $num)? {
let num = $num;
copy_buffer($into, &mut reader, num).wrap_err(eyre!("Failed to read {} bytes from reader", num))?
}
} }
} }
let header: SerHeader = read!(:); let header = read!(: SerHeader);
let data_len = { let data_len = {
let mut bytes = [0u8; std::mem::size_of::<u64>()]; let mut bytes = [0u8; std::mem::size_of::<u64>()];
read!(&mut bytes); read!(&mut bytes);
u64::from_be_bytes(bytes) u64::from_be_bytes(bytes)
}.try_into()?; }.try_into()?;
let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate massive buffers by accident on corrupted/malformed messages let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate OR try to read massive buffers by accident on corrupted/malformed messages
read!(&mut data, data_len); read!(&mut data, data_len);
if data.len()!=data_len { if data.len()!=data_len {
return Err(eyre!("Failed to read {} bytes from buffer (got {})", data_len, data.len())); return Err(eyre!("Failed to read {} bytes from buffer (got {})", data_len, data.len()));
@ -262,3 +357,28 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
Self::from_reader(&mut &bytes[..]) Self::from_reader(&mut &bytes[..])
} }
} }
#[cfg(test)]
mod tests
{
use super::*;
#[test]
fn message_serial()
{
let message = MessageBuilder::new()
.create(format!("This is a string, and some random data: {:?}", aes::AesKey::generate().unwrap()))
.expect("Failed to create message");
println!(">> Created message: {:?}", message);
let serialised = message.serialise(()).expect("Failed to serialise message");
println!(">> Serialised message: {:?}", serialised);
let binary = serialised.into_bytes();
println!(">> Written to {} bytes", binary.len());
let read = SerializedMessage::from_bytes(&binary).expect("Failed to read serialised message from binary");
println!(">> Read from bytes: {:?}", read);
let deserialised = Message::deserialize(&read, ()).expect("Failed to deserialise message");
println!(">> Deserialised message: {:?}", deserialised);
assert_eq!(message, deserialised, "Messages not identical");
assert_eq!(message.into_value(), deserialised.into_value(), "Message values not identical");
}
}

@ -0,0 +1,105 @@
//! Building `Message<V>`s
use super::*;
use std::time::SystemTime;
/// Builder for the `Message<V>` type
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageBuilder<V: ?Sized>
{
sign: bool,
encrypt: bool,
respond: Option<Uuid>,
_phantom: PhantomData<V>,
}
impl<V: ?Sized + MessageValue> Default for MessageBuilder<V>
{
#[inline]
fn default() -> Self
{
Self::new()
}
}
impl<V: ?Sized> MessageBuilder<V>
{
/// Create a new builder for a message with default settings
pub const fn new() -> Self
{
Self {
sign: false,
encrypt: false,
respond: None,
_phantom: PhantomData,
}
}
/// Specify if the message should be signed when serialized.
pub const fn sign(mut self, sign: bool) -> Self
{
self.sign = sign;
self
}
/// Specify if the message should be encrypted when serialized.
///
/// A key will be generated randomly for the message on creation.
pub const fn encrypt(mut self, encrypt: bool) -> Self
{
self.encrypt = encrypt;
self
}
/// Specify a message ID that this message should respond to.
pub const fn respond(mut self, to: Uuid) -> Self
{
self.respond = Some(to);
self
}
}
impl<V: ?Sized + MessageValue> MessageBuilder<V>
{
/// Create a message from this builder with this value.
pub fn create(self, value: V) -> eyre::Result<Message<V>>
{
let key = if self.encrypt {
Some(aes::AesKey::generate().wrap_err(eyre!("Failed to generate session key"))?)
} else {
None
};
let header = SerHeader::new(self.respond);
Ok(Message {
header,
key,
sign: self.sign,
value,
})
}
}
/// Get the current unix timestamp.
pub(super) fn timestamp_now() -> u64
{
match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
Ok(n) => n.as_secs(),
Err(_) => panic!("Timestamp for now returned before unix epoch."),
}
}
impl SerHeader
{
/// Create a new header with optional response ID.
#[inline] pub fn new(responds: Option<Uuid>) -> Self
{
Self {
id: Uuid::new_v4(),
idemp: Uuid::new_v4(),
timestamp: timestamp_now(),
responds_to: responds,
}
}
}

@ -17,8 +17,8 @@ pub trait MessageSender
/// A type that can be used to deserialise a message /// A type that can be used to deserialise a message
pub trait MessageReceiver pub trait MessageReceiver
{ {
#[inline] fn decrypt_key(&self, _enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option<aes::AesKey>{ None } #[inline] fn decrypt_key(&self, _enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option<eyre::Result<aes::AesKey>>{ None }
#[inline] fn verify_data(&self, _data: &[u8], _sig: rsa::Signature) -> Option<bool> { None } #[inline] fn verify_data(&self, _data: &[u8], _sig: &rsa::Signature) -> Option<eyre::Result<bool>> { None }
} }
impl MessageSender for (){} impl MessageSender for (){}

Loading…
Cancel
Save