@ -2,14 +2,8 @@
use super ::* ;
use cryptohelpers ::{
rsa ::{
self ,
RsaPublicKey ,
RsaPrivateKey ,
openssl ::{
symm ::Crypter ,
error ::ErrorStack ,
} ,
} ,
sha256 ,
} ;
@ -18,8 +12,6 @@ use chacha20stream::{
AsyncSource ,
Key , IV ,
cha ,
} ;
use std ::sync ::Arc ;
use tokio ::{
@ -28,24 +20,25 @@ use tokio::{
RwLockReadGuard ,
RwLockWriteGuard ,
} ,
io ::{
DuplexStream ,
} ,
} ;
use std ::{
io ,
fmt ,
task ::{
Context , Poll ,
} ,
pin ::Pin ,
marker ::Unpin ,
marker ::{
Unpin ,
PhantomPinned ,
} ,
} ;
use smallvec ::SmallVec ;
/// Size of a single RSA ciphertext.
pub const RSA_CIPHERTEXT_SIZE : usize = 512 ;
/// A single, full block of RSA ciphertext.
type RsaCiphertextBlock = [ u8 ; RSA_CIPHERTEXT_SIZE ] ;
/// Max size to read when exchanging keys
const TRANS_KEY_MAX_SIZE : usize = 4096 ;
@ -56,44 +49,13 @@ struct ESockInfo {
them : Option < RsaPublicKey > ,
}
impl ESockInfo
{
/// Generate a new private key
pub fn new ( us : impl Into < RsaPrivateKey > ) -> Self
{
Self {
us : us . into ( ) ,
them : None ,
}
}
/// Generate a new private key for the local endpoint
pub fn generate ( ) -> Result < Self , rsa ::Error >
{
Ok ( Self ::new ( RsaPrivateKey ::generate ( ) ? ) )
}
}
/// The encryption state of the Tx and Rx instances.
#[ derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone) ]
#[ derive(Debug) ]
struct ESockState {
encr : bool ,
encw : bool ,
}
impl Default for ESockState
{
#[ inline ]
fn default ( ) -> Self
{
Self {
encr : false ,
encw : false ,
}
}
}
/// Contains a cc20 Key and IV that can be serialized and then encrypted
/// Contains a Key and IV that can be serialized and then encrypted
#[ derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize) ]
struct ESockSessionKey
{
@ -101,71 +63,6 @@ struct ESockSessionKey
iv : IV ,
}
impl fmt ::Display for ESockSessionKey
{
fn fmt ( & self , f : & mut fmt ::Formatter < ' _ > ) -> fmt ::Result
{
write! ( f , "Key: {}, IV: {}" , self . key . hex ( ) , self . iv . hex ( ) )
}
}
impl ESockSessionKey
{
/// Generate a new cc20 key + iv,
pub fn generate ( ) -> Self
{
let ( key , iv ) = cha ::keygen ( ) ;
Self { key , iv }
}
/// Generate an encryption device
pub fn to_decrypter ( & self ) -> Result < Crypter , ErrorStack >
{
cha ::decrypter ( & self . key , & self . iv )
}
/// Generate an encryption device
pub fn to_encrypter ( & self ) -> Result < Crypter , ErrorStack >
{
cha ::encrypter ( & self . key , & self . iv )
}
/// Encrypt with RSA
pub fn to_ciphertext < K : ? Sized + rsa ::PublicKey > ( & self , rsa_key : & K ) -> eyre ::Result < RsaCiphertextBlock >
{
let mut output = [ 0 u8 ; RSA_CIPHERTEXT_SIZE ] ;
let mut temp = SmallVec ::< [ u8 ; RSA_CIPHERTEXT_SIZE ] > ::new ( ) ; // We know size will fit into here.
serde_cbor ::to_writer ( & mut temp , self )
. wrap_err ( eyre ! ( "Failed to CBOR encode session key to buffer" ) )
. with_section ( | | self . clone ( ) . header ( "Session key was" ) ) ? ;
debug_assert! ( temp . len ( ) < RSA_CIPHERTEXT_SIZE ) ;
let _wr = rsa ::encrypt_slice_sync ( & temp , rsa_key , & mut & mut output [ .. ] )
. wrap_err ( eyre ! ( "Failed to encrypt session key with RSA public key" ) )
. with_section ( | | self . clone ( ) . header ( "Session key was" ) )
. with_section ( { let temp = temp . len ( ) ; move | | temp . header ( "Encoded data size was" ) } )
. with_section ( move | | base64 ::encode ( temp ) . header ( "Encoded data (base64) was" ) ) ? ;
debug_assert_eq! ( _wr , output . len ( ) ) ;
Ok ( output )
}
/// Decrypt from RSA
pub fn from_ciphertext < K : ? Sized + rsa ::PrivateKey > ( data : & [ u8 ; RSA_CIPHERTEXT_SIZE ] , rsa_key : & K ) -> eyre ::Result < Self >
where < K as rsa ::PublicKey > ::KeyType : rsa ::openssl ::pkey ::HasPrivate //ugh, why do we have to have this bound??? it should be implied ffs... :/
{
let mut temp = SmallVec ::< [ u8 ; RSA_CIPHERTEXT_SIZE ] > ::new ( ) ;
rsa ::decrypt_slice_sync ( data , rsa_key , & mut temp )
. wrap_err ( eyre ! ( "Failed to decrypt ciphertext to session key" ) )
. with_section ( { let data = data . len ( ) ; move | | data . header ( "Ciphertext length was" ) } )
. with_section ( | | base64 ::encode ( data ) . header ( "Ciphertext was" ) ) ? ;
Ok ( serde_cbor ::from_slice ( & temp [ .. ] )
. wrap_err ( eyre ! ( "Failed to decode CBOR data to session key object" ) )
. with_section ( { let temp = temp . len ( ) ; move | | temp . header ( "Encoded data size was" ) } )
. with_section ( move | | base64 ::encode ( temp ) . header ( "Encoded data (base64) was" ) ) ? )
}
}
/// A tx+rx socket.
#[ pin_project ]
#[ derive(Debug) ]
@ -182,7 +79,7 @@ pub struct ESock<W, R> {
impl < W : AsyncWrite , R : AsyncRead > ESock < W , R >
{
fn inner ( & self ) -> ( & W , & R )
pub fn inner ( & self ) -> ( & W , & R )
{
( self . tx . inner ( ) , self . rx . inner ( ) )
}
@ -192,6 +89,31 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
( self . tx . inner_mut ( ) , self . rx . inner_mut ( ) )
}
/// Create a future that exchanges keys
pub fn exchange_unsafe ( & mut self ) -> Exchange < ' _ , W , R >
{
let us = self . info . us . get_public_parts ( ) ;
todo! ( "Currently unimplemented" )
/*
Exchange {
sock : self ,
write_buf : Default ::default ( ) ,
read_buf : Default ::default ( ) ,
_pin : PhantomPinned ,
read_state : Default ::default ( ) ,
write_state : Default ::default ( ) ,
us ,
them : None ,
us_written : Default ::default ( ) ,
us_buf : ( ) ,
write_sz_num : ( ) ,
write_sz_buf : ( ) ,
read_sz_buf : ( ) ,
} * /
}
///Get a mutable ref to unencrypted read+write
fn unencrypted ( & mut self ) -> ( & mut W , & mut R )
{
@ -210,155 +132,19 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
}
/// Is the Write + Read operation encrypted? Tuple is `(Tx, Rx)`.
#[ inline ] pub fn is_encrypted ( & self ) -> ( bool , bool )
pub fn is_encrypted ( & self ) -> ( bool , bool )
{
( self . state . encw , self . state . encr )
}
/// Create a new `ESock` wrapper over this writer and reader with this specific RSA key.
pub fn with_key ( key : impl Into < RsaPrivateKey > , tx : W , rx : R ) -> Self
{
let ( tk , tiv ) = cha ::keygen ( ) ;
Self {
info : ESockInfo ::new ( key ) ,
state : Default ::default ( ) ,
// Note: These key+IV pairs are never used, as `state` defaults to unencrypted, and a new key/iv pair is generated when we `set_encrypted_write/read(true)`.
// TODO: Have a method to exchange these default session keys after `exchange()`?
tx : AsyncSink ::encrypt ( tx , tk , tiv ) . expect ( "Failed to create temp AsyncSink" ) ,
rx : AsyncSource ::encrypt ( rx , tk , tiv ) . expect ( "Failed to create temp AsyncSource" ) ,
}
}
/// Create a new `ESock` wrapper over this writer and reader with a newly generated private key
#[ inline ] pub fn new ( tx : W , rx : R ) -> Result < Self , rsa ::Error >
{
Ok ( Self ::with_key ( RsaPrivateKey ::generate ( ) ? , tx , rx ) )
}
/// The local RSA private key
#[ inline ] pub fn local_key ( & self ) -> & RsaPrivateKey
{
& self . info . us
}
/// THe remote RSA public key (if exchange has happened.)
#[ inline ] pub fn foreign_key ( & self ) -> Option < & RsaPublicKey >
{
self . info . them . as_ref ( )
}
/// Split this `ESock` into a read+write pair.
///
/// # Note
/// You must preform an `exchange()` before splitting, as exchanging RSA keys is not possible on a single half.
///
/// It is also more efficient to `set_encrypted_write/read(true)` on `ESock` than it is on the halves, but changinc encryption modes on halves is still possible.
pub fn split ( self ) -> ( ESockWriteHalf < W > , ESockReadHalf < R > )
{
let arced = Arc ::new ( self . info ) ;
( ESockWriteHalf ( Arc ::clone ( & arced ) , self . tx , self . state . encw ) ,
ESockReadHalf ( arced , self . rx , self . state . encr ) )
}
/// Merge a previously split `ESock` into a single one again.
///
/// # Panics
/// If the two halves were not split from the same `ESock`.
pub fn unsplit ( txh : ESockWriteHalf < W > , rxh : ESockReadHalf < R > ) -> Self
{
#[ cold ]
#[ inline(never) ]
fn _panic_ptr_ineq ( ) -> !
{
panic! ( "Cannot merge halves split from different sources" )
}
if ! Arc ::ptr_eq ( & txh . 0 , & rxh . 0 ) {
_panic_ptr_ineq ( ) ;
}
let tx = txh . 1 ;
drop ( txh . 0 ) ;
let info = Arc ::try_unwrap ( rxh . 0 ) . unwrap ( ) ;
let rx = rxh . 1 ;
Self {
state : ESockState {
encw : txh . 2 ,
encr : rxh . 2 ,
} ,
info ,
tx , rx
}
}
}
async fn set_encrypted_write_for < T : AsyncWrite + Unpin > ( info : & ESockInfo , tx : & mut AsyncSink < T > ) -> eyre ::Result < ( ) >
{
use tokio ::prelude ::* ;
let session_key = ESockSessionKey ::generate ( ) ;
let data = {
let them = info . them . as_ref ( ) . expect ( "Cannot set encrypted write when keys have not been exchanged" ) ;
session_key . to_ciphertext ( them )
. wrap_err ( eyre ! ( "Failed to encrypt session key with foreign endpoint's key" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Session key was" ) )
. with_section ( | | them . to_string ( ) . header ( "Foreign pubkey was" ) ) ?
} ;
let crypter = session_key . to_encrypter ( )
. wrap_err ( eyre ! ( "Failed to create encryption device from session key for Tx" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Session key was" ) ) ? ;
// Send rsa `data` over unencrypted endpoint
tx . inner_mut ( ) . write_all ( & data [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to write ciphertext to endpoint" ) )
. with_section ( | | data . to_base64_string ( ) . header ( "Ciphertext of session key was" ) ) ? ;
// Set crypter of `tx` to `session_key`.
* tx . crypter_mut ( ) = crypter ;
Ok ( ( ) )
}
async fn set_encrypted_read_for < T : AsyncRead + Unpin > ( info : & ESockInfo , rx : & mut AsyncSource < T > ) -> eyre ::Result < ( ) >
{
use tokio ::prelude ::* ;
let mut data = [ 0 u8 ; RSA_CIPHERTEXT_SIZE ] ;
// Read `data` from unencrypted endpoint
rx . inner_mut ( ) . read_exact ( & mut data [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to read ciphertext from endpoint" ) ) ? ;
// Decrypt `data`
let session_key = ESockSessionKey ::from_ciphertext ( & data , & info . us )
. wrap_err ( eyre ! ( "Failed to decrypt session key from ciphertext" ) )
. with_section ( | | data . to_base64_string ( ) . header ( "Ciphertext was" ) )
. with_section ( | | info . us . to_string ( ) . header ( "Our RSA key is" ) ) ? ;
// Set crypter of `rx` to `session_key`.
* rx . crypter_mut ( ) = session_key . to_decrypter ( )
. wrap_err ( eyre ! ( "Failed to create decryption device from session key for Rx" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Decrypted session key was" ) ) ? ;
Ok ( ( ) )
}
impl < W : AsyncWrite + Unpin , R : AsyncRead + Unpin > ESock < W , R >
{
/// Get the Tx and Rx of the stream.
///
/// # Returns
/// Returns encrypted stream halfs if the stream is encrypted, unencrypted if not.
pub fn stream ( & mut self ) -> ( & mut ( dyn AsyncWrite + Unpin + ' _ ) , & mut ( dyn AsyncRead + Unpin + ' _ ) )
{
( if self . state . encw {
& mut self . tx
} else {
self . tx . inner_mut ( )
} , if self . state . encr {
& mut self . rx
} else {
self . rx . inner_mut ( )
} )
}
/// Enable write encryption
pub async fn set_encrypted_write ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
if set {
set_encrypted_write_for ( & self . info , & mut self . tx ) . await ? ;
// Set `encw` to true
let ( key , iv ) = ( ( ) , ( ) ) ;
self . state . encw = true ;
Ok ( ( ) )
} else {
@ -367,22 +153,6 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
}
}
/// Enable read encryption
///
/// The other endpoint must have sent a `set_encrypted_write()`
pub async fn set_encrypted_read ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
if set {
set_encrypted_read_for ( & self . info , & mut self . rx ) . await ? ;
// Set `encr` to true
self . state . encr = true ;
Ok ( ( ) )
} else {
self . state . encr = false ;
Ok ( ( ) )
}
}
/// Get dynamic ref to unencrypted write+read
fn unencrypted_dyn ( & mut self ) -> ( & mut ( dyn AsyncWrite + Unpin + ' _ ) , & mut ( dyn AsyncRead + Unpin + ' _ ) )
{
@ -405,29 +175,17 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
// Read the public key from `rx`.
//TODO: Find pubkey max size.
let mut sz_buf = [ 0 u8 ; std ::mem ::size_of ::< u64 > ( ) ] ;
rx . read_exact ( & mut sz_buf [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to read size of pubkey form endpoint" ) ) ? ;
let sz64 = u64 ::from_be_bytes ( sz_buf ) ;
let sz = match usize ::try_from ( sz64 )
. wrap_err ( eyre ! ( "Read size could not fit into u64" ) )
. with_section ( | | format! ( "{:?}" , sz_buf ) . header ( "Read buffer was" ) )
. with_section ( | | u64 ::from_be_bytes ( sz_buf ) . header ( "64=bit size value was" ) )
. with_warning ( | | "This should not happen, it is only possible when you are running a machine with a pointer size lower than 64 bits." )
. with_suggestion ( | | "The message is likely malformed. If it is not, then you are communicating with an endpoint of 64 bits whereas your pointer size is far less." ) ? {
x if x > TRANS_KEY_MAX_SIZE = > return Err ( eyre ! ( "Recv'd key size exceeded max acceptable key buffer size" ) ) ,
rx . read_exact ( & mut sz_buf [ .. ] ) . await ? ;
let sz = match usize ::try_from ( u64 ::from_be_bytes ( sz_buf ) ) ? {
x if x > TRANS_KEY_MAX_SIZE = > return Err ( eyre ! ( "Recv'd key size exceeded max" ) ) ,
x = > x
} ;
let mut key_bytes = Vec ::with_capacity ( sz ) ;
tokio ::io ::copy ( & mut rx . take ( sz64 ) , & mut key_bytes ) . await
. wrap_err ( "Failed to read key bytes into buffer" )
. with_section ( move | | sz64 . header ( "Pubkey size to read was" ) ) ? ;
tokio ::io ::copy ( & mut rx . take ( sz as u64 ) , & mut key_bytes ) . await ? ;
if key_bytes . len ( ) ! = sz {
return Err ( eyre ! ( "Could not read required bytes" ) ) ;
}
let k = RsaPublicKey ::from_bytes ( & key_bytes )
. wrap_err ( "Failed to construct RSA public key from read bytes" )
. with_section ( | | sz . header ( "Pubkey size was" ) )
. with_section ( move | | key_bytes . to_base64_string ( ) . header ( "Pubkey bytes were" ) ) ? ;
let k = RsaPublicKey ::from_bytes ( key_bytes ) ? ;
Result ::< RsaPublicKey , eyre ::Report > ::Ok ( k )
}
@ -435,231 +193,145 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
let write_fut = {
let key_bytes = our_key . to_bytes ( ) ;
assert! ( key_bytes . len ( ) < = TRANS_KEY_MAX_SIZE ) ;
let sz64 = u64 ::try_from ( key_bytes . len ( ) )
. wrap_err ( eyre ! ( "Size of our pubkey could not fit into u64" ) )
. with_section ( | | key_bytes . len ( ) . header ( "Size was" ) )
. with_warning ( | | "This should not happen, it is only possible when you are running a machine with a pointer size larger than 64 bits." )
. with_warning ( | | "There was likely internal memory corruption." ) ? ;
let sz_buf = sz64 . to_be_bytes ( ) ;
let sz_buf = u64 ::try_from ( key_bytes . len ( ) ) ? . to_be_bytes ( ) ;
async move {
tx . write_all ( & sz_buf [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to write key size" ) )
. with_section ( | | sz64 . header ( "Key size bytes were" ) )
. with_section ( | | format! ( "{:?}" , sz_buf ) . header ( "Key size bytes (BE) were" ) ) ? ;
tx . write_all ( & key_bytes [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to write key bytes" ) )
. with_section ( | | sz64 . header ( "Size of key was" ) )
. with_section ( | | key_bytes . to_base64_string ( ) . header ( "Key bytes are" ) ) ? ;
tx . write_all ( & sz_buf [ .. ] ) . await ? ;
tx . write_all ( & key_bytes [ .. ] ) . await ? ;
Result ::< ( ) , eyre ::Report > ::Ok ( ( ) )
}
} ;
let ( send , recv ) = tokio ::join ! [ write_fut , read_fut ] ;
send . wrap_err ( "Failed to send our pubkey" ) ? ;
let recv = recv . wrap_err ( "Failed to receive foreign pubkey" ) ? ;
send ? ;
let recv = recv ? ;
self . info . them = Some ( recv ) ;
Ok ( ( ) )
}
}
//XXX: For some reason, non-exact reads + writes cause garbage to be produced on the receiving end?
// Is this fixable? Why does it disjoint? I have no idea... This is supposed to be a stream cipher, right? Why does positioning matter? Have I misunderstood how it workd? Eh...
// With this bug, it seems the `while read(buffer) > 0` construct is impossible. This might make this entirely useless. Hopefully with the rigid size-based format for `Message` we won't run into this problem, but subsequent data streaming will likely be affected unless we use rigid, fixed, and (inefficiently) communicated buffer sizes.
impl < W , R > AsyncWrite for ESock < W , R >
where W : AsyncWrite
#[ derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord) ]
enum ExchangeState
{
fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io ::Error > > {
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
if self . state . encw {
self . project ( ) . tx . poll_write ( cx , buf )
} else {
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self . map_unchecked_mut ( | this | this . tx . inner_mut ( ) ) . poll_write ( cx , buf ) }
}
}
fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io ::Error > > {
// Should we do anything else here?
// Should we clear foreign key/current session key?
self . project ( ) . tx . poll_shutdown ( cx )
}
fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io ::Error > > {
self . project ( ) . tx . poll_flush ( cx )
}
/// We are currently reading/writing the buffer's size
BufferSize ,
/// We are currently reading/writing the buffer itself
Buffer ,
}
impl < W , R > AsyncRead for ESock < W , R >
where R : AsyncRead
impl Default for ExchangeState
{
fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut [ u8 ] ) -> Poll < io ::Result < usize > > {
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
if self . state . encr {
self . project ( ) . rx . poll_read ( cx , buf )
} else {
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self . map_unchecked_mut ( | this | this . rx . inner_mut ( ) ) . poll_read ( cx , buf ) }
}
#[ inline ]
fn default ( ) -> Self
{
Self ::BufferSize
}
}
/// Write half for `ESock`.
#[ pin_project ]
#[ derive(Debug) ]
pub struct ESockWriteHalf < W > ( Arc < ESockInfo > , #[ pin ] AsyncSink < W > , bool ) ;
/// Read half for `ESock`.
#[ pin_project ]
#[ derive(Debug) ]
pub struct ESockReadHalf < R > ( Arc < ESockInfo > , #[ pin ] AsyncSource < R > , bool ) ;
pub struct Exchange < ' a , W , R >
{
sock : & ' a mut ESock < W , R > ,
//Impl AsyncRead/Write + set_encrypted_read/write for ESockRead/WriteHalf.
us : RsaPublicKey ,
impl < W : AsyncWrite > ESockWriteHalf < W >
{
/// Does this write half have a live corresponding read half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional ( & self ) -> bool
{
Arc ::strong_count ( & self . 0 ) > 1
}
us_written : usize ,
us_buf : Vec < u8 > ,
/// Is write encrypted on this half?
#[ inline(always) ] pub fn is_encrypted ( & self ) -> bool
{
self . 2
}
/// The local RSA private key
#[ inline ] pub fn local_key ( & self ) -> & RsaPrivateKey
{
& self . 0. us
}
/// THe remote RSA public key (if exchange has happened.)
#[ inline ] pub fn foreign_key ( & self ) -> Option < & RsaPublicKey >
{
self . 0. them . as_ref ( )
}
/// The return value
them : Option < RsaPublicKey > ,
/// End an encrypted session syncronously.
///
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
pub fn clear_encryption ( & mut self )
{
self . 2 = false ;
}
}
write_sz_num : usize ,
write_sz_buf : [ u8 ; std ::mem ::size_of ::< u64 > ( ) ] ,
read_sz_buf : [ u8 ; std ::mem ::size_of ::< u64 > ( ) ] ,
impl < R : AsyncRead > ESockReadHalf < R >
{
/// Does this read half have a live corresponding write half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional ( & self ) -> bool
{
Arc ::strong_count ( & self . 0 ) > 1
}
read_buf : Vec < u8 > ,
/// Is write encrypted on this half?
#[ inline(always) ] pub fn is_encrypted ( & self ) -> bool
{
self . 2
}
/// The local RSA private key
#[ inline ] pub fn local_key ( & self ) -> & RsaPrivateKey
{
& self . 0. us
}
/// THe remote RSA public key (if exchange has happened.)
#[ inline ] pub fn foreign_key ( & self ) -> Option < & RsaPublicKey >
{
self . 0. them . as_ref ( )
}
write_state : ExchangeState ,
read_state : ExchangeState ,
/// End an encrypted session syncronously.
///
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
pub fn clear_encryption ( & mut self )
{
self . 2 = false ;
}
#[ pin ] _pin : PhantomPinned ,
}
impl < W : AsyncWrite + Unpin > ESockWriteHalf < W >
/*
impl < ' a , W : AsyncWrite , R : AsyncRead > Future for Exchange < ' a , W , R >
{
/// Begin or end an encrypted writing session
pub async fn set_encryption ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
if set {
set_encrypted_write_for ( & self . 0 , & mut self . 1 ) . await ? ;
self . 2 = true ;
} else {
self . 2 = false ;
}
Ok ( ( ) )
}
type Output = eyre ::Result < ( ) > ;
fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self ::Output > {
use futures ::ready ;
let this = self . project ( ) ;
let ( tx , rx ) = {
let sock = this . sock ;
//XXX: Idk if this is safe?
unsafe {
( Pin ::new_unchecked ( & mut sock . tx ) , Pin ::new_unchecked ( & mut sock . rx ) )
}
} ;
impl < R : AsyncRead + Unpin > ESockReadHalf < R >
{
/// Begin or end an encrypted reading session
pub async fn set_encryption ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
if set {
set_encrypted_read_for ( & self . 0 , & mut self . 1 ) . await ? ;
self . 2 = true ;
} else {
self . 2 = false ;
if this . us_buf . is_empty ( ) {
* this . us_buf = this . us . to_bytes ( ) ;
}
Ok ( ( ) )
let poll_write = loop {
break match this . write_state {
ExchangeState ::BufferSize = > {
if * this . write_sz_num = = 0 {
* this . write_sz_buf = u64 ::try_from ( this . us_buf . len ( ) ) ? . to_be_bytes ( ) ;
}
// Write this to tx.
match tx . poll_write ( cx , & this . write_sz_buf [ ( * this . write_sz_num ) .. ] ) {
x @ Poll ::Ready ( Ok ( n ) ) = > {
* this . write_sz_num + = n ;
if * this . write_sz_num = = this . write_sz_buf . len ( ) {
// We've written all the size bytes, continue to writing the buffer bytes.
* this . write_state = ExchangeState ::Buffer ;
continue ;
}
impl < W : AsyncWrite > AsyncWrite for ESockWriteHalf < W >
{
fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io ::Error > > {
if self . 2 {
// Encrypted
self . project ( ) . 1. poll_write ( cx , buf )
} else {
// Unencrypted
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self . map_unchecked_mut ( | this | this . 1. inner_mut ( ) ) . poll_write ( cx , buf ) }
}
x
} ,
x = > x ,
}
#[ inline(always) ] fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io ::Error > > {
self . project ( ) . 1. poll_flush ( cx )
} ,
ExchangeState ::Buffer = > {
match tx . poll_write ( cx , & this . us_buf [ ( * this . us_written ) .. ] ) {
x @ Poll ::Ready ( Ok ( n ) ) = > {
if * this . us_written = = this . us . len ( ) {
}
#[ inline(always) ] fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io ::Error > > {
self . project ( ) . 1. poll_flush ( cx )
x
} ,
x = > x ,
}
} ,
}
} ;
let poll_read = match this . read_state {
ExchangeState ::BufferSize = > {
impl < R : AsyncRead > AsyncRead for ESockReadHalf < R >
{
fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut [ u8 ] ) -> Poll < io ::Result < usize > > {
if self . 2 {
// Encrypted
self . project ( ) . 1. poll_read ( cx , buf )
} else {
// Unencrypted
} ,
ExchangeState ::Buffer = > {
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self . map_unchecked_mut ( | this | this . 1. inner_mut ( ) ) . poll_read ( cx , buf ) }
}
} ,
} ;
todo! ( "This is going to be dificult to implement... We don't have access to write_all and read_exact" )
}
}
* /
/// Write half for `ESock`.
#[ pin_project ]
#[ derive(Debug) ]
pub struct ESockWriteHalf < W > ( Arc < ( ESockInfo , RwLock < ESockState > ) > , #[ pin ] AsyncSink < W > ) ;
/// Read half for `ESock`.
#[ pin_project ]
#[ derive(Debug) ]
pub struct ESockReadHalf < R > ( Arc < ( ESockInfo , RwLock < ESockState > ) > , #[ pin ] AsyncSource < R > ) ;
#[ cfg(test) ]
mod tests
{
use super ::ESock ;
#[ test ]
fn rsa_ciphertext_len ( ) -> crate ::eyre ::Result < ( ) >
{
@ -699,176 +371,4 @@ mod tests
Ok ( ( ) )
}
fn gen_duplex_esock ( bufsz : usize ) -> crate ::eyre ::Result < ( ESock < tokio ::io ::DuplexStream , tokio ::io ::DuplexStream > , ESock < tokio ::io ::DuplexStream , tokio ::io ::DuplexStream > ) >
{
use crate ::* ;
let ( atx , brx ) = tokio ::io ::duplex ( bufsz ) ;
let ( btx , arx ) = tokio ::io ::duplex ( bufsz ) ;
let tx = ESock ::new ( atx , arx ) . wrap_err ( eyre ! ( "Failed to create TX" ) ) ? ;
let rx = ESock ::new ( btx , brx ) . wrap_err ( eyre ! ( "Failed to create RX" ) ) ? ;
Ok ( ( tx , rx ) )
}
#[ tokio::test ]
async fn esock_exchange ( ) -> crate ::eyre ::Result < ( ) >
{
use crate ::* ;
const VALUE : & ' static [ u8 ] = b" Hello world! " ;
// The duplex buffer size here is smaller than an RSA ciphertext block. So, writing the session key must be buffered with a buffer size this small (should return Pending at least once.)
// Using a low buffer size to make sure the test passes even when the entire buffer cannot be written at once.
let ( mut tx , mut rx ) = gen_duplex_esock ( 16 ) . wrap_err ( eyre ! ( "Failed to weave socks" ) ) ? ;
let writer = tokio ::spawn ( async move {
use tokio ::prelude ::* ;
tx . exchange ( ) . await ? ;
assert! ( tx . has_exchanged ( ) ) ;
tx . set_encrypted_write ( true ) . await ? ;
assert_eq! ( ( true , false ) , tx . is_encrypted ( ) ) ;
tx . write_all ( VALUE ) . await ? ;
tx . write_all ( VALUE ) . await ? ;
// Check resp
tx . set_encrypted_read ( true ) . await ? ;
assert_eq! ( {
let mut chk = [ 0 u8 ; 3 ] ;
tx . read_exact ( & mut chk [ .. ] ) . await ? ;
chk
} , [ 0xaa u8 , 0 , 0 ] , "Failed response check" ) ;
// Write unencrypted
tx . set_encrypted_write ( false ) . await ? ;
tx . write_all ( & [ 2 , 1 , 0xfa ] ) . await ? ;
Result ::< _ , eyre ::Report > ::Ok ( VALUE )
} ) ;
let reader = tokio ::spawn ( async move {
use tokio ::prelude ::* ;
rx . exchange ( ) . await ? ;
assert! ( rx . has_exchanged ( ) ) ;
rx . set_encrypted_read ( true ) . await ? ;
assert_eq! ( ( false , true ) , rx . is_encrypted ( ) ) ;
let mut val = vec! [ 0 u8 ; VALUE . len ( ) ] ;
rx . read_exact ( & mut val [ .. ] ) . await ? ;
let mut val2 = vec! [ 0 u8 ; VALUE . len ( ) ] ;
rx . read_exact ( & mut val2 [ .. ] ) . await ? ;
assert_eq! ( val , val2 ) ;
// Send resp
rx . set_encrypted_write ( true ) . await ? ;
rx . write_all ( & [ 0xaa , 0 , 0 ] ) . await ? ;
// Read unencrypted
rx . set_encrypted_read ( false ) . await ? ;
assert_eq! ( {
let mut buf = [ 0 u8 ; 3 ] ;
rx . read_exact ( & mut buf [ .. ] ) . await ? ;
buf
} , [ 2 u8 , 1 , 0xfa ] , "2nd response incorrect" ) ;
Result ::< _ , eyre ::Report > ::Ok ( val )
} ) ;
let ( writer , reader ) = tokio ::join ! [ writer , reader ] ;
let writer = writer . expect ( "Tx task panic" ) ;
let reader = reader . expect ( "Rx task panic" ) ;
eprintln! ( "Txr: {:?}" , writer ) ;
eprintln! ( "Rxr: {:?}" , reader ) ;
writer ? ;
let val = reader ? ;
println! ( "Read: {:?}" , val ) ;
assert_eq! ( & val , VALUE ) ;
Ok ( ( ) )
}
#[ tokio::test ]
async fn esock_split ( ) -> crate ::eyre ::Result < ( ) >
{
use super ::* ;
const SLICES : & ' static [ & ' static [ u8 ] ] = & [
& [ 1 , 5 , 3 , 7 , 6 , 9 , 100 , 0 ] ,
& [ 7 , 6 , 2 , 90 ] ,
& [ 3 , 6 , 1 , 0 ] ,
& [ 5 , 1 , 3 , 3 ] ,
] ;
let result = SLICES . iter ( ) . map ( | & slice | slice . iter ( ) . map ( | & b | u64 ::from ( b ) ) . sum ::< u64 > ( ) ) . sum ::< u64 > ( ) ;
println! ( "Result: {}" , result ) ;
let ( mut tx , mut rx ) = gen_duplex_esock ( super ::TRANS_KEY_MAX_SIZE * 4 ) . wrap_err ( eyre ! ( "Failed to weave socks" ) ) ? ;
let ( writer , reader ) = {
use tokio ::prelude ::* ;
let writer = tokio ::spawn ( async move {
tx . exchange ( ) . await ? ;
let ( mut tx , mut rx ) = tx . split ( ) ;
//tx.set_encryption(true).await?;
let slices = & SLICES [ 1 .. ] ;
for & slice in slices . iter ( )
{
println! ( "Writing slice: {:?}" , slice ) ;
tx . write_all ( slice ) . await ? ;
}
//let mut tx = ESock::unsplit(tx, rx);
tx . write_all ( SLICES [ 0 ] ) . await ? ;
Result ::< _ , eyre ::Report > ::Ok ( ( ) )
} ) ;
let reader = tokio ::spawn ( async move {
rx . exchange ( ) . await ? ;
let ( mut tx , mut rx ) = rx . split ( ) ;
//rx.set_encryption(true).await?;
let ( mut mtx , mut mrx ) = tokio ::sync ::mpsc ::channel ::< Vec < u8 > > ( 16 ) ;
let sorter = tokio ::spawn ( async move {
let mut done = 0 u64 ;
while let Some ( buf ) = mrx . recv ( ) . await
{
//buf.sort();
done + = buf . iter ( ) . map ( | & b | u64 ::from ( b ) ) . sum ::< u64 > ( ) ;
println! ( "Got buffer: {:?}" , buf ) ;
tx . write_all ( & buf ) . await ? ;
}
Result ::< _ , eyre ::Report > ::Ok ( done )
} ) ;
let mut buffer = [ 0 u8 ; 16 ] ;
while let Ok ( read ) = rx . read ( & mut buffer [ .. ] ) . await
{
if read = = 0 {
break ;
}
mtx . send ( Vec ::from ( & buffer [ .. read ] ) ) . await ? ;
}
drop ( mtx ) ;
let sum = sorter . await . expect ( "(reader) Sorter task panic" ) ? ;
Result ::< _ , eyre ::Report > ::Ok ( sum )
} ) ;
let ( writer , reader ) = tokio ::join ! [ writer , reader ] ;
( writer . expect ( "Writer task panic" ) ,
reader . expect ( "Reader task panic" ) )
} ;
writer ? ;
assert_eq! ( result , reader ? ) ;
Ok ( ( ) )
}
}