Added `SerializedMessage::from_reader()` (reading untrusted messages)

TODO: XXX: Validate the length of message bodies somehow before naively trying to read them.

Fortune for rsh's current commit: Half curse − 半凶
specialisation
Avril 3 years ago
parent 9142244bca
commit c41d5c2c28
Signed by: flanchan
GPG Key ID: 284488987C31F630

@ -13,6 +13,11 @@ use color_eyre::{
}, },
SectionExt, Help, SectionExt, Help,
}; };
#[allow(unused_imports)]
use std::convert::{
TryFrom,
TryInto,
};
mod message; mod message;

@ -22,6 +22,9 @@ pub use serial::*;
/// Size of encrypted AES key /// Size of encrypted AES key
pub const RSA_BLOCK_SIZE: usize = 512; pub const RSA_BLOCK_SIZE: usize = 512;
/// Max size to pre-allocate when reading a message buffer.
pub const MAX_ALLOC_SIZE: usize = 4096; // 4kb
/// A value that can be used for messages. /// A value that can be used for messages.
pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{} pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{}
@ -52,6 +55,9 @@ struct SerHeader
responds_to: Option<Uuid>, responds_to: Option<Uuid>,
} }
/// A serialized message that can be sent over a socket.
///
/// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message<V>` (of the same `V`.)
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct SerializedMessage<V: ?Sized + MessageValue> pub struct SerializedMessage<V: ?Sized + MessageValue>
{ {
@ -167,9 +173,10 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
serde_cbor::to_writer(&mut w2, $ser)?; serde_cbor::to_writer(&mut w2, $ser)?;
w+=w2.0; w+=w2.0;
} }
} };
} }
write!(: &self.header); write!(: &self.header);
write!(u64::try_from(self.data.len())?.to_be_bytes());
write!(self.data); write!(self.data);
write!(self.hash); write!(self.hash);
write!(? self.enc_key); write!(? self.enc_key);
@ -177,6 +184,70 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
Ok(w) Ok(w)
} }
/// Create from a reader.
///
/// The message may be in an invalid state. It is only possible to extract the value after validating it into a `Message<V>`.
pub fn from_reader(mut reader: impl io::Read) -> eyre::Result<Self>
{
macro_rules! read {
($b:expr) => {
read_all($b, &mut reader)?
};
(? $ot:expr) => {
{
let mut b = [0u8; 1];
read!(&mut b[..]);
match b[0] {
1 => {
let mut o = $ot;
read!(&mut o);
Some(o)
},
0 => {
None
},
x => {
return Err(eyre!("Invalid option state {:?}", x));
}
}
}
};
(: $ser:ty) => {
serde_cbor::from_reader::<$ty, _>(&mut reader)?
};
(:) => {
serde_cbor::from_reader(&mut reader)?
};
($into:expr, $num:expr) => {
copy_buffer($into, &mut reader, $num)?
}
}
let header: SerHeader = read!(:);
let data_len = {
let mut bytes = [0u8; std::mem::size_of::<u64>()];
read!(&mut bytes);
u64::from_be_bytes(bytes)
}.try_into()?;
let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate massive buffers by accident on corrupted/malformed messages
read!(&mut data, data_len);
if data.len()!=data_len {
return Err(eyre!("Failed to read {} bytes from buffer (got {})", data_len, data.len()));
}
let mut hash = sha256::Sha256Hash::default();
read!(&mut hash);
let enc_key: Option<[u8; RSA_BLOCK_SIZE]> = read!(? [0u8; RSA_BLOCK_SIZE]);
let sig: Option<rsa::Signature> = read!(? rsa::Signature::default());
Ok(Self {
header,
data,
hash,
enc_key,
sig,
_phantom: PhantomData,
})
}
/// Consume into `Vec<u8>`. /// Consume into `Vec<u8>`.
pub fn into_bytes(self) -> Vec<u8> pub fn into_bytes(self) -> Vec<u8>
{ {
@ -184,4 +255,10 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
self.into_writer(&mut v).expect("Failed to write to in-memory buffer"); self.into_writer(&mut v).expect("Failed to write to in-memory buffer");
v v
} }
/// Create from bytes
#[inline] pub fn from_bytes(bytes: impl AsRef<[u8]>) -> eyre::Result<Self>
{
let bytes = bytes.as_ref();
Self::from_reader(&mut &bytes[..])
}
} }

@ -78,3 +78,48 @@ pub(super) async fn write_all_async(mut to: impl AsyncWrite + Unpin, bytes: impl
Ok(bytes.len()) Ok(bytes.len())
} }
#[inline(always)] pub(super) fn read_all(mut to: impl AsMut<[u8]>, mut from: impl io::Read) -> io::Result<usize>
{
let mut read=0;
let to = to.as_mut();
loop
{
match from.read(&mut to[read..]) {
Ok(r) if r>0 => read+=r,
Err(io) if io.kind() == io::ErrorKind::Interrupted => continue,
x => {x?; break;},
}
}
Ok(read)
}
pub(super) async fn read_all_async(mut to: impl AsMut<[u8]>, mut from: impl AsyncRead + Unpin) -> io::Result<usize>
{
use tokio::prelude::*;
let mut read=0;
let to = to.as_mut();
loop
{
match from.read(&mut to[read..]).await {
Ok(r) if r>0 => read+=r,
Err(io) if io.kind() == io::ErrorKind::Interrupted => continue,
x => {x?; break;},
}
}
Ok(read)
}
#[inline(always)] pub(super) fn copy_buffer(mut to: impl io::Write, from: impl io::Read, n: usize) -> io::Result<usize>
{
let mut reader = from.take(n.try_into().expect("Invalid take size"));
io::copy(&mut reader, &mut to).map(|x| x.try_into().expect("Invalid read size"))
}
pub(super) async fn copy_buffer_async(mut to: impl AsyncWrite + Unpin, from: impl AsyncRead + Unpin, n: usize) -> io::Result<usize>
{
use tokio::prelude::*;
let mut reader = from.take(n.try_into().expect("Invalid take size"));
tokio::io::copy(&mut reader, &mut to).await.map(|x| x.try_into().expect("Invalid read size"))
}

Loading…
Cancel
Save