From c41d5c2c2816e6be7631013da9221fe6a0b1f861 Mon Sep 17 00:00:00 2001 From: Avril Date: Thu, 29 Jul 2021 17:10:03 +0100 Subject: [PATCH] Added `SerializedMessage::from_reader()` (reading untrusted messages) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TODO: XXX: Validate the length of message bodies somehow before naively trying to read them. Fortune for rsh's current commit: Half curse − 半凶 --- src/main.rs | 5 +++ src/message.rs | 79 ++++++++++++++++++++++++++++++++++++++++++- src/message/serial.rs | 45 ++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index d84d87a..ee8dcf5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,11 @@ use color_eyre::{ }, SectionExt, Help, }; +#[allow(unused_imports)] +use std::convert::{ + TryFrom, + TryInto, +}; mod message; diff --git a/src/message.rs b/src/message.rs index 0679398..b07e984 100644 --- a/src/message.rs +++ b/src/message.rs @@ -22,6 +22,9 @@ pub use serial::*; /// Size of encrypted AES key pub const RSA_BLOCK_SIZE: usize = 512; +/// Max size to pre-allocate when reading a message buffer. +pub const MAX_ALLOC_SIZE: usize = 4096; // 4kb + /// A value that can be used for messages. pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{} @@ -52,6 +55,9 @@ struct SerHeader responds_to: Option, } +/// A serialized message that can be sent over a socket. +/// +/// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message` (of the same `V`.) #[derive(Debug, Clone, PartialEq, Eq)] pub struct SerializedMessage { @@ -167,9 +173,10 @@ impl SerializedMessage serde_cbor::to_writer(&mut w2, $ser)?; w+=w2.0; } - } + }; } write!(: &self.header); + write!(u64::try_from(self.data.len())?.to_be_bytes()); write!(self.data); write!(self.hash); write!(? self.enc_key); @@ -177,6 +184,70 @@ impl SerializedMessage Ok(w) } + /// Create from a reader. + /// + /// The message may be in an invalid state. It is only possible to extract the value after validating it into a `Message`. + pub fn from_reader(mut reader: impl io::Read) -> eyre::Result + { + macro_rules! read { + ($b:expr) => { + read_all($b, &mut reader)? + }; + (? $ot:expr) => { + { + let mut b = [0u8; 1]; + read!(&mut b[..]); + match b[0] { + 1 => { + let mut o = $ot; + read!(&mut o); + Some(o) + }, + 0 => { + None + }, + x => { + return Err(eyre!("Invalid option state {:?}", x)); + } + } + } + }; + (: $ser:ty) => { + serde_cbor::from_reader::<$ty, _>(&mut reader)? + }; + (:) => { + serde_cbor::from_reader(&mut reader)? + }; + ($into:expr, $num:expr) => { + copy_buffer($into, &mut reader, $num)? + } + } + let header: SerHeader = read!(:); + let data_len = { + let mut bytes = [0u8; std::mem::size_of::()]; + read!(&mut bytes); + u64::from_be_bytes(bytes) + }.try_into()?; + let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate massive buffers by accident on corrupted/malformed messages + read!(&mut data, data_len); + if data.len()!=data_len { + return Err(eyre!("Failed to read {} bytes from buffer (got {})", data_len, data.len())); + } + let mut hash = sha256::Sha256Hash::default(); + read!(&mut hash); + let enc_key: Option<[u8; RSA_BLOCK_SIZE]> = read!(? [0u8; RSA_BLOCK_SIZE]); + let sig: Option = read!(? rsa::Signature::default()); + + Ok(Self { + header, + data, + hash, + enc_key, + sig, + + _phantom: PhantomData, + }) + } /// Consume into `Vec`. pub fn into_bytes(self) -> Vec { @@ -184,4 +255,10 @@ impl SerializedMessage self.into_writer(&mut v).expect("Failed to write to in-memory buffer"); v } + /// Create from bytes + #[inline] pub fn from_bytes(bytes: impl AsRef<[u8]>) -> eyre::Result + { + let bytes = bytes.as_ref(); + Self::from_reader(&mut &bytes[..]) + } } diff --git a/src/message/serial.rs b/src/message/serial.rs index b43ed75..718e0cd 100644 --- a/src/message/serial.rs +++ b/src/message/serial.rs @@ -78,3 +78,48 @@ pub(super) async fn write_all_async(mut to: impl AsyncWrite + Unpin, bytes: impl Ok(bytes.len()) } +#[inline(always)] pub(super) fn read_all(mut to: impl AsMut<[u8]>, mut from: impl io::Read) -> io::Result +{ + let mut read=0; + let to = to.as_mut(); + loop + { + match from.read(&mut to[read..]) { + Ok(r) if r>0 => read+=r, + Err(io) if io.kind() == io::ErrorKind::Interrupted => continue, + x => {x?; break;}, + } + } + Ok(read) +} + + +pub(super) async fn read_all_async(mut to: impl AsMut<[u8]>, mut from: impl AsyncRead + Unpin) -> io::Result +{ + use tokio::prelude::*; + + let mut read=0; + let to = to.as_mut(); + loop + { + match from.read(&mut to[read..]).await { + Ok(r) if r>0 => read+=r, + Err(io) if io.kind() == io::ErrorKind::Interrupted => continue, + x => {x?; break;}, + } + } + Ok(read) +} + +#[inline(always)] pub(super) fn copy_buffer(mut to: impl io::Write, from: impl io::Read, n: usize) -> io::Result +{ + let mut reader = from.take(n.try_into().expect("Invalid take size")); + io::copy(&mut reader, &mut to).map(|x| x.try_into().expect("Invalid read size")) +} + +pub(super) async fn copy_buffer_async(mut to: impl AsyncWrite + Unpin, from: impl AsyncRead + Unpin, n: usize) -> io::Result +{ + use tokio::prelude::*; + let mut reader = from.take(n.try_into().expect("Invalid take size")); + tokio::io::copy(&mut reader, &mut to).await.map(|x| x.try_into().expect("Invalid read size")) +}