You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rsh/src/message/serial.rs

148 lines
4.3 KiB

//! Traits for serialising the message
use super::*;
use std::{
pin::Pin,
task::{
Context, Poll,
},
};
/// A type that can be used to serialise a message
pub trait MessageSender
{
const CAP_ENCRYPT: bool = false;
const CAP_SIGN: bool = false;
#[inline] fn encrypt_key(&self, _key: &aes::AesKey) -> Option<[u8; RSA_BLOCK_SIZE]> { None }
#[inline] fn sign_data(&self, _data: &[u8]) -> Option<rsa::Signature> { None }
}
/// A type that can be used to deserialise a message
pub trait MessageReceiver
{
#[inline] fn decrypt_key(&self, _enc_key: &[u8; RSA_BLOCK_SIZE]) -> Option<eyre::Result<aes::AesKey>>{ None }
#[inline] fn verify_data(&self, _data: &[u8], _sig: &rsa::Signature) -> Option<eyre::Result<bool>> { None }
}
impl MessageSender for (){}
impl MessageReceiver for (){}
/// Identical to `()` in terms of the use of the MessageSender trait.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct DefaultMessageSender;
impl MessageSender for DefaultMessageSender{}
/// Identical to `()` in terms of the use of the MessageReceiver trait.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct DefaultMessageReceiver;
impl MessageReceiver for DefaultMessageReceiver{}
/// Identical to `()` in terms of the use of the MessageSender/Receiver traits.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct DefaultMessageSenderReceiver;
impl MessageSender for DefaultMessageSenderReceiver{}
impl MessageReceiver for DefaultMessageSenderReceiver{}
#[derive(Debug)]
pub(super) struct WriteCounter<W:?Sized>(pub usize, pub W);
impl<W: ?Sized + io::Write> io::Write for WriteCounter<W>
{
#[inline] fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let w = self.1.write(buf)?;
self.0 +=w;
Ok(w)
}
#[inline] fn flush(&mut self) -> io::Result<()> {
self.1.flush()
}
}
#[pin_project]
#[derive(Debug)]
pub(super) struct AsyncWriteCounter<W:?Sized>(pub usize, #[pin] pub W);
impl<W: ?Sized + AsyncWrite> AsyncWrite for AsyncWriteCounter<W>
{
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().1.poll_shutdown(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().1.poll_flush(cx)
}
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.project();
match this.1.poll_write(cx, buf) {
Poll::Ready(Ok(sz)) => {
*this.0 += sz;
Poll::Ready(Ok(sz))
},
x => x,
}
}
}
pub(super) async fn write_all_async(mut to: impl AsyncWrite + Unpin, bytes: impl AsRef<[u8]>) -> io::Result<usize>
{
use tokio::prelude::*;
let bytes= bytes.as_ref();
to.write_all(bytes).await?;
Ok(bytes.len())
}
#[inline(always)] pub(super) fn write_all(mut to: impl io::Write, bytes: impl AsRef<[u8]>) -> io::Result<usize>
{
let bytes= bytes.as_ref();
to.write_all(bytes)?;
Ok(bytes.len())
}
#[inline(always)] pub(super) fn read_all(mut to: impl AsMut<[u8]>, mut from: impl io::Read) -> io::Result<usize>
{
let mut read=0;
let to = to.as_mut();
loop
{
match from.read(&mut to[read..]) {
Ok(r) if r>0 => read+=r,
Err(io) if io.kind() == io::ErrorKind::Interrupted => continue,
x => {x?; break;},
}
}
Ok(read)
}
pub(super) async fn read_all_async(mut to: impl AsMut<[u8]>, mut from: impl AsyncRead + Unpin) -> io::Result<usize>
{
use tokio::prelude::*;
let mut read=0;
let to = to.as_mut();
loop
{
match from.read(&mut to[read..]).await {
Ok(r) if r>0 => read+=r,
Err(io) if io.kind() == io::ErrorKind::Interrupted => continue,
x => {x?; break;},
}
}
Ok(read)
}
#[inline(always)] pub(super) fn copy_buffer(mut to: impl io::Write, from: impl io::Read, n: usize) -> io::Result<usize>
{
let mut reader = from.take(n.try_into().expect("Invalid take size"));
io::copy(&mut reader, &mut to).map(|x| x.try_into().expect("Invalid read size"))
}
pub(super) async fn copy_buffer_async(mut to: impl AsyncWrite + Unpin, from: impl AsyncRead + Unpin, n: usize) -> io::Result<usize>
{
use tokio::prelude::*;
let mut reader = from.take(n.try_into().expect("Invalid take size"));
tokio::io::copy(&mut reader, &mut to).await.map(|x| x.try_into().expect("Invalid read size"))
}