@ -210,7 +210,7 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
}
/// Is the Write + Read operation encrypted? Tuple is `(Tx, Rx)`.
pub fn is_encrypted ( & self ) -> ( bool , bool )
#[ inline ] pub fn is_encrypted ( & self ) -> ( bool , bool )
{
( self . state . encw , self . state . encr )
}
@ -254,10 +254,10 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
/// It is also more efficient to `set_encrypted_write/read(true)` on `ESock` than it is on the halves, but changinc encryption modes on halves is still possible.
pub fn split ( self ) -> ( ESockWriteHalf < W > , ESockReadHalf < R > )
{
let arced = Arc ::new ( ( self . info , RwLock ::new ( self . state ) ) ) ;
let arced = Arc ::new ( self . info ) ;
( ESockWriteHalf ( Arc ::clone ( & arced ) , self . tx ),
ESockReadHalf ( arced , self . rx ))
( ESockWriteHalf ( Arc ::clone ( & arced ) , self . tx , self . state . encw ),
ESockReadHalf ( arced , self . rx , self . state . encr ))
}
/// Merge a previously split `ESock` into a single one again.
@ -270,7 +270,7 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
#[ inline(never) ]
fn _panic_ptr_ineq ( ) -> !
{
panic! ( "Cannot merge halves of different socket s")
panic! ( "Cannot merge halves split from different source s")
}
if ! Arc ::ptr_eq ( & txh . 0 , & rxh . 0 ) {
_panic_ptr_ineq ( ) ;
@ -278,17 +278,63 @@ impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
let tx = txh . 1 ;
drop ( txh . 0 ) ;
let ( info , lstate ) = Arc ::try_unwrap ( rxh . 0 ) . unwrap ( ) ;
let info = Arc ::try_unwrap ( rxh . 0 ) . unwrap ( ) ;
let rx = rxh . 1 ;
Self {
state : lstate . into_inner ( ) ,
state : ESockState {
encw : txh . 2 ,
encr : rxh . 2 ,
} ,
info ,
tx , rx
}
}
}
async fn set_encrypted_write_for < T : AsyncWrite + Unpin > ( info : & ESockInfo , tx : & mut AsyncSink < T > ) -> eyre ::Result < ( ) >
{
use tokio ::prelude ::* ;
let session_key = ESockSessionKey ::generate ( ) ;
let data = {
let them = info . them . as_ref ( ) . expect ( "Cannot set encrypted write when keys have not been exchanged" ) ;
session_key . to_ciphertext ( them )
. wrap_err ( eyre ! ( "Failed to encrypt session key with foreign endpoint's key" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Session key was" ) )
. with_section ( | | them . to_string ( ) . header ( "Foreign pubkey was" ) ) ?
} ;
let crypter = session_key . to_encrypter ( )
. wrap_err ( eyre ! ( "Failed to create encryption device from session key for Tx" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Session key was" ) ) ? ;
// Send rsa `data` over unencrypted endpoint
tx . inner_mut ( ) . write_all ( & data [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to write ciphertext to endpoint" ) )
. with_section ( | | data . to_base64_string ( ) . header ( "Ciphertext of session key was" ) ) ? ;
// Set crypter of `tx` to `session_key`.
* tx . crypter_mut ( ) = crypter ;
Ok ( ( ) )
}
async fn set_encrypted_read_for < T : AsyncRead + Unpin > ( info : & ESockInfo , rx : & mut AsyncSource < T > ) -> eyre ::Result < ( ) >
{
use tokio ::prelude ::* ;
let mut data = [ 0 u8 ; RSA_CIPHERTEXT_SIZE ] ;
// Read `data` from unencrypted endpoint
rx . inner_mut ( ) . read_exact ( & mut data [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to read ciphertext from endpoint" ) ) ? ;
// Decrypt `data`
let session_key = ESockSessionKey ::from_ciphertext ( & data , & info . us )
. wrap_err ( eyre ! ( "Failed to decrypt session key from ciphertext" ) )
. with_section ( | | data . to_base64_string ( ) . header ( "Ciphertext was" ) )
. with_section ( | | info . us . to_string ( ) . header ( "Our RSA key is" ) ) ? ;
// Set crypter of `rx` to `session_key`.
* rx . crypter_mut ( ) = session_key . to_decrypter ( )
. wrap_err ( eyre ! ( "Failed to create decryption device from session key for Rx" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Decrypted session key was" ) ) ? ;
Ok ( ( ) )
}
impl < W : AsyncWrite + Unpin , R : AsyncRead + Unpin > ESock < W , R >
{
/// Get the Tx and Rx of the stream.
@ -308,28 +354,10 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
} )
}
/// Enable write encryption
//TODO: Implement this also for write half
pub async fn set_encrypted_write ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
use tokio ::prelude ::* ;
if set {
let session_key = ESockSessionKey ::generate ( ) ;
let data = {
let them = self . info . them . as_ref ( ) . expect ( "Cannot set encrypted write when keys have not been exchanged" ) ;
session_key . to_ciphertext ( them )
. wrap_err ( eyre ! ( "Failed to encrypt session key with foreign endpoint's key" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Session key was" ) )
. with_section ( | | them . to_string ( ) . header ( "Foreign pubkey was" ) ) ?
} ;
let crypter = session_key . to_encrypter ( )
. wrap_err ( eyre ! ( "Failed to create encryption device from session key for Tx" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Session key was" ) ) ? ;
// Send rsa `data` over unencrypted endpoint
self . unencrypted ( ) . 0. write_all ( & data [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to write ciphertext to endpoint" ) )
. with_section ( | | data . to_base64_string ( ) . header ( "Ciphertext of session key was" ) ) ? ;
// Set crypter of `tx` to `session_key`.
* self . tx . crypter_mut ( ) = crypter ;
set_encrypted_write_for ( & self . info , & mut self . tx ) . await ? ;
// Set `encw` to true
self . state . encw = true ;
Ok ( ( ) )
@ -342,29 +370,14 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
/// Enable read encryption
///
/// The other endpoint must have sent a `set_encrypted_write()`
//TODO: Implement this also for read half
pub async fn set_encrypted_read ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
use tokio ::prelude ::* ;
if set {
let mut data = [ 0 u8 ; RSA_CIPHERTEXT_SIZE ] ;
// Read `data` from unencrypted endpoint
self . unencrypted ( ) . 1. read_exact ( & mut data [ .. ] ) . await
. wrap_err ( eyre ! ( "Failed to read ciphertext from endpoint" ) ) ? ;
// Decrypt `data`
let session_key = ESockSessionKey ::from_ciphertext ( & data , & self . info . us )
. wrap_err ( eyre ! ( "Failed to decrypt session key from ciphertext" ) )
. with_section ( | | data . to_base64_string ( ) . header ( "Ciphertext was" ) )
. with_section ( | | self . info . us . to_string ( ) . header ( "Our RSA key is" ) ) ? ;
// Set crypter of `rx` to `session_key`.
* self . rx . crypter_mut ( ) = session_key . to_decrypter ( )
. wrap_err ( eyre ! ( "Failed to create decryption device from session key for Rx" ) )
. with_section ( | | session_key . to_string ( ) . header ( "Decrypted session key was" ) ) ? ;
set_encrypted_read_for ( & self . info , & mut self . rx ) . await ? ;
// Set `encr` to true
self . state . encr = true ;
Ok ( ( ) )
} else {
self . state . encr = false ;
Ok ( ( ) )
}
@ -448,11 +461,13 @@ impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
}
}
//XXXXXXX: This isn't working. The first write to the socket succeeds. Any subsequent writes/reads produce garbage. Why?
impl < W , R > AsyncWrite for ESock < W , R >
where W : AsyncWrite
{
fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io ::Error > > {
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
if self . state . encw {
self . project ( ) . tx . poll_write ( cx , buf )
} else {
@ -489,16 +504,160 @@ where R: AsyncRead
/// Write half for `ESock`.
#[ pin_project ]
#[ derive(Debug) ]
pub struct ESockWriteHalf < W > ( Arc < ( ESockInfo , RwLock < ESockState > ) > , #[ pin ] AsyncSink < W > ) ;
pub struct ESockWriteHalf < W > ( Arc < ESockInfo > , #[ pin ] AsyncSink < W > , bool ) ;
/// Read half for `ESock`.
#[ pin_project ]
#[ derive(Debug) ]
pub struct ESockReadHalf < R > ( Arc < ( ESockInfo , RwLock < ESockState > ) > , #[ pin ] AsyncSource < R > ) ;
pub struct ESockReadHalf < R > ( Arc < ESockInfo > , #[ pin ] AsyncSource < R > , bool ) ;
//Impl AsyncRead/Write + set_encrypted_read/write for ESockRead/WriteHalf.
impl < W : AsyncWrite > ESockWriteHalf < W >
{
/// Does this write half have a live corresponding read half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional ( & self ) -> bool
{
Arc ::strong_count ( & self . 0 ) > 1
}
/// Is write encrypted on this half?
#[ inline(always) ] pub fn is_encrypted ( & self ) -> bool
{
self . 2
}
/// The local RSA private key
#[ inline ] pub fn local_key ( & self ) -> & RsaPrivateKey
{
& self . 0. us
}
/// THe remote RSA public key (if exchange has happened.)
#[ inline ] pub fn foreign_key ( & self ) -> Option < & RsaPublicKey >
{
self . 0. them . as_ref ( )
}
/// End an encrypted session syncronously.
///
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
pub fn clear_encryption ( & mut self )
{
self . 2 = false ;
}
}
impl < R : AsyncRead > ESockReadHalf < R >
{
/// Does this read half have a live corresponding write half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional ( & self ) -> bool
{
Arc ::strong_count ( & self . 0 ) > 1
}
/// Is write encrypted on this half?
#[ inline(always) ] pub fn is_encrypted ( & self ) -> bool
{
self . 2
}
/// The local RSA private key
#[ inline ] pub fn local_key ( & self ) -> & RsaPrivateKey
{
& self . 0. us
}
/// THe remote RSA public key (if exchange has happened.)
#[ inline ] pub fn foreign_key ( & self ) -> Option < & RsaPublicKey >
{
self . 0. them . as_ref ( )
}
/// End an encrypted session syncronously.
///
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
pub fn clear_encryption ( & mut self )
{
self . 2 = false ;
}
}
impl < W : AsyncWrite + Unpin > ESockWriteHalf < W >
{
/// Begin or end an encrypted writing session
pub async fn set_encryption ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
if set {
set_encrypted_write_for ( & self . 0 , & mut self . 1 ) . await ? ;
self . 2 = true ;
} else {
self . 2 = false ;
}
Ok ( ( ) )
}
}
impl < R : AsyncRead + Unpin > ESockReadHalf < R >
{
/// Begin or end an encrypted reading session
pub async fn set_encryption ( & mut self , set : bool ) -> eyre ::Result < ( ) >
{
if set {
set_encrypted_read_for ( & self . 0 , & mut self . 1 ) . await ? ;
self . 2 = true ;
} else {
self . 2 = false ;
}
Ok ( ( ) )
}
}
impl < W : AsyncWrite > AsyncWrite for ESockWriteHalf < W >
{
fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io ::Error > > {
if self . 2 {
// Encrypted
self . project ( ) . 1. poll_write ( cx , buf )
} else {
// Unencrypted
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self . map_unchecked_mut ( | this | this . 1. inner_mut ( ) ) . poll_write ( cx , buf ) }
}
}
#[ inline(always) ] fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io ::Error > > {
self . project ( ) . 1. poll_flush ( cx )
}
#[ inline(always) ] fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io ::Error > > {
self . project ( ) . 1. poll_flush ( cx )
}
}
impl < R : AsyncRead > AsyncRead for ESockReadHalf < R >
{
fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut [ u8 ] ) -> Poll < io ::Result < usize > > {
if self . 2 {
// Encrypted
self . project ( ) . 1. poll_read ( cx , buf )
} else {
// Unencrypted
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self . map_unchecked_mut ( | this | this . 1. inner_mut ( ) ) . poll_read ( cx , buf ) }
}
}
}
#[ cfg(test) ]
mod tests
{
use super ::ESock ;
#[ test ]
fn rsa_ciphertext_len ( ) -> crate ::eyre ::Result < ( ) >
{
@ -539,6 +698,16 @@ mod tests
Ok ( ( ) )
}
fn gen_duplex_esock ( bufsz : usize ) -> crate ::eyre ::Result < ( ESock < tokio ::io ::DuplexStream , tokio ::io ::DuplexStream > , ESock < tokio ::io ::DuplexStream , tokio ::io ::DuplexStream > ) >
{
use crate ::* ;
let ( atx , brx ) = tokio ::io ::duplex ( bufsz ) ;
let ( btx , arx ) = tokio ::io ::duplex ( bufsz ) ;
let tx = ESock ::new ( atx , arx ) . wrap_err ( eyre ! ( "Failed to create TX" ) ) ? ;
let rx = ESock ::new ( btx , brx ) . wrap_err ( eyre ! ( "Failed to create RX" ) ) ? ;
Ok ( ( tx , rx ) )
}
#[ tokio::test ]
async fn esock_exchange ( ) -> crate ::eyre ::Result < ( ) >
{
@ -546,10 +715,9 @@ mod tests
const VALUE : & ' static [ u8 ] = b" Hello world! " ;
let ( atx , brx ) = tokio ::io ::duplex ( super ::TRANS_KEY_MAX_SIZE * 4 ) ;
let ( btx , arx ) = tokio ::io ::duplex ( super ::TRANS_KEY_MAX_SIZE * 4 ) ;
let mut tx = super ::ESock ::new ( atx , arx ) ? ;
let mut rx = super ::ESock ::new ( btx , brx ) ? ;
// The duplex buffer size here is smaller than an RSA ciphertext block. So, writing the session key must be buffered with a buffer size this small (should return Pending at least once.)
// Using a low buffer size to make sure the test passes even when the entire buffer cannot be written at once.
let ( mut tx , mut rx ) = gen_duplex_esock ( 256 ) . wrap_err ( eyre ! ( "Failed to weave socks" ) ) ? ;
let writer = tokio ::spawn ( async move {
use tokio ::prelude ::* ;
@ -560,7 +728,9 @@ mod tests
tx . set_encrypted_write ( true ) . await ? ;
assert_eq! ( ( true , false ) , tx . is_encrypted ( ) ) ;
tx . write_all ( VALUE ) . await ? ;
tx . write_all ( & VALUE [ 0 .. 2 ] ) . await ? ;
tx . write_all ( & VALUE [ 2 .. 6 ] ) . await ? ;
tx . write_all ( & VALUE [ 6 .. ] ) . await ? ;
// Check resp
tx . set_encrypted_read ( true ) . await ? ;
@ -616,4 +786,83 @@ mod tests
assert_eq! ( & val , VALUE ) ;
Ok ( ( ) )
}
#[ tokio::test ]
async fn esock_split ( ) -> crate ::eyre ::Result < ( ) >
{
use super ::* ;
const SLICES : & ' static [ & ' static [ u8 ] ] = & [
& [ 1 , 5 , 3 , 7 , 6 , 9 , 100 , 0 ] ,
& [ 7 , 6 , 2 , 90 ] ,
& [ 3 , 6 , 1 , 0 ] ,
& [ 5 , 1 , 3 , 3 ] ,
] ;
let result = SLICES . iter ( ) . map ( | & slice | slice . iter ( ) . map ( | & b | u64 ::from ( b ) ) . sum ::< u64 > ( ) ) . sum ::< u64 > ( ) ;
println! ( "Result: {}" , result ) ;
let ( mut tx , mut rx ) = gen_duplex_esock ( super ::TRANS_KEY_MAX_SIZE * 4 ) . wrap_err ( eyre ! ( "Failed to weave socks" ) ) ? ;
let ( writer , reader ) = {
use tokio ::prelude ::* ;
let writer = tokio ::spawn ( async move {
tx . exchange ( ) . await ? ;
let ( mut tx , mut rx ) = tx . split ( ) ;
//tx.set_encryption(true).await?;
let slices = & SLICES [ 1 .. ] ;
for & slice in slices . iter ( )
{
println! ( "Writing slice: {:?}" , slice ) ;
tx . write_all ( slice ) . await ? ;
}
//let mut tx = ESock::unsplit(tx, rx);
tx . write_all ( SLICES [ 0 ] ) . await ? ;
Result ::< _ , eyre ::Report > ::Ok ( ( ) )
} ) ;
let reader = tokio ::spawn ( async move {
rx . exchange ( ) . await ? ;
let ( mut tx , mut rx ) = rx . split ( ) ;
//rx.set_encryption(true).await?;
let ( mut mtx , mut mrx ) = tokio ::sync ::mpsc ::channel ::< Vec < u8 > > ( 16 ) ;
let sorter = tokio ::spawn ( async move {
let mut done = 0 u64 ;
while let Some ( buf ) = mrx . recv ( ) . await
{
//buf.sort();
done + = buf . iter ( ) . map ( | & b | u64 ::from ( b ) ) . sum ::< u64 > ( ) ;
println! ( "Got buffer: {:?}" , buf ) ;
tx . write_all ( & buf ) . await ? ;
}
Result ::< _ , eyre ::Report > ::Ok ( done )
} ) ;
let mut buffer = [ 0 u8 ; 16 ] ;
while let Ok ( read ) = rx . read ( & mut buffer [ .. ] ) . await
{
if read = = 0 {
break ;
}
mtx . send ( Vec ::from ( & buffer [ .. read ] ) ) . await ? ;
}
drop ( mtx ) ;
let sum = sorter . await . expect ( "(reader) Sorter task panic" ) ? ;
Result ::< _ , eyre ::Report > ::Ok ( sum )
} ) ;
let ( writer , reader ) = tokio ::join ! [ writer , reader ] ;
( writer . expect ( "Writer task panic" ) ,
reader . expect ( "Reader task panic" ) )
} ;
writer ? ;
assert_eq! ( result , reader ? ) ;
Ok ( ( ) )
}
}