You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rsh/src/message.rs

503 lines
14 KiB

//! Messages
use super::*;
use std::marker::PhantomData;
use serde::{Serialize, Deserialize};
use cryptohelpers::{
sha256,
aes,
rsa,
};
use uuid::Uuid;
use std::borrow::{
Borrow,
Cow
};
use std::io;
use std::marker::Unpin;
use tokio::io::{
AsyncWrite,
AsyncRead,
};
mod serial;
pub use serial::*;
mod builder;
pub use builder::*;
/// Size of encrypted AES key
pub const RSA_BLOCK_SIZE: usize = 512;
/// Max size to pre-allocate when reading a message buffer.
pub const MAX_ALLOC_SIZE: usize = 4096; // 4kb
/// A value that can be used for messages.
pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{}
impl<T: ?Sized> MessageValue for T
where T: Serialize + for<'de> Deserialize<'de>{}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Message<V: ?Sized + MessageValue>
{
header: SerHeader,
/// Optional key to use to encrypt the message
key: Option<aes::AesKey>,
/// Should the message body be signed?
sign: bool,
/// Value to serialise
value: V,
}
/// `SerializedMessage` header.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct SerHeader
{
/// Message ID
id: Uuid,
/// Message idempodence ID
idemp: Uuid,
/// Timestamp of when this message was created (Unix TS, UTC).
timestamp: u64,
/// `id` of message this one is responding to, if needed.
responds_to: Option<Uuid>,
}
impl<V: ?Sized + MessageValue> AsRef<V> for Message<V>
{
#[inline(always)] fn as_ref(&self) -> &V
{
self.value_ref()
}
}
/// A serialized message that can be sent over a socket.
///
/// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message<V>` (of the same `V`.)
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SerializedMessage<V: ?Sized + MessageValue>
{
header: SerHeader,
/// cbor serialised `V`.
data: Vec<u8>,
/// Hash of `data` (after encryption)
hash: sha256::Sha256Hash,
/// `key` encrypted with recipient's RSA public key.
enc_key: Option<[u8; RSA_BLOCK_SIZE]>, // we can't derive Serialize because of this array.. meh..
/// Signature of hash of un-encrypted `data`.
sig: Option<rsa::Signature>,
//TODO: Add a message header checksum.
_phantom: PhantomData<V>,
}
impl<V: ?Sized + MessageValue> Message<V>
{
/// A reference to the value itself.
pub fn value_ref(&self) -> &V
{
&self.value
}
/// A mutable reference to the value itself.
///
/// # Safety
/// Mutating a value inside a message may cause invalid metadata.
pub unsafe fn value_mut(&mut self) -> &mut V
{
&mut self.value
}
/// Consume into just the value
pub fn into_value(self) -> V
{
self.value
}
/// Serialise this message into one that can be converted to/from bytes.
///
/// # Panics
/// * If this message was specified to be encrypted, but `S` doesn't support encryption.
/// * If this message was specified to be signed, but `S` doesn't support signing.
pub fn serialise<S: ?Sized + MessageSender>(&self, send_with: impl Borrow<S>) -> eyre::Result<SerializedMessage<V>>
{
let send_with: &S = send_with.borrow();
let data = serde_cbor::to_vec(&self.value)?;
let sig = if self.sign {
Some(send_with.sign_data(&data[..]).expect("Message expected signing, sender did not support it"))
} else {
None
};
let (data, enc_key) = if let Some(key) = &self.key {
// Encrypt the body
let enc_key = send_with.encrypt_key(key).expect("Message expected encryption, sender did not support it");
(aes::encrypt_slice_sync(key, data)?, Some(enc_key))
} else {
// Don't encrypt the body
(data, None)
};
// Compute hash of data
let hash = sha256::compute_slice(&data);
Ok(SerializedMessage{
header: self.header.clone(),
data,
sig,
enc_key,
hash,
_phantom: PhantomData,
})
}
/// Try to deserialize and validate a received message.
///
/// If a part of the message is invalid, an error is returned.
pub fn deserialize<'a, S: ?Sized + MessageReceiver>(serial: &'a SerializedMessage<V>, recv_with: impl Borrow<S>) -> eyre::Result<Self>
{
let recv_with: &S = recv_with.borrow();
macro_rules! assert_valid {
($ex:expr, $fmt:literal $($tt:tt)*) => {
if $ex {
Ok(())
} else {
Err(eyre!($fmt $($tt)*))
}
}
}
//Validate hashes
assert_valid!(sha256::compute_slice(&serial.data) == serial.hash, "Non-matching hashes")?;
// Decrypted data
let (data, key): (Cow<'a, [u8]>, Option<aes::AesKey>) = if let Some(enc_key) = &serial.enc_key {
let key = recv_with.decrypt_key(enc_key).ok_or(eyre!("Message was decrypted, but receiver doesn't support decryption"))?
.wrap_err(eyre!("Failed to decrypt session key"))?;
let mut data = Vec::with_capacity(serial.data.len());
aes::decrypt_stream_sync(&key, &mut &serial.data[..], &mut data).wrap_err(eyre!("Failed to decrypt body"))?;
(Cow::Owned(data), Some(key))
} else {
(Cow::Borrowed(&serial.data[..]), None)
};
let sign = if let Some(sig) = &serial.sig {
// Validate signature
assert_valid!(recv_with.verify_data(&data[..], sig).ok_or(eyre!("Message was signed, but receiver doesn't support signature verification"))?.wrap_err("Failed to verify signature")?,
"Non-matching signature")
.with_section(move || format!("Embedded sig was: {}", sig))?;
true
} else {
false
};
// Deserialise value
Ok(Self {
value: serde_cbor::from_slice(&data[..]).wrap_err(eyre!("Failed to deserialise value"))?,
header: serial.header.clone(),
key,
sign
})
}
}
impl<V: ?Sized + MessageValue> SerializedMessage<V>
{
/// Consume into an async writer
pub async fn into_writer_async<W:AsyncWrite+Unpin>(self, mut writer: W) -> eyre::Result<usize>
{
let mut w = 0;
macro_rules! write {
($b:expr) => {
w+=write_all_async(&mut writer, $b).await?
};
(? $o:expr) => {
match $o {
Some(key) => {
write!([1]);
write!(key);
},
None => {
write!([0]);
},
}
};
(: $ser:expr) => {
{
let mut v = StackVec::new();
serde_cbor::to_writer(&mut v, $ser)?;
write!(&v[..]);
}
}
}
write!(: &self.header);
write!(u64::try_from(self.data.len())?.to_be_bytes());
write!(self.data);
write!(self.hash);
write!(? self.enc_key);
write!(? self.sig);
Ok(w)
}
/// Consume into a syncronous writer
pub fn into_writer(self, mut writer: impl io::Write) -> eyre::Result<usize>
{
let mut w = 0;
macro_rules! write {
($b:expr) => {
w+=write_all(&mut writer, $b)?
};
(? $o:expr) => {
match $o {
Some(key) => {
write!([1]);
write!(key);
},
None => {
write!([0]);
},
}
};
(: $ser:expr) => {
{
let mut ser = StackVec::new();
serde_cbor::to_writer(&mut ser, $ser)?;
write!(u64::try_from(ser.len())?.to_be_bytes());
write!(ser);
/*
let mut w2 = WriteCounter(0, &mut writer);
serde_cbor::to_writer(&mut w2, $ser)?;
w+=w2.0;*/
}
};
}
write!(: &self.header);
write!(u64::try_from(self.data.len())?.to_be_bytes());
write!(self.data);
write!(self.hash);
write!(? self.enc_key);
write!(? self.sig);
Ok(w)
}
/// Create from a reader.
///
/// The message may be in an invalid state. It is only possible to extract the value after validating it into a `Message<V>`.
pub fn from_reader(mut reader: impl io::Read) -> eyre::Result<Self>
{
macro_rules! read {
($b:expr; $fmt:literal $($tt:tt)*) => {
read_all($b, &mut reader).wrap_err(eyre!($fmt $($tt)*))
};
($b:expr) => {
read!($b; "Failed to read from stream")?;
};
(? $ot:expr) => {
{
let mut b = [0u8; 1];
read!(&mut b[..]);
match b[0] {
1 => {
let mut o = $ot;
read!(&mut o);
Some(o)
},
0 => {
None
},
x => {
return Err(eyre!("Invalid option state {:?}", x));
}
}
}
};
(: $ser:ty) => {
{
let mut len = [0u8; std::mem::size_of::<u64>()];
read!(&mut len[..]);
let len = usize::try_from(u64::from_be_bytes(len))?;
//TODO: Find realistic max size for `$ser`.
if len > MAX_ALLOC_SIZE {
return Err(eyre!("Invalid length read: {}", len)
.with_section(|| format!("Max length read: {}", MAX_ALLOC_SIZE)))
}
alloc_local_bytes(len, |de| {
read!(&mut de[..]);
serde_cbor::from_slice::<$ser>(&de[..]).wrap_err(eyre!("Failed to deserialise {} from reader", std::any::type_name::<$ser>()))
})?
}
};
($into:expr, $num:expr) => {
{
let num = $num;
copy_buffer($into, &mut reader, num).wrap_err(eyre!("Failed to read {} bytes from reader", num))?
}
}
}
let header = read!(: SerHeader);
let data_len = {
let mut bytes = [0u8; std::mem::size_of::<u64>()];
read!(&mut bytes);
u64::from_be_bytes(bytes)
}.try_into()?;
let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate OR try to read massive buffers by accident on corrupted/malformed messages
read!(&mut data, data_len);
if data.len()!=data_len {
return Err(eyre!("Failed to read {} bytes from buffer (got {})", data_len, data.len()));
}
let mut hash = sha256::Sha256Hash::default();
read!(&mut hash);
let enc_key: Option<[u8; RSA_BLOCK_SIZE]> = read!(? [0u8; RSA_BLOCK_SIZE]);
let sig: Option<rsa::Signature> = read!(? rsa::Signature::default());
Ok(Self {
header,
data,
hash,
enc_key,
sig,
_phantom: PhantomData,
})
}
/// Consume into `Vec<u8>`.
pub fn into_bytes(self) -> Vec<u8>
{
let mut v = Vec::with_capacity(self.data.len()<<1);
self.into_writer(&mut v).expect("Failed to write to in-memory buffer");
v
}
/// Create from bytes
#[inline] pub fn from_bytes(bytes: impl AsRef<[u8]>) -> eyre::Result<Self>
{
let bytes = bytes.as_ref();
Self::from_reader(&mut &bytes[..])
}
}
#[cfg(test)]
mod tests
{
use super::*;
/// Generic message test function
fn message_serial_generic<S,D>(s: S, d: D)
where S: MessageSender,
D: MessageReceiver
{
eprintln!("=== Message serialisation with tc, rc: S: {}, D: {}", std::any::type_name::<S>(), std::any::type_name::<D>());
let message = MessageBuilder::for_sender::<S>()
.create(format!("This is a string, and some random data: {:?}", aes::AesKey::generate().unwrap()))
.expect("Failed to create message");
println!(">> Created message: {:?}", message);
let serialised = message.serialise(s).expect("Failed to serialise message");
println!(">> Serialised message: {:?}", serialised);
let binary = serialised.into_bytes();
println!(">> Written to {} bytes", binary.len());
let read = SerializedMessage::from_bytes(&binary).expect("Failed to read serialised message from binary");
println!(">> Read from bytes: {:?}", read);
let deserialised = Message::deserialize(&read, d).expect("Failed to deserialise message");
println!(">> Deserialised message: {:?}", deserialised);
assert_eq!(message, deserialised, "Messages not identical");
assert_eq!(message.into_value(), deserialised.into_value(), "Message values not identical");
eprintln!("=== Passed (S: {}, D: {}) ===", std::any::type_name::<S>(), std::any::type_name::<D>());
}
#[test]
fn message_serial_basic()
{
message_serial_generic((), ());
}
#[test]
fn message_serial_sign()
{
let rsa_priv = rsa::RsaPrivateKey::generate().unwrap();
struct Sign(rsa::RsaPrivateKey);
struct Verify(rsa::RsaPublicKey);
impl MessageSender for Sign
{
const CAP_SIGN: bool = true;
fn sign_data(&self, data: &[u8]) -> Option<rsa::Signature> {
Some(rsa::sign_slice(data, &self.0).expect("Failed to sign"))
}
}
impl MessageReceiver for Verify
{
fn verify_data(&self, data: &[u8], sig: &rsa::Signature) -> Option<eyre::Result<bool>> {
Some(sig.verify_slice(data, &self.0).map_err(Into::into))
}
}
let verify = Verify(rsa_priv.get_public_parts());
println!("Signing priv-key: {:?}", rsa_priv);
println!("Verifying pub-key: {:?}", verify.0);
message_serial_generic(Sign(rsa_priv), verify);
}
#[test]
fn message_serial_encrypt()
{
color_eyre::install().unwrap();
let rsa_priv = rsa::RsaPrivateKey::generate().unwrap();
struct Dec(rsa::RsaPrivateKey);
struct Enc(rsa::RsaPublicKey);
impl MessageSender for Enc
{
const CAP_ENCRYPT: bool = true;
fn encrypt_key(&self, key: &aes::AesKey) -> Option<[u8; RSA_BLOCK_SIZE]> {
let mut output = [0u8; RSA_BLOCK_SIZE];
use rsa::HasPublicComponents;
let w = rsa::encrypt_slice_sync(key, &self.0, &mut &mut output[..]).expect("Failed to encrypt session key");
assert_eq!(w, output.len());
Some(output)
}
}
impl MessageReceiver for Dec
{
fn decrypt_key(&self, enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option<eyre::Result<aes::AesKey>> {
let mut output = aes::AesKey::empty();
match rsa::decrypt_slice_sync(enc_key, &self.0, &mut output.as_mut()) {
Ok(sz) => assert_eq!(sz,
output.as_ref().len()),
Err(err) => return Some(Err(err.into())),
}
Some(Ok(output))
}
}
let enc = Enc(rsa_priv.get_public_parts());
println!("Encrypting pub-key: {:?}", enc.0);
println!("Decrypting priv-key: {:?}", rsa_priv);
message_serial_generic(enc, Dec(rsa_priv));
}
#[test]
fn rsa_bullshit()
{
use rsa::HasPublicComponents;
let rsa = rsa::openssl::rsa::Rsa::generate(cryptohelpers::consts::RSA_KEY_BITS).unwrap();
eprintln!("rn: {}, re: {},", rsa.n(), rsa.e());
let rsa_comp: rsa::RsaPrivateKey = rsa.clone().into();
eprintln!("n: {:?}, e: {:?}", rsa_comp.n(), rsa_comp.e());
let rsa2: rsa::openssl::rsa::Rsa<_> = rsa_comp.get_rsa_pub().unwrap();
eprintln!("n2: {}, e2: {}", rsa2.n(), rsa2.e());
let mut data = Vec::new();
data.extend_from_slice(aes::AesKey::generate().unwrap().as_ref());
let rend = rsa::encrypt_slice_to_vec(data, &rsa2).unwrap();
assert_eq!(rend.len(), RSA_BLOCK_SIZE);
}
//TODO: message_serial_sign(), message_serial_encrypted(), message_serial_encrypted_sign()
}