You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rsh/src/sock/enc.rs

375 lines
8.5 KiB

//! Socket encryption wrapper
use super::*;
use cryptohelpers::{
rsa::{
RsaPublicKey,
RsaPrivateKey,
},
sha256,
};
use chacha20stream::{
AsyncSink,
AsyncSource,
Key, IV,
};
use std::sync::Arc;
use tokio::{
sync::{
RwLock,
RwLockReadGuard,
RwLockWriteGuard,
},
io::{
DuplexStream,
},
};
use std::{
io,
task::{
Context, Poll,
},
pin::Pin,
marker::{
Unpin,
PhantomPinned,
},
};
/// Size of a single RSA ciphertext.
pub const RSA_CIPHERTEXT_SIZE: usize = 512;
/// Max size to read when exchanging keys
const TRANS_KEY_MAX_SIZE: usize = 4096;
/// Encrypted socket information.
#[derive(Debug)]
struct ESockInfo {
us: RsaPrivateKey,
them: Option<RsaPublicKey>,
}
#[derive(Debug)]
struct ESockState {
encr: bool,
encw: bool,
}
/// Contains a Key and IV that can be serialized and then encrypted
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct ESockSessionKey
{
key: Key,
iv: IV,
}
/// A tx+rx socket.
#[pin_project]
#[derive(Debug)]
pub struct ESock<W, R> {
info: ESockInfo,
state: ESockState,
#[pin]
rx: AsyncSource<R>,
#[pin]
tx: AsyncSink<W>,
}
impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
{
pub fn inner(&self) -> (&W, &R)
{
(self.tx.inner(), self.rx.inner())
}
fn inner_mut(&mut self) -> (&mut W, &mut R)
{
(self.tx.inner_mut(), self.rx.inner_mut())
}
/// Create a future that exchanges keys
pub fn exchange_unsafe(&mut self) -> Exchange<'_, W, R>
{
let us = self.info.us.get_public_parts();
todo!("Currently unimplemented")
/*
Exchange{
sock: self,
write_buf: Default::default(),
read_buf: Default::default(),
_pin: PhantomPinned,
read_state: Default::default(),
write_state: Default::default(),
us,
them: None,
us_written: Default::default(),
us_buf: (),
write_sz_num: (),
write_sz_buf: (),
read_sz_buf: (),
}*/
}
///Get a mutable ref to unencrypted read+write
fn unencrypted(&mut self) -> (&mut W, &mut R)
{
(self.tx.inner_mut(), self.rx.inner_mut())
}
/// Get a mutable ref to encrypted write+read
fn encrypted(&mut self) -> (&mut AsyncSink<W>, &mut AsyncSource<R>)
{
(&mut self.tx, &mut self.rx)
}
/// Have the RSA keys been exchanged?
pub fn has_exchanged(&self) -> bool
{
self.info.them.is_some()
}
/// Is the Write + Read operation encrypted? Tuple is `(Tx, Rx)`.
pub fn is_encrypted(&self) -> (bool, bool)
{
(self.state.encw, self.state.encr)
}
}
impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
{
/// Enable write encryption
pub async fn set_encrypted_write(&mut self, set: bool) -> eyre::Result<()>
{
if set {
let (key, iv) = ((),());
self.state.encw = true;
Ok(())
} else {
self.state.encw = false;
Ok(())
}
}
/// Get dynamic ref to unencrypted write+read
fn unencrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{
(self.tx.inner_mut(), self.rx.inner_mut())
}
/// Get dynamic ref to encrypted write+read
fn encrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{
(&mut self.tx, &mut self.rx)
}
/// Exchange keys.
pub async fn exchange(&mut self) -> eyre::Result<()>
{
use tokio::prelude::*;
let our_key = self.info.us.get_public_parts();
let (tx, rx) = self.inner_mut();
let read_fut = {
async move {
// Read the public key from `rx`.
//TODO: Find pubkey max size.
let mut sz_buf = [0u8; std::mem::size_of::<u64>()];
rx.read_exact(&mut sz_buf[..]).await?;
let sz= match usize::try_from(u64::from_be_bytes(sz_buf))? {
x if x > TRANS_KEY_MAX_SIZE => return Err(eyre!("Recv'd key size exceeded max")),
x => x
};
let mut key_bytes = Vec::with_capacity(sz);
tokio::io::copy(&mut rx.take(sz as u64), &mut key_bytes).await?;
if key_bytes.len() != sz {
return Err(eyre!("Could not read required bytes"));
}
let k = RsaPublicKey::from_bytes(key_bytes)?;
Result::<RsaPublicKey, eyre::Report>::Ok(k)
}
};
let write_fut = {
let key_bytes = our_key.to_bytes();
assert!(key_bytes.len() <= TRANS_KEY_MAX_SIZE);
let sz_buf = u64::try_from(key_bytes.len())?.to_be_bytes();
async move {
tx.write_all(&sz_buf[..]).await?;
tx.write_all(&key_bytes[..]).await?;
Result::<(), eyre::Report>::Ok(())
}
};
let (send, recv) = tokio::join! [write_fut, read_fut];
send?;
let recv = recv?;
self.info.them = Some(recv);
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)]
enum ExchangeState
{
/// We are currently reading/writing the buffer's size
BufferSize,
/// We are currently reading/writing the buffer itself
Buffer,
}
impl Default for ExchangeState
{
#[inline]
fn default() -> Self
{
Self::BufferSize
}
}
#[pin_project]
#[derive(Debug)]
pub struct Exchange<'a, W, R>
{
sock: &'a mut ESock<W, R>,
us: RsaPublicKey,
us_written: usize,
us_buf: Vec<u8>,
/// The return value
them: Option<RsaPublicKey>,
write_sz_num: usize,
write_sz_buf: [u8; std::mem::size_of::<u64>()],
read_sz_buf: [u8; std::mem::size_of::<u64>()],
read_buf: Vec<u8>,
write_state: ExchangeState,
read_state: ExchangeState,
#[pin] _pin: PhantomPinned,
}
/*
impl<'a, W: AsyncWrite, R: AsyncRead> Future for Exchange<'a, W, R>
{
type Output = eyre::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
use futures::ready;
let this = self.project();
let (tx, rx) = {
let sock = this.sock;
//XXX: Idk if this is safe?
unsafe {
(Pin::new_unchecked(&mut sock.tx), Pin::new_unchecked(&mut sock.rx))
}
};
if this.us_buf.is_empty() {
*this.us_buf = this.us.to_bytes();
}
let poll_write = loop {
break match this.write_state {
ExchangeState::BufferSize => {
if *this.write_sz_num == 0 {
*this.write_sz_buf = u64::try_from(this.us_buf.len())?.to_be_bytes();
}
// Write this to tx.
match tx.poll_write(cx, &this.write_sz_buf[(*this.write_sz_num)..]) {
x @ Poll::Ready(Ok(n)) => {
*this.write_sz_num+=n;
if *this.write_sz_num == this.write_sz_buf.len() {
// We've written all the size bytes, continue to writing the buffer bytes.
*this.write_state = ExchangeState::Buffer;
continue;
}
x
},
x => x,
}
},
ExchangeState::Buffer => {
match tx.poll_write(cx, &this.us_buf[(*this.us_written)..]) {
x @ Poll::Ready(Ok(n)) => {
if *this.us_written == this.us.len() {
}
x
},
x=> x,
}
},
}
};
let poll_read = match this.read_state {
ExchangeState::BufferSize => {
},
ExchangeState::Buffer => {
},
};
todo!("This is going to be dificult to implement... We don't have access to write_all and read_exact")
}
}
*/
/// Write half for `ESock`.
#[pin_project]
#[derive(Debug)]
pub struct ESockWriteHalf<W>(Arc<(ESockInfo, RwLock<ESockState>)>, #[pin] AsyncSink<W>);
/// Read half for `ESock`.
#[pin_project]
#[derive(Debug)]
pub struct ESockReadHalf<R>(Arc<(ESockInfo, RwLock<ESockState>)>, #[pin] AsyncSource<R>);
#[cfg(test)]
mod tests
{
#[test]
fn rsa_ciphertext_len() -> crate::eyre::Result<()>
{
let data = {
use chacha20stream::cha::{KEY_SIZE, IV_SIZE};
let (key, iv) = chacha20stream::cha::keygen();
let (sz, d) = crate::bin::collect_slices_exact::<&[u8], _, {KEY_SIZE + IV_SIZE}>([key.as_ref(), iv.as_ref()]);
assert_eq!(sz, d.len());
d
};
println!("KEY+IV: {} bytes", data.len());
let key = cryptohelpers::rsa::RsaPublicKey::generate()?;
let rsa = cryptohelpers::rsa::encrypt_slice_to_vec(data, &key)?;
println!("Rsa ciphertext size: {}", rsa.len());
assert_eq!(rsa.len(), super::RSA_CIPHERTEXT_SIZE, "Incorrect RSA ciphertext length constant for cc20 KEY+IV encoding.");
Ok(())
}
#[test]
fn rsa_serial_ciphertext_len() -> crate::eyre::Result<()>
{
let data = serde_cbor::to_vec(&{
let (key, iv) = chacha20stream::cha::keygen();
super::ESockSessionKey {
key, iv,
}
}).expect("Failed to CBOR encode Key+IV");
println!("(cbor) KEY+IV: {} bytes", data.len());
let key = cryptohelpers::rsa::RsaPublicKey::generate()?;
let rsa = cryptohelpers::rsa::encrypt_slice_to_vec(data, &key)?;
println!("Rsa ciphertext size: {}", rsa.len());
assert_eq!(rsa.len(), super::RSA_CIPHERTEXT_SIZE, "Incorrect RSA ciphertext length constant for cc20 KEY+IV CBOR encoding.");
Ok(())
}
}