Compare commits
1 Commits
master
...
specialisa
Author | SHA1 | Date |
---|---|---|
Avril | af19935167 | 3 years ago |
@ -1,37 +0,0 @@
|
||||
//! Binary / byte maniuplation
|
||||
use bytes::BufMut;
|
||||
|
||||
/// Concatenate an iterator of byte slices into a buffer.
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of bytes written
|
||||
/// # Panics
|
||||
/// If the buffer cannot hold all the slices
|
||||
pub fn collect_slices_into<B: BufMut + ?Sized, I, T>(into: &mut B, from: I) -> usize
|
||||
where I: IntoIterator<Item=T>,
|
||||
T: AsRef<[u8]>
|
||||
{
|
||||
let mut done =0;
|
||||
for slice in from.into_iter()
|
||||
{
|
||||
let s = slice.as_ref();
|
||||
into.put_slice(s);
|
||||
done+=s.len();
|
||||
}
|
||||
done
|
||||
}
|
||||
|
||||
/// Collect an iterator of byte slices into a new exact-size buffer.
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of bytes written, and the new array
|
||||
///
|
||||
/// # Panics
|
||||
/// If the total bytes in all slices exceeds `SIZE`.
|
||||
pub fn collect_slices_exact<T, I, const SIZE: usize>(from: I) -> (usize, [u8; SIZE])
|
||||
where I: IntoIterator<Item=T>,
|
||||
T: AsRef<[u8]>
|
||||
{
|
||||
let mut output = [0u8; SIZE];
|
||||
(collect_slices_into(&mut &mut output[..], from), output)
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
//! Capabilities (permissions) of a connection.
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
/*
|
||||
pub mod fail;
|
||||
pub use fail::Failures;
|
||||
|
||||
/// How lenient to be with a certain operation
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Copy)]
|
||||
pub enum Leniency
|
||||
{
|
||||
/// Ignore **all** malformed/missed messages.
|
||||
Ignore,
|
||||
/// Allow `n` failures before disconnecting the socket.
|
||||
Specific(Failures),
|
||||
/// Allow
|
||||
Normal,
|
||||
/// Immediately disconnect the socket on **any** malformed/missed message.
|
||||
None,
|
||||
}
|
||||
*/
|
||||
|
||||
/// A capability (permission) for a raw socket's data transmission.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum RawSockCapability
|
||||
{
|
||||
/// Process messages that aren't signed.
|
||||
AllowUnsignedMessages,
|
||||
|
||||
/// Do not disconnect the socket when a malformed message is received, just ignore the message.
|
||||
SoftFail,
|
||||
|
||||
/// Throttle the number of messages to process
|
||||
RateLimit { tx: usize, rx: usize },
|
||||
|
||||
/// The request response timeout for messages with an expected response.
|
||||
RecvRespTimeout { tx: Duration, rx: Duration },
|
||||
|
||||
/// Max number of bytes to read for a single message.
|
||||
//TODO: Implement this for message
|
||||
MaxMessageSize(usize),
|
||||
}
|
@ -1,102 +0,0 @@
|
||||
//! Failure counting/capping
|
||||
use super::*;
|
||||
use std::{
|
||||
fmt, error,
|
||||
};
|
||||
|
||||
/// A measure of failures, used to track or to check failures.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Copy)]
|
||||
pub struct Failures
|
||||
{
|
||||
/// Number of failures happened back-to-back.
|
||||
pub seq: usize,
|
||||
/// Total number of failures over the socket's lifetime.
|
||||
///
|
||||
/// Set allowed to `0` for unlimited.
|
||||
pub total: usize,
|
||||
|
||||
/// Window of time to keep failures.
|
||||
pub window: Duration,
|
||||
/// Number of failures happened in the last `window` of time.
|
||||
pub last_window: usize,
|
||||
}
|
||||
|
||||
/// When a failure cap is exceeded, which one is exceeded; and what is the limit that is exceeded?
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum FailureCapExceeded
|
||||
{
|
||||
/// Too many sequential errors
|
||||
Sequential(usize),
|
||||
/// Too many total errors
|
||||
Total(usize),
|
||||
/// Too many errors in the refresh window
|
||||
Windowed(usize, Duration),
|
||||
}
|
||||
|
||||
impl error::Error for FailureCapExceeded{}
|
||||
impl fmt::Display for FailureCapExceeded
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
|
||||
{
|
||||
match self {
|
||||
Self::Sequential(_) => write!(f, "too many sequential errors"),
|
||||
Self::Total(_) => write!(f, "too many total errors"),
|
||||
Self::Windowed(_, _) => write!(f, "too many errors in refresh window"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FailureCapExceeded
|
||||
{
|
||||
/// Convert into a detailed report. (This shouldn't be shared with peer, probably.)
|
||||
pub fn into_detailed_report(self) -> eyre::Report
|
||||
{
|
||||
let rep = eyre::Report::from(&self);
|
||||
match self {
|
||||
|
||||
Self::Sequential(cap) => rep.with_section(|| cap.header("Exceeded limit")),
|
||||
Self::Total(cap) => rep.with_section(|| cap.header("Exceeded limit")),
|
||||
Self::Windowed(cap, w) => rep.with_section(|| cap.header("Exceeded limit"))
|
||||
.with_section(|| format!("{:?}", w).header("Refresh window was"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Failures
|
||||
{
|
||||
/// Default allowed failure parameters.
|
||||
pub const DEFAULT_ALLOWED: Self = Self {
|
||||
seq: 10,
|
||||
total: 65536,
|
||||
window: Duration::from_secs(10),
|
||||
last_window: 5,
|
||||
};
|
||||
|
||||
/// Has this `Failures` exceeded failure cap `other`?
|
||||
pub fn cap_check(&self, other: &Self) -> Result<(), FailureCapExceeded>
|
||||
{
|
||||
macro_rules! chk {
|
||||
($name:ident, $err:expr) => {
|
||||
if other.$name != 0 && (self.$name >= other.$name) {
|
||||
return Err($err)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
chk!(seq, FailureCapExceeded::Sequential(other.seq));
|
||||
chk!(total, FailureCapExceeded::Total(other.total));
|
||||
//debug_assert!(other.window == self.window); //TODO: Should we disallow this?
|
||||
chk!(last_window, FailureCapExceeded::Windowed(other.last_window, self.window));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Failures
|
||||
{
|
||||
#[inline]
|
||||
fn default() -> Self
|
||||
{
|
||||
Self::DEFAULT_ALLOWED
|
||||
}
|
||||
}
|
||||
|
@ -1,91 +0,0 @@
|
||||
//! Stack allocation helpers
|
||||
use super::*;
|
||||
|
||||
/// Max size of memory allowed to be allocated on the stack.
|
||||
pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB
|
||||
|
||||
/// A stack-allocated vector that spills onto the heap when needed.
|
||||
pub type StackVec<T> = SmallVec<[T; STACK_MEM_ALLOC_MAX]>;
|
||||
|
||||
/// Allocate a local buffer initialised from `init`.
|
||||
pub fn alloc_local_with<T, U>(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U
|
||||
{
|
||||
if sz > STACK_MEM_ALLOC_MAX {
|
||||
let mut memory: Vec<T> = iter::repeat_with(init).take(sz).collect();
|
||||
within(&mut memory[..])
|
||||
} else {
|
||||
stackalloc::stackalloc_with(sz, init, within)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Allocate a local zero-initialised byte buffer
|
||||
pub fn alloc_local_bytes<U>(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U
|
||||
{
|
||||
if sz > STACK_MEM_ALLOC_MAX {
|
||||
let mut memory: Vec<MaybeUninit<u8>> = vec_uninit(sz);
|
||||
within(unsafe {
|
||||
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz);
|
||||
stackalloc::helpers::slice_assume_init_mut(&mut memory[..])
|
||||
})
|
||||
} else {
|
||||
stackalloc::alloca_zeroed(sz, within)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Allocate a local zero-initialised buffer
|
||||
pub fn alloc_local_zeroed<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
|
||||
{
|
||||
let sz_bytes = mem::size_of::<T>() * sz;
|
||||
if sz > STACK_MEM_ALLOC_MAX {
|
||||
let mut memory = vec_uninit(sz);
|
||||
unsafe {
|
||||
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes);
|
||||
}
|
||||
within(&mut memory[..])
|
||||
} else {
|
||||
stackalloc::alloca_zeroed(sz_bytes, move |buf| {
|
||||
unsafe {
|
||||
debug_assert_eq!(buf.len() / mem::size_of::<T>(), sz);
|
||||
within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<T>, sz))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Allocate a local uninitialised buffer
|
||||
pub fn alloc_local_uninit<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
|
||||
{
|
||||
if sz > STACK_MEM_ALLOC_MAX {
|
||||
let mut memory = vec_uninit(sz);
|
||||
within(&mut memory[..])
|
||||
} else {
|
||||
stackalloc::stackalloc_uninit(sz, within)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Allocate a local buffer initialised with `init`.
|
||||
pub fn alloc_local<T: Clone, U>(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U
|
||||
{
|
||||
if sz > STACK_MEM_ALLOC_MAX {
|
||||
let mut memory: Vec<T> = iter::repeat(init).take(sz).collect();
|
||||
within(&mut memory[..])
|
||||
} else {
|
||||
stackalloc::stackalloc(sz, init, within)
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a local buffer initialised with `T::default()`.
|
||||
pub fn alloc_local_with_default<T: Default, U>(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U
|
||||
{
|
||||
if sz > STACK_MEM_ALLOC_MAX {
|
||||
let mut memory: Vec<T> = iter::repeat_with(Default::default).take(sz).collect();
|
||||
within(&mut memory[..])
|
||||
} else {
|
||||
stackalloc::stackalloc_with_default(sz, within)
|
||||
}
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
//! Base64 formatting extensions
|
||||
use super::*;
|
||||
|
||||
pub trait Base64StringExt
|
||||
{
|
||||
fn to_base64_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Base64StringExt for T
|
||||
where T: AsRef<[u8]>
|
||||
{
|
||||
fn to_base64_string(&self) -> String {
|
||||
::base64::encode(self.as_ref())
|
||||
}
|
||||
}
|
@ -1,124 +0,0 @@
|
||||
use std::{
|
||||
mem,
|
||||
iter::{
|
||||
self,
|
||||
ExactSizeIterator,
|
||||
FusedIterator,
|
||||
},
|
||||
slice,
|
||||
fmt,
|
||||
};
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HexStringIter<I>(I, [u8; 2]);
|
||||
|
||||
impl<I: Iterator<Item = u8>> HexStringIter<I>
|
||||
{
|
||||
/// Write this hex string iterator to a formattable buffer
|
||||
pub fn consume<F>(self, f: &mut F) -> fmt::Result
|
||||
where F: std::fmt::Write
|
||||
{
|
||||
if self.1[0] != 0 {
|
||||
write!(f, "{}", self.1[0] as char)?;
|
||||
}
|
||||
if self.1[1] != 0 {
|
||||
write!(f, "{}", self.1[1] as char)?;
|
||||
}
|
||||
|
||||
for x in self.0 {
|
||||
write!(f, "{:02x}", x)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Consume into a string
|
||||
pub fn into_string(self) -> String
|
||||
{
|
||||
let mut output = match self.size_hint() {
|
||||
(0, None) => String::new(),
|
||||
(_, Some(x)) |
|
||||
(x, None) => String::with_capacity(x),
|
||||
};
|
||||
self.consume(&mut output).unwrap();
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HexStringIterExt<I>: Sized
|
||||
{
|
||||
fn into_hex(self) -> HexStringIter<I>;
|
||||
}
|
||||
|
||||
pub type HexStringSliceIter<'a> = HexStringIter<iter::Copied<slice::Iter<'a, u8>>>;
|
||||
|
||||
pub trait HexStringSliceIterExt
|
||||
{
|
||||
fn hex(&self) -> HexStringSliceIter<'_>;
|
||||
}
|
||||
|
||||
impl<S> HexStringSliceIterExt for S
|
||||
where S: AsRef<[u8]>
|
||||
{
|
||||
fn hex(&self) -> HexStringSliceIter<'_>
|
||||
{
|
||||
self.as_ref().iter().copied().into_hex()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: IntoIterator<Item=u8>> HexStringIterExt<I::IntoIter> for I
|
||||
{
|
||||
#[inline] fn into_hex(self) -> HexStringIter<I::IntoIter> {
|
||||
HexStringIter(self.into_iter(), [0u8; 2])
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Iterator<Item = u8>> Iterator for HexStringIter<I>
|
||||
{
|
||||
type Item = char;
|
||||
fn next(&mut self) -> Option<Self::Item>
|
||||
{
|
||||
match self.1 {
|
||||
[_, 0] => {
|
||||
use std::io::Write;
|
||||
write!(&mut self.1[..], "{:02x}", self.0.next()?).unwrap();
|
||||
|
||||
Some(mem::replace(&mut self.1[0], 0) as char)
|
||||
},
|
||||
[0, _] => Some(mem::replace(&mut self.1[1], 0) as char),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let (l, h) = self.0.size_hint();
|
||||
|
||||
(l * 2, h.map(|x| x*2))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Iterator<Item = u8> + ExactSizeIterator> ExactSizeIterator for HexStringIter<I>{}
|
||||
impl<I: Iterator<Item = u8> + FusedIterator> FusedIterator for HexStringIter<I>{}
|
||||
|
||||
impl<I: Iterator<Item = u8>> From<HexStringIter<I>> for String
|
||||
{
|
||||
fn from(from: HexStringIter<I>) -> Self
|
||||
{
|
||||
from.into_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Iterator<Item = u8> + Clone> fmt::Display for HexStringIter<I>
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
|
||||
{
|
||||
self.clone().consume(f)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
#[macro_export] macro_rules! prog1 {
|
||||
($first:expr, $($rest:expr);+ $(;)?) => {
|
||||
($first, $( $rest ),+).0
|
||||
}
|
||||
}
|
||||
*/
|
@ -1,198 +0,0 @@
|
||||
//! Creating binary from messages
|
||||
//!
|
||||
//! `SerializedMessage` to `Bytes`.
|
||||
use super::*;
|
||||
use bytes::{
|
||||
Bytes,
|
||||
BytesMut,
|
||||
BufMut,
|
||||
Buf,
|
||||
};
|
||||
|
||||
macro_rules! try_from {
|
||||
(ref $into:ty, $from:expr $(; $fmt:literal)?) => {
|
||||
{
|
||||
let from = &$from;
|
||||
#[inline] fn _type_name_of_val<T: ?Sized>(_: &T) -> &'static str
|
||||
{
|
||||
::std::any::type_name::<T>()
|
||||
}
|
||||
<$into>::try_from(*from).wrap_err(eyre!("Failed to convert type"))
|
||||
.with_section(|| ::std::any::type_name::<$into>().header("New type"))
|
||||
.with_section(|| _type_name_of_val(from).header("Old type"))
|
||||
.with_section(|| from.to_string().header("Value was"))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Check bit written/read from binary streams to check basic integrity.
|
||||
pub const MESSAGE_HEADER_CHECK: [u8; 4] = (0xc0ffee00u32).to_be_bytes();
|
||||
|
||||
impl<V: ?Sized> SerializedMessage<V>
|
||||
{
|
||||
/// Write this message to a buffer
|
||||
///
|
||||
/// # Panics
|
||||
/// If `buffer` cannot hold enough bytes.
|
||||
pub fn into_buffer(self, mut buffer: impl BufMut) -> eyre::Result<usize>
|
||||
{
|
||||
let mut w=0;
|
||||
macro_rules! write {
|
||||
($bytes:expr) => {
|
||||
{
|
||||
let slice: &[u8] = ($bytes).as_ref();
|
||||
buffer.put_slice(slice);
|
||||
w+=slice.len();
|
||||
}
|
||||
};
|
||||
(? $o:expr) => {
|
||||
{
|
||||
match $o {
|
||||
Some(opt) => {
|
||||
buffer.put_u8(1);
|
||||
write!(opt);
|
||||
},
|
||||
None => {buffer.put_u8(0);},
|
||||
}
|
||||
w+=1;
|
||||
}
|
||||
};
|
||||
(: $ser:expr) => {
|
||||
{
|
||||
let mut v = StackVec::new();
|
||||
#[inline] fn _type_name_of_val<T: ?Sized>(_: &T) -> &'static str
|
||||
{
|
||||
::std::any::type_name::<T>()
|
||||
}
|
||||
let ser = $ser;
|
||||
serde_cbor::to_writer(&mut v, $ser)
|
||||
.wrap_err(eyre!("Failed to serialise value to temporary buffer"))
|
||||
.with_section(|| _type_name_of_val(ser).header("Type was"))?;
|
||||
buffer.put_u64(try_from!(ref u64, v.len())?);
|
||||
write!(&v[..]);
|
||||
}
|
||||
};
|
||||
}
|
||||
write!(MESSAGE_HEADER_CHECK);
|
||||
write!(: &self.header);
|
||||
buffer.put_u64(try_from!(ref u64, self.data.len())?);
|
||||
write!(self.data);
|
||||
write!(self.hash);
|
||||
write!(? self.enc_key);
|
||||
write!(? self.sig);
|
||||
|
||||
Ok(w)
|
||||
}
|
||||
|
||||
/// Write this message to a new `Bytes`.
|
||||
pub fn into_bytes(self) -> eyre::Result<Bytes>
|
||||
{
|
||||
let mut output = BytesMut::with_capacity(4096); //TODO: Find a better default capacity for this.
|
||||
self.into_buffer(&mut output)?;
|
||||
Ok(output.freeze())
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: ?Sized + MessageValue> SerializedMessage<V>
|
||||
{
|
||||
/// Create from a buffer of bytes.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes` does not contain enough data to read.
|
||||
pub fn from_buffer(mut bytes: impl Buf) -> eyre::Result<Self>
|
||||
{
|
||||
macro_rules! read {
|
||||
($bref:expr) => {
|
||||
{
|
||||
let by: &mut [u8] = ($bref).as_mut();
|
||||
bytes.copy_to_slice(by);
|
||||
}
|
||||
};
|
||||
(? $odef:expr) => {
|
||||
{
|
||||
let by = bytes.get_u8();
|
||||
match by {
|
||||
0 => None,
|
||||
1 => {
|
||||
let mut def = $odef;
|
||||
read!(&mut def);
|
||||
Some(def)
|
||||
},
|
||||
x => {
|
||||
return Err(eyre!("Invalid optional-set bit (should be 0 or 1)").with_section(|| x.header("Value was")));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
(: $ser:ty) => {
|
||||
{
|
||||
let len = try_from!(ref usize, bytes.get_u64())?;
|
||||
if len > MAX_READ_SIZE {
|
||||
return Err(eyre!("Invalid length read: {}", len)
|
||||
.with_section(|| MAX_READ_SIZE.header("Max length read")));
|
||||
}
|
||||
alloc_local_bytes(len, |de| {
|
||||
read!(&mut de[..]);
|
||||
serde_cbor::from_slice::<$ser>(&de[..]).wrap_err(eyre!("Failed to deserialise CBOR from reader")).with_section(|| std::any::type_name::<$ser>().header("Type to deserialise was"))
|
||||
})?
|
||||
|
||||
}
|
||||
};
|
||||
($into:expr, $num:expr) => {
|
||||
{
|
||||
let num = $num;
|
||||
let reader = (&mut bytes).reader();
|
||||
copy_buffer($into, reader, num).wrap_err(eyre!("Failed to read {} bytes from reader", num))?
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut check = [0u8; MESSAGE_HEADER_CHECK.len()];
|
||||
read!(&mut check[..]);
|
||||
if check != MESSAGE_HEADER_CHECK {
|
||||
return Err(eyre!("Invalid check bit for message header"))
|
||||
.with_section(|| u32::from_be_bytes(check).header("Expected"))
|
||||
.with_section(|| u32::from_be_bytes(MESSAGE_HEADER_CHECK).header("Got"));
|
||||
}
|
||||
|
||||
let header = read!(: SerHeader);
|
||||
let data_len = try_from!(ref usize, bytes.get_u64())?;
|
||||
if MAX_BODY_SIZE > 0 && data_len > MAX_BODY_SIZE {
|
||||
return Err(eyre!("Body size too large"))
|
||||
.with_section(|| data_len.header("Encoded size was"))
|
||||
.with_section(|| MAX_BODY_SIZE.header("Max size is"));
|
||||
}
|
||||
let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate OR try to read massive buffers by accident on corrupted/malformed messages
|
||||
read!(&mut data, data_len);
|
||||
if data.len()!=data_len {
|
||||
return Err(eyre!("Failed to read body bytes from buffer"))
|
||||
.with_section(|| format!("{} bytes", data_len).header("Tried to read"))
|
||||
.with_section(|| format!("{} bytes", data.len()).header("Read only"));
|
||||
}
|
||||
let mut hash = sha256::Sha256Hash::default();
|
||||
read!(&mut hash);
|
||||
let enc_key: Option<[u8; RSA_BLOCK_SIZE]> = read!(? [0u8; RSA_BLOCK_SIZE]);
|
||||
let sig: Option<rsa::Signature> = read!(? rsa::Signature::default());
|
||||
|
||||
Ok(Self {
|
||||
header,
|
||||
data,
|
||||
hash,
|
||||
enc_key,
|
||||
sig,
|
||||
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from a slice of bytes
|
||||
#[inline] pub fn from_slice(bytes: impl AsRef<[u8]>) -> eyre::Result<Self>
|
||||
{
|
||||
Self::from_buffer(bytes.as_ref())
|
||||
}
|
||||
|
||||
/// Create from a `Bytes` instance
|
||||
#[inline(always)] pub fn from_bytes(bytes: Bytes) -> eyre::Result<Self>
|
||||
{
|
||||
Self::from_buffer(bytes)
|
||||
}
|
||||
}
|
@ -1,66 +1,32 @@
|
||||
//! Message values
|
||||
use super::*;
|
||||
use std::mem;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
/// A value that can be used for messages.
|
||||
pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{}
|
||||
|
||||
impl<T: ?Sized> MessageValue for T
|
||||
where T: Serialize + for<'de> Deserialize<'de>{}
|
||||
|
||||
/*
|
||||
|
||||
use std::any::Any;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DynamicMessageValue<T>(T)
|
||||
where T: Any + 'static;
|
||||
|
||||
|
||||
pub struct MessageValueAnyRef<'a>(&'a (dyn Any +'static));
|
||||
pub struct MessageValueAnyMut<'a>(&'a mut (dyn Any +'static));
|
||||
pub struct MessageValueAny(Box<dyn Any +'static>);
|
||||
|
||||
//impl<T> MessageValue for DynamicMessageValue<T>
|
||||
//where T: Serialize + for<'de> Deserialize<'de> + Any{}
|
||||
*/
|
||||
|
||||
/// A type-unsafe value that can be used to transmute `SerializedMessage` instances.
|
||||
///
|
||||
/// This operation is unsafe and can result in deserializing the `SerializedMessage` failing, or even worse, a type confusion.
|
||||
///
|
||||
/// This type does not implement `MessageValue`, as serialized messages of this value cannot be created (from `Message::serialize()`), nor deserialized. They must first be converted into a typed `SerializedMessage<V>` with `UntypedSerializedMessage::into_typed<V>()`.
|
||||
///
|
||||
/// This is an empty (!) type.
|
||||
#[derive(Debug)]
|
||||
pub enum UntypedMessageValue{}
|
||||
|
||||
impl SerializedMessage<UntypedMessageValue>
|
||||
/// A value that can be used for messages.
|
||||
pub trait MessageValue: Serialize + for<'de> Deserialize<'de>
|
||||
{
|
||||
/// Transmute into a specifically typed `SerializedMessage`
|
||||
///
|
||||
/// # Safety
|
||||
/// If `V` is not the original type of this message (before being untyped), then deserialisation will likely fail, or, much worse, cause a *type consufion* bug, where the object of type `V` is successfully deserialized with an invalid value from an unknown type.
|
||||
/// Take special care when using this.
|
||||
pub const unsafe fn into_typed<V: MessageValue + ?Sized>(self) -> SerializedMessage<V>
|
||||
{
|
||||
mem::transmute(self)
|
||||
}
|
||||
fn as_dynamic(&self) -> Option<MessageValueAnyRef<'_>> { None }
|
||||
fn as_dynamic_mut(&mut self) -> Option<MessageValueAnyMut<'_>> { None }
|
||||
fn into_dynamic(self) -> Result<MessageValueAny, Self> { Err(self) }
|
||||
}
|
||||
|
||||
impl<V: MessageValue + ?Sized> SerializedMessage<V>
|
||||
default impl<T: ?Sized> MessageValue for T
|
||||
where T: Serialize + for<'de> Deserialize<'de>
|
||||
{
|
||||
/// Consume this value into an untyped `SerializedMessage`.
|
||||
///
|
||||
/// # Safety
|
||||
/// This operation is safe, however, doing anything at all with the resulting `SerializedMessage` is not.
|
||||
/// It must be unsafly converted into a types message before it can be deserialized.
|
||||
pub const fn into_untyped(self) -> SerializedMessage<UntypedMessageValue>
|
||||
{
|
||||
unsafe {
|
||||
mem::transmute(self)
|
||||
}
|
||||
}
|
||||
default fn as_dynamic(&self) -> Option<MessageValueAnyRef<'_>> { None }
|
||||
default fn as_dynamic_mut(&mut self) -> Option<MessageValueAnyMut<'_>> { None }
|
||||
default fn into_dynamic(self) -> Result<MessageValueAny, Self> { Err(self) }
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DynamicMessageValue<T>(T)
|
||||
where T: Any + 'static;
|
||||
|
||||
//impl<T> MessageValue for DynamicMessageValue<T>
|
||||
//where T: Serialize + for<'de> Deserialize<'de> + Any{}
|
||||
|
||||
|
@ -0,0 +1,30 @@
|
||||
//! Socket handlers
|
||||
use super::*;
|
||||
|
||||
use tokio::io::{
|
||||
AsyncWrite,
|
||||
AsyncRead
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
use futures::Future;
|
||||
use cancel::*;
|
||||
|
||||
|
||||
|
||||
/// Handles a raw, opened socket
|
||||
pub fn handle_socket_with_shutdown<R, W, C: cancel::CancelFuture + 'static + Send>(tx: W, rx: R, shutdown: C) -> JoinHandle<eyre::Result<()>>
|
||||
where R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static
|
||||
{
|
||||
tokio::spawn(async move {
|
||||
match {
|
||||
with_cancel!(async move {
|
||||
//TODO: How to handle reads+writes?
|
||||
Ok(())
|
||||
}, shutdown)
|
||||
} {
|
||||
Ok(v) => v,
|
||||
Err(x) => Err(eyre::Report::from(x)),
|
||||
}
|
||||
})
|
||||
}
|
@ -1,874 +0,0 @@
|
||||
//! Socket encryption wrapper
|
||||
use super::*;
|
||||
use cryptohelpers::{
|
||||
rsa::{
|
||||
self,
|
||||
RsaPublicKey,
|
||||
RsaPrivateKey,
|
||||
|
||||
openssl::{
|
||||
symm::Crypter,
|
||||
error::ErrorStack,
|
||||
},
|
||||
},
|
||||
sha256,
|
||||
};
|
||||
use chacha20stream::{
|
||||
AsyncSink,
|
||||
AsyncSource,
|
||||
|
||||
Key, IV,
|
||||
|
||||
cha,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::{
|
||||
sync::{
|
||||
RwLock,
|
||||
RwLockReadGuard,
|
||||
RwLockWriteGuard,
|
||||
},
|
||||
};
|
||||
use std::{
|
||||
io,
|
||||
fmt,
|
||||
task::{
|
||||
Context, Poll,
|
||||
},
|
||||
pin::Pin,
|
||||
marker::Unpin,
|
||||
};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
/// Size of a single RSA ciphertext.
|
||||
pub const RSA_CIPHERTEXT_SIZE: usize = 512;
|
||||
|
||||
/// A single, full block of RSA ciphertext.
|
||||
type RsaCiphertextBlock = [u8; RSA_CIPHERTEXT_SIZE];
|
||||
|
||||
/// Max size to read when exchanging keys
|
||||
const TRANS_KEY_MAX_SIZE: usize = 4096;
|
||||
|
||||
/// Encrypted socket information.
|
||||
#[derive(Debug)]
|
||||
struct ESockInfo {
|
||||
us: RsaPrivateKey,
|
||||
them: Option<RsaPublicKey>,
|
||||
}
|
||||
|
||||
impl ESockInfo
|
||||
{
|
||||
/// Generate a new private key
|
||||
pub fn new(us: impl Into<RsaPrivateKey>) -> Self
|
||||
{
|
||||
Self {
|
||||
us: us.into(),
|
||||
them: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a new private key for the local endpoint
|
||||
pub fn generate() -> Result<Self, rsa::Error>
|
||||
{
|
||||
Ok(Self::new(RsaPrivateKey::generate()?))
|
||||
}
|
||||
}
|
||||
|
||||
/// The encryption state of the Tx and Rx instances.
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
|
||||
struct ESockState {
|
||||
encr: bool,
|
||||
encw: bool,
|
||||
}
|
||||
|
||||
impl Default for ESockState
|
||||
{
|
||||
#[inline]
|
||||
fn default() -> Self
|
||||
{
|
||||
Self {
|
||||
encr: false,
|
||||
encw: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains a cc20 Key and IV that can be serialized and then encrypted
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
struct ESockSessionKey
|
||||
{
|
||||
key: Key,
|
||||
iv: IV,
|
||||
}
|
||||
|
||||
impl fmt::Display for ESockSessionKey
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
|
||||
{
|
||||
write!(f, "Key: {}, IV: {}", self.key.hex(), self.iv.hex())
|
||||
}
|
||||
}
|
||||
|
||||
impl ESockSessionKey
|
||||
{
|
||||
/// Generate a new cc20 key + iv,
|
||||
pub fn generate() -> Self
|
||||
{
|
||||
let (key,iv) = cha::keygen();
|
||||
Self{key,iv}
|
||||
}
|
||||
|
||||
/// Generate an encryption device
|
||||
pub fn to_decrypter(&self) -> Result<Crypter, ErrorStack>
|
||||
{
|
||||
cha::decrypter(&self.key, &self.iv)
|
||||
}
|
||||
|
||||
/// Generate an encryption device
|
||||
pub fn to_encrypter(&self) -> Result<Crypter, ErrorStack>
|
||||
{
|
||||
cha::encrypter(&self.key, &self.iv)
|
||||
}
|
||||
|
||||
/// Encrypt with RSA
|
||||
pub fn to_ciphertext<K: ?Sized + rsa::PublicKey>(&self, rsa_key: &K) -> eyre::Result<RsaCiphertextBlock>
|
||||
{
|
||||
let mut output = [0u8; RSA_CIPHERTEXT_SIZE];
|
||||
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new(); // We know size will fit into here.
|
||||
serde_cbor::to_writer(&mut temp, self)
|
||||
.wrap_err(eyre!("Failed to CBOR encode session key to buffer"))
|
||||
.with_section(|| self.clone().header("Session key was"))?;
|
||||
debug_assert!(temp.len() < RSA_CIPHERTEXT_SIZE);
|
||||
|
||||
let _wr = rsa::encrypt_slice_sync(&temp, rsa_key, &mut &mut output[..])
|
||||
.wrap_err(eyre!("Failed to encrypt session key with RSA public key"))
|
||||
.with_section(|| self.clone().header("Session key was"))
|
||||
.with_section({let temp = temp.len(); move || temp.header("Encoded data size was")})
|
||||
.with_section(move || base64::encode(temp).header("Encoded data (base64) was"))?;
|
||||
debug_assert_eq!(_wr, output.len());
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Decrypt from RSA
|
||||
pub fn from_ciphertext<K: ?Sized + rsa::PrivateKey>(data: &[u8; RSA_CIPHERTEXT_SIZE], rsa_key: &K) -> eyre::Result<Self>
|
||||
where <K as rsa::PublicKey>::KeyType: rsa::openssl::pkey::HasPrivate //ugh, why do we have to have this bound??? it should be implied ffs... :/
|
||||
{
|
||||
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new();
|
||||
rsa::decrypt_slice_sync(data, rsa_key, &mut temp)
|
||||
.wrap_err(eyre!("Failed to decrypt ciphertext to session key"))
|
||||
.with_section({let data = data.len(); move || data.header("Ciphertext length was")})
|
||||
.with_section(|| base64::encode(data).header("Ciphertext was"))?;
|
||||
Ok(serde_cbor::from_slice(&temp[..])
|
||||
.wrap_err(eyre!("Failed to decode CBOR data to session key object"))
|
||||
.with_section({let temp = temp.len(); move || temp.header("Encoded data size was")})
|
||||
.with_section(move || base64::encode(temp).header("Encoded data (base64) was"))?)
|
||||
}
|
||||
}
|
||||
|
||||
/// A tx+rx socket.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ESock<W, R> {
|
||||
info: ESockInfo,
|
||||
|
||||
state: ESockState,
|
||||
|
||||
#[pin]
|
||||
rx: AsyncSource<R>,
|
||||
#[pin]
|
||||
tx: AsyncSink<W>,
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
|
||||
{
|
||||
fn inner(&self) -> (&W, &R)
|
||||
{
|
||||
(self.tx.inner(), self.rx.inner())
|
||||
}
|
||||
|
||||
fn inner_mut(&mut self) -> (&mut W, &mut R)
|
||||
{
|
||||
(self.tx.inner_mut(), self.rx.inner_mut())
|
||||
}
|
||||
|
||||
///Get a mutable ref to unencrypted read+write
|
||||
fn unencrypted(&mut self) -> (&mut W, &mut R)
|
||||
{
|
||||
(self.tx.inner_mut(), self.rx.inner_mut())
|
||||
}
|
||||
/// Get a mutable ref to encrypted write+read
|
||||
fn encrypted(&mut self) -> (&mut AsyncSink<W>, &mut AsyncSource<R>)
|
||||
{
|
||||
(&mut self.tx, &mut self.rx)
|
||||
}
|
||||
|
||||
/// Have the RSA keys been exchanged?
|
||||
pub fn has_exchanged(&self) -> bool
|
||||
{
|
||||
self.info.them.is_some()
|
||||
}
|
||||
|
||||
/// Is the Write + Read operation encrypted? Tuple is `(Tx, Rx)`.
|
||||
#[inline] pub fn is_encrypted(&self) -> (bool, bool)
|
||||
{
|
||||
(self.state.encw, self.state.encr)
|
||||
}
|
||||
|
||||
/// Create a new `ESock` wrapper over this writer and reader with this specific RSA key.
|
||||
pub fn with_key(key: impl Into<RsaPrivateKey>, tx: W, rx: R) -> Self
|
||||
{
|
||||
let (tk, tiv) = cha::keygen();
|
||||
Self {
|
||||
info: ESockInfo::new(key),
|
||||
state: Default::default(),
|
||||
|
||||
// Note: These key+IV pairs are never used, as `state` defaults to unencrypted, and a new key/iv pair is generated when we `set_encrypted_write/read(true)`.
|
||||
// TODO: Have a method to exchange these default session keys after `exchange()`?
|
||||
tx: AsyncSink::encrypt(tx, tk, tiv).expect("Failed to create temp AsyncSink"),
|
||||
rx: AsyncSource::encrypt(rx, tk, tiv).expect("Failed to create temp AsyncSource"),
|
||||
}
|
||||
}
|
||||
/// Create a new `ESock` wrapper over this writer and reader with a newly generated private key
|
||||
#[inline] pub fn new(tx: W, rx: R) -> Result<Self, rsa::Error>
|
||||
{
|
||||
Ok(Self::with_key(RsaPrivateKey::generate()?, tx, rx))
|
||||
}
|
||||
|
||||
/// The local RSA private key
|
||||
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
|
||||
{
|
||||
&self.info.us
|
||||
}
|
||||
/// THe remote RSA public key (if exchange has happened.)
|
||||
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
|
||||
{
|
||||
self.info.them.as_ref()
|
||||
}
|
||||
|
||||
/// Split this `ESock` into a read+write pair.
|
||||
///
|
||||
/// # Note
|
||||
/// You must preform an `exchange()` before splitting, as exchanging RSA keys is not possible on a single half.
|
||||
///
|
||||
/// It is also more efficient to `set_encrypted_write/read(true)` on `ESock` than it is on the halves, but changinc encryption modes on halves is still possible.
|
||||
pub fn split(self) -> (ESockWriteHalf<W>, ESockReadHalf<R>)
|
||||
{
|
||||
let arced = Arc::new(self.info);
|
||||
|
||||
(ESockWriteHalf(Arc::clone(&arced), self.tx, self.state.encw),
|
||||
ESockReadHalf(arced, self.rx, self.state.encr))
|
||||
}
|
||||
|
||||
/// Merge a previously split `ESock` into a single one again.
|
||||
///
|
||||
/// # Panics
|
||||
/// If the two halves were not split from the same `ESock`.
|
||||
pub fn unsplit(txh: ESockWriteHalf<W>, rxh: ESockReadHalf<R>) -> Self
|
||||
{
|
||||
#[cold]
|
||||
#[inline(never)]
|
||||
fn _panic_ptr_ineq() -> !
|
||||
{
|
||||
panic!("Cannot merge halves split from different sources")
|
||||
}
|
||||
if !Arc::ptr_eq(&txh.0, &rxh.0) {
|
||||
_panic_ptr_ineq();
|
||||
}
|
||||
|
||||
let tx = txh.1;
|
||||
drop(txh.0);
|
||||
let info = Arc::try_unwrap(rxh.0).unwrap();
|
||||
let rx = rxh.1;
|
||||
|
||||
Self {
|
||||
state: ESockState {
|
||||
encw: txh.2,
|
||||
encr: rxh.2,
|
||||
},
|
||||
info,
|
||||
tx, rx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_encrypted_write_for<T: AsyncWrite + Unpin>(info: &ESockInfo, tx: &mut AsyncSink<T>) -> eyre::Result<()>
|
||||
{
|
||||
use tokio::prelude::*;
|
||||
let session_key = ESockSessionKey::generate();
|
||||
let data = {
|
||||
let them = info.them.as_ref().expect("Cannot set encrypted write when keys have not been exchanged");
|
||||
session_key.to_ciphertext(them)
|
||||
.wrap_err(eyre!("Failed to encrypt session key with foreign endpoint's key"))
|
||||
.with_section(|| session_key.to_string().header("Session key was"))
|
||||
.with_section(|| them.to_string().header("Foreign pubkey was"))?
|
||||
};
|
||||
let crypter = session_key.to_encrypter()
|
||||
.wrap_err(eyre!("Failed to create encryption device from session key for Tx"))
|
||||
.with_section(|| session_key.to_string().header("Session key was"))?;
|
||||
// Send rsa `data` over unencrypted endpoint
|
||||
tx.inner_mut().write_all(&data[..]).await
|
||||
.wrap_err(eyre!("Failed to write ciphertext to endpoint"))
|
||||
.with_section(|| data.to_base64_string().header("Ciphertext of session key was"))?;
|
||||
// Set crypter of `tx` to `session_key`.
|
||||
*tx.crypter_mut() = crypter;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn set_encrypted_read_for<T: AsyncRead + Unpin>(info: &ESockInfo, rx: &mut AsyncSource<T>) -> eyre::Result<()>
|
||||
{
|
||||
use tokio::prelude::*;
|
||||
|
||||
let mut data = [0u8; RSA_CIPHERTEXT_SIZE];
|
||||
// Read `data` from unencrypted endpoint
|
||||
rx.inner_mut().read_exact(&mut data[..]).await
|
||||
.wrap_err(eyre!("Failed to read ciphertext from endpoint"))?;
|
||||
// Decrypt `data`
|
||||
let session_key = ESockSessionKey::from_ciphertext(&data, &info.us)
|
||||
.wrap_err(eyre!("Failed to decrypt session key from ciphertext"))
|
||||
.with_section(|| data.to_base64_string().header("Ciphertext was"))
|
||||
.with_section(|| info.us.to_string().header("Our RSA key is"))?;
|
||||
// Set crypter of `rx` to `session_key`.
|
||||
*rx.crypter_mut() = session_key.to_decrypter()
|
||||
.wrap_err(eyre!("Failed to create decryption device from session key for Rx"))
|
||||
.with_section(|| session_key.to_string().header("Decrypted session key was"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
|
||||
{
|
||||
/// Get the Tx and Rx of the stream.
|
||||
///
|
||||
/// # Returns
|
||||
/// Returns encrypted stream halfs if the stream is encrypted, unencrypted if not.
|
||||
pub fn stream(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
|
||||
{
|
||||
(if self.state.encw {
|
||||
&mut self.tx
|
||||
} else {
|
||||
self.tx.inner_mut()
|
||||
}, if self.state.encr {
|
||||
&mut self.rx
|
||||
} else {
|
||||
self.rx.inner_mut()
|
||||
})
|
||||
}
|
||||
/// Enable write encryption
|
||||
pub async fn set_encrypted_write(&mut self, set: bool) -> eyre::Result<()>
|
||||
{
|
||||
if set {
|
||||
set_encrypted_write_for(&self.info, &mut self.tx).await?;
|
||||
// Set `encw` to true
|
||||
self.state.encw = true;
|
||||
Ok(())
|
||||
} else {
|
||||
self.state.encw = false;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable read encryption
|
||||
///
|
||||
/// The other endpoint must have sent a `set_encrypted_write()`
|
||||
pub async fn set_encrypted_read(&mut self, set: bool) -> eyre::Result<()>
|
||||
{
|
||||
if set {
|
||||
set_encrypted_read_for(&self.info, &mut self.rx).await?;
|
||||
// Set `encr` to true
|
||||
self.state.encr = true;
|
||||
Ok(())
|
||||
} else {
|
||||
self.state.encr = false;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get dynamic ref to unencrypted write+read
|
||||
fn unencrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
|
||||
{
|
||||
(self.tx.inner_mut(), self.rx.inner_mut())
|
||||
}
|
||||
/// Get dynamic ref to encrypted write+read
|
||||
fn encrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
|
||||
{
|
||||
(&mut self.tx, &mut self.rx)
|
||||
}
|
||||
/// Exchange keys.
|
||||
pub async fn exchange(&mut self) -> eyre::Result<()>
|
||||
{
|
||||
use tokio::prelude::*;
|
||||
let our_key = self.info.us.get_public_parts();
|
||||
let (tx, rx) = self.inner_mut();
|
||||
let read_fut = {
|
||||
|
||||
async move {
|
||||
// Read the public key from `rx`.
|
||||
//TODO: Find pubkey max size.
|
||||
let mut sz_buf = [0u8; std::mem::size_of::<u64>()];
|
||||
rx.read_exact(&mut sz_buf[..]).await
|
||||
.wrap_err(eyre!("Failed to read size of pubkey form endpoint"))?;
|
||||
let sz64 = u64::from_be_bytes(sz_buf);
|
||||
let sz= match usize::try_from(sz64)
|
||||
.wrap_err(eyre!("Read size could not fit into u64"))
|
||||
.with_section(|| format!("{:?}", sz_buf).header("Read buffer was"))
|
||||
.with_section(|| u64::from_be_bytes(sz_buf).header("64=bit size value was"))
|
||||
.with_warning(|| "This should not happen, it is only possible when you are running a machine with a pointer size lower than 64 bits.")
|
||||
.with_suggestion(|| "The message is likely malformed. If it is not, then you are communicating with an endpoint of 64 bits whereas your pointer size is far less.")? {
|
||||
x if x > TRANS_KEY_MAX_SIZE => return Err(eyre!("Recv'd key size exceeded max acceptable key buffer size")),
|
||||
x => x
|
||||
};
|
||||
let mut key_bytes = Vec::with_capacity(sz);
|
||||
tokio::io::copy(&mut rx.take(sz64), &mut key_bytes).await
|
||||
.wrap_err("Failed to read key bytes into buffer")
|
||||
.with_section(move || sz64.header("Pubkey size to read was"))?;
|
||||
if key_bytes.len() != sz {
|
||||
return Err(eyre!("Could not read required bytes"));
|
||||
}
|
||||
let k = RsaPublicKey::from_bytes(&key_bytes)
|
||||
.wrap_err("Failed to construct RSA public key from read bytes")
|
||||
.with_section(|| sz.header("Pubkey size was"))
|
||||
.with_section(move || key_bytes.to_base64_string().header("Pubkey bytes were"))?;
|
||||
|
||||
Result::<RsaPublicKey, eyre::Report>::Ok(k)
|
||||
}
|
||||
};
|
||||
let write_fut = {
|
||||
let key_bytes = our_key.to_bytes();
|
||||
assert!(key_bytes.len() <= TRANS_KEY_MAX_SIZE);
|
||||
let sz64 = u64::try_from(key_bytes.len())
|
||||
.wrap_err(eyre!("Size of our pubkey could not fit into u64"))
|
||||
.with_section(|| key_bytes.len().header("Size was"))
|
||||
.with_warning(|| "This should not happen, it is only possible when you are running a machine with a pointer size larger than 64 bits.")
|
||||
.with_warning(|| "There was likely internal memory corruption.")?;
|
||||
let sz_buf = sz64.to_be_bytes();
|
||||
async move {
|
||||
tx.write_all(&sz_buf[..]).await
|
||||
.wrap_err(eyre!("Failed to write key size"))
|
||||
.with_section(|| sz64.header("Key size bytes were"))
|
||||
.with_section(|| format!("{:?}", sz_buf).header("Key size bytes (BE) were"))?;
|
||||
tx.write_all(&key_bytes[..]).await
|
||||
.wrap_err(eyre!("Failed to write key bytes"))
|
||||
.with_section(|| sz64.header("Size of key was"))
|
||||
.with_section(|| key_bytes.to_base64_string().header("Key bytes are"))?;
|
||||
Result::<(), eyre::Report>::Ok(())
|
||||
}
|
||||
};
|
||||
let (send, recv) = tokio::join! [write_fut, read_fut];
|
||||
send.wrap_err("Failed to send our pubkey")?;
|
||||
let recv = recv.wrap_err("Failed to receive foreign pubkey")?;
|
||||
self.info.them = Some(recv);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
//XXX: For some reason, non-exact reads + writes cause garbage to be produced on the receiving end?
|
||||
// Is this fixable? Why does it disjoint? I have no idea... This is supposed to be a stream cipher, right? Why does positioning matter? Have I misunderstood how it workd? Eh...
|
||||
// With this bug, it seems the `while read(buffer) > 0` construct is impossible. This might make this entirely useless. Hopefully with the rigid size-based format for `Message` we won't run into this problem, but subsequent data streaming will likely be affected unless we use rigid, fixed, and (inefficiently) communicated buffer sizes.
|
||||
|
||||
impl<W, R> AsyncWrite for ESock<W, R>
|
||||
where W: AsyncWrite
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
|
||||
if self.state.encw {
|
||||
self.project().tx.poll_write(cx, buf)
|
||||
} else {
|
||||
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
|
||||
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
|
||||
unsafe { self.map_unchecked_mut(|this| this.tx.inner_mut()).poll_write(cx, buf)}
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
// Should we do anything else here?
|
||||
// Should we clear foreign key/current session key?
|
||||
self.project().tx.poll_shutdown(cx)
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().tx.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, R> AsyncRead for ESock<W, R>
|
||||
where R: AsyncRead
|
||||
{
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
|
||||
if self.state.encr {
|
||||
self.project().rx.poll_read(cx, buf)
|
||||
} else {
|
||||
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
|
||||
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
|
||||
unsafe { self.map_unchecked_mut(|this| this.rx.inner_mut()).poll_read(cx, buf)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write half for `ESock`.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ESockWriteHalf<W>(Arc<ESockInfo>, #[pin] AsyncSink<W>, bool);
|
||||
|
||||
/// Read half for `ESock`.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ESockReadHalf<R>(Arc<ESockInfo>, #[pin] AsyncSource<R>, bool);
|
||||
|
||||
//Impl AsyncRead/Write + set_encrypted_read/write for ESockRead/WriteHalf.
|
||||
|
||||
impl<W: AsyncWrite> ESockWriteHalf<W>
|
||||
{
|
||||
/// Does this write half have a live corresponding read half?
|
||||
///
|
||||
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
|
||||
pub fn is_bidirectional(&self) -> bool
|
||||
{
|
||||
Arc::strong_count(&self.0) > 1
|
||||
}
|
||||
|
||||
/// Is write encrypted on this half?
|
||||
#[inline(always)] pub fn is_encrypted(&self) -> bool
|
||||
{
|
||||
self.2
|
||||
}
|
||||
/// The local RSA private key
|
||||
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
|
||||
{
|
||||
&self.0.us
|
||||
}
|
||||
/// THe remote RSA public key (if exchange has happened.)
|
||||
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
|
||||
{
|
||||
self.0.them.as_ref()
|
||||
}
|
||||
|
||||
/// End an encrypted session syncronously.
|
||||
///
|
||||
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
|
||||
pub fn clear_encryption(&mut self)
|
||||
{
|
||||
self.2 = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> ESockReadHalf<R>
|
||||
{
|
||||
/// Does this read half have a live corresponding write half?
|
||||
///
|
||||
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
|
||||
pub fn is_bidirectional(&self) -> bool
|
||||
{
|
||||
Arc::strong_count(&self.0) > 1
|
||||
}
|
||||
|
||||
/// Is write encrypted on this half?
|
||||
#[inline(always)] pub fn is_encrypted(&self) -> bool
|
||||
{
|
||||
self.2
|
||||
}
|
||||
/// The local RSA private key
|
||||
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
|
||||
{
|
||||
&self.0.us
|
||||
}
|
||||
/// THe remote RSA public key (if exchange has happened.)
|
||||
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
|
||||
{
|
||||
self.0.them.as_ref()
|
||||
}
|
||||
|
||||
/// End an encrypted session syncronously.
|
||||
///
|
||||
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
|
||||
pub fn clear_encryption(&mut self)
|
||||
{
|
||||
self.2 = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> ESockWriteHalf<W>
|
||||
{
|
||||
/// Begin or end an encrypted writing session
|
||||
pub async fn set_encryption(&mut self, set: bool) -> eyre::Result<()>
|
||||
{
|
||||
if set {
|
||||
set_encrypted_write_for(&self.0, &mut self.1).await?;
|
||||
self.2 = true;
|
||||
} else {
|
||||
self.2 = false;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> ESockReadHalf<R>
|
||||
{
|
||||
/// Begin or end an encrypted reading session
|
||||
pub async fn set_encryption(&mut self, set: bool) -> eyre::Result<()>
|
||||
{
|
||||
if set {
|
||||
set_encrypted_read_for(&self.0, &mut self.1).await?;
|
||||
self.2 = true;
|
||||
} else {
|
||||
self.2 = false;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite> AsyncWrite for ESockWriteHalf<W>
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||
if self.2 {
|
||||
// Encrypted
|
||||
|
||||
self.project().1.poll_write(cx, buf)
|
||||
} else {
|
||||
// Unencrypted
|
||||
|
||||
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
|
||||
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
|
||||
unsafe { self.map_unchecked_mut(|this| this.1.inner_mut()).poll_write(cx, buf)}
|
||||
}
|
||||
}
|
||||
#[inline(always)] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().1.poll_flush(cx)
|
||||
}
|
||||
#[inline(always)] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().1.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> AsyncRead for ESockReadHalf<R>
|
||||
{
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||
if self.2 {
|
||||
// Encrypted
|
||||
|
||||
self.project().1.poll_read(cx, buf)
|
||||
} else {
|
||||
// Unencrypted
|
||||
|
||||
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
|
||||
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
|
||||
unsafe { self.map_unchecked_mut(|this| this.1.inner_mut()).poll_read(cx, buf)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests
|
||||
{
|
||||
use super::ESock;
|
||||
|
||||
#[test]
|
||||
fn rsa_ciphertext_len() -> crate::eyre::Result<()>
|
||||
{
|
||||
let data = {
|
||||
use chacha20stream::cha::{KEY_SIZE, IV_SIZE};
|
||||
let (key, iv) = chacha20stream::cha::keygen();
|
||||
let (sz, d) = crate::bin::collect_slices_exact::<&[u8], _, {KEY_SIZE + IV_SIZE}>([key.as_ref(), iv.as_ref()]);
|
||||
assert_eq!(sz, d.len());
|
||||
d
|
||||
};
|
||||
println!("KEY+IV: {} bytes", data.len());
|
||||
|
||||
let key = cryptohelpers::rsa::RsaPublicKey::generate()?;
|
||||
let rsa = cryptohelpers::rsa::encrypt_slice_to_vec(data, &key)?;
|
||||
println!("Rsa ciphertext size: {}", rsa.len());
|
||||
|
||||
assert_eq!(rsa.len(), super::RSA_CIPHERTEXT_SIZE, "Incorrect RSA ciphertext length constant for cc20 KEY+IV encoding.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn rsa_serial_ciphertext_len() -> crate::eyre::Result<()>
|
||||
{
|
||||
let data = serde_cbor::to_vec(&{
|
||||
let (key, iv) = chacha20stream::cha::keygen();
|
||||
super::ESockSessionKey {
|
||||
key, iv,
|
||||
}
|
||||
}).expect("Failed to CBOR encode Key+IV");
|
||||
println!("(cbor) KEY+IV: {} bytes", data.len());
|
||||
|
||||
let key = cryptohelpers::rsa::RsaPublicKey::generate()?;
|
||||
let rsa = cryptohelpers::rsa::encrypt_slice_to_vec(data, &key)?;
|
||||
println!("Rsa ciphertext size: {}", rsa.len());
|
||||
|
||||
assert_eq!(rsa.len(), super::RSA_CIPHERTEXT_SIZE, "Incorrect RSA ciphertext length constant for cc20 KEY+IV CBOR encoding.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn gen_duplex_esock(bufsz: usize) -> crate::eyre::Result<(ESock<tokio::io::DuplexStream, tokio::io::DuplexStream>, ESock<tokio::io::DuplexStream, tokio::io::DuplexStream>)>
|
||||
{
|
||||
use crate::*;
|
||||
let (atx, brx) = tokio::io::duplex(bufsz);
|
||||
let (btx, arx) = tokio::io::duplex(bufsz);
|
||||
let tx = ESock::new(atx, arx).wrap_err(eyre!("Failed to create TX"))?;
|
||||
let rx = ESock::new(btx, brx).wrap_err(eyre!("Failed to create RX"))?;
|
||||
Ok((tx, rx))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn esock_exchange() -> crate::eyre::Result<()>
|
||||
{
|
||||
use crate::*;
|
||||
|
||||
const VALUE: &'static [u8] = b"Hello world!";
|
||||
|
||||
// The duplex buffer size here is smaller than an RSA ciphertext block. So, writing the session key must be buffered with a buffer size this small (should return Pending at least once.)
|
||||
// Using a low buffer size to make sure the test passes even when the entire buffer cannot be written at once.
|
||||
let (mut tx, mut rx) = gen_duplex_esock(16).wrap_err(eyre!("Failed to weave socks"))?;
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
use tokio::prelude::*;
|
||||
|
||||
tx.exchange().await?;
|
||||
assert!(tx.has_exchanged());
|
||||
|
||||
tx.set_encrypted_write(true).await?;
|
||||
assert_eq!((true, false), tx.is_encrypted());
|
||||
|
||||
tx.write_all(VALUE).await?;
|
||||
tx.write_all(VALUE).await?;
|
||||
|
||||
// Check resp
|
||||
tx.set_encrypted_read(true).await?;
|
||||
assert_eq!({
|
||||
let mut chk = [0u8; 3];
|
||||
tx.read_exact(&mut chk[..]).await?;
|
||||
chk
|
||||
}, [0xaau8,0, 0], "Failed response check");
|
||||
|
||||
// Write unencrypted
|
||||
tx.set_encrypted_write(false).await?;
|
||||
tx.write_all(&[2,1,0xfa]).await?;
|
||||
|
||||
Result::<_, eyre::Report>::Ok(VALUE)
|
||||
});
|
||||
let reader = tokio::spawn(async move {
|
||||
use tokio::prelude::*;
|
||||
|
||||
rx.exchange().await?;
|
||||
assert!(rx.has_exchanged());
|
||||
|
||||
rx.set_encrypted_read(true).await?;
|
||||
assert_eq!((false, true), rx.is_encrypted());
|
||||
|
||||
let mut val = vec![0u8; VALUE.len()];
|
||||
rx.read_exact(&mut val[..]).await?;
|
||||
|
||||
let mut val2 = vec![0u8; VALUE.len()];
|
||||
rx.read_exact(&mut val2[..]).await?;
|
||||
|
||||
assert_eq!(val, val2);
|
||||
|
||||
// Send resp
|
||||
rx.set_encrypted_write(true).await?;
|
||||
rx.write_all(&[0xaa, 0, 0]).await?;
|
||||
|
||||
// Read unencrypted
|
||||
rx.set_encrypted_read(false).await?;
|
||||
assert_eq!({
|
||||
let mut buf = [0u8; 3];
|
||||
rx.read_exact(&mut buf[..]).await?;
|
||||
buf
|
||||
}, [2u8,1,0xfa], "2nd response incorrect");
|
||||
|
||||
Result::<_, eyre::Report>::Ok(val)
|
||||
});
|
||||
let (writer, reader) = tokio::join![writer, reader];
|
||||
|
||||
let writer = writer.expect("Tx task panic");
|
||||
let reader = reader.expect("Rx task panic");
|
||||
|
||||
eprintln!("Txr: {:?}", writer);
|
||||
eprintln!("Rxr: {:?}", reader);
|
||||
writer?;
|
||||
let val = reader?;
|
||||
println!("Read: {:?}", val);
|
||||
|
||||
assert_eq!(&val, VALUE);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn esock_split() -> crate::eyre::Result<()>
|
||||
{
|
||||
use super::*;
|
||||
const SLICES: &'static [&'static [u8]] = &[
|
||||
&[1,5,3,7,6,9,100,0],
|
||||
&[7,6,2,90],
|
||||
&[3,6,1,0],
|
||||
&[5,1,3,3],
|
||||
];
|
||||
let result = SLICES.iter().map(|&slice| slice.iter().map(|&b| u64::from(b)).sum::<u64>()).sum::<u64>();
|
||||
println!("Result: {}", result);
|
||||
let (mut tx, mut rx) = gen_duplex_esock(super::TRANS_KEY_MAX_SIZE * 4).wrap_err(eyre!("Failed to weave socks"))?;
|
||||
let (writer, reader) = {
|
||||
use tokio::prelude::*;
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
tx.exchange().await?;
|
||||
|
||||
let (mut tx, mut rx) = tx.split();
|
||||
|
||||
//tx.set_encryption(true).await?;
|
||||
|
||||
let slices = &SLICES[1..];
|
||||
for &slice in slices.iter()
|
||||
{
|
||||
println!("Writing slice: {:?}", slice);
|
||||
tx.write_all(slice).await?;
|
||||
}
|
||||
|
||||
//let mut tx = ESock::unsplit(tx, rx);
|
||||
|
||||
tx.write_all(SLICES[0]).await?;
|
||||
|
||||
Result::<_, eyre::Report>::Ok(())
|
||||
});
|
||||
let reader = tokio::spawn(async move {
|
||||
rx.exchange().await?;
|
||||
|
||||
let (mut tx, mut rx) = rx.split();
|
||||
|
||||
//rx.set_encryption(true).await?;
|
||||
|
||||
let (mut mtx, mut mrx) = tokio::sync::mpsc::channel::<Vec<u8>>(16);
|
||||
let sorter = tokio::spawn(async move {
|
||||
let mut done = 0u64;
|
||||
while let Some(buf) = mrx.recv().await
|
||||
{
|
||||
//buf.sort();
|
||||
done += buf.iter().map(|&b| u64::from(b)).sum::<u64>();
|
||||
println!("Got buffer: {:?}", buf);
|
||||
tx.write_all(&buf).await?;
|
||||
}
|
||||
Result::<_, eyre::Report>::Ok(done)
|
||||
});
|
||||
let mut buffer = [0u8; 16];
|
||||
while let Ok(read) = rx.read(&mut buffer[..]).await
|
||||
{
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
mtx.send(Vec::from(&buffer[..read])).await?;
|
||||
}
|
||||
drop(mtx);
|
||||
let sum = sorter.await.expect("(reader) Sorter task panic")?;
|
||||
Result::<_, eyre::Report>::Ok(sum)
|
||||
});
|
||||
let (writer, reader) = tokio::join![writer, reader];
|
||||
|
||||
(writer.expect("Writer task panic"),
|
||||
reader.expect("Reader task panic"))
|
||||
};
|
||||
|
||||
writer?;
|
||||
assert_eq!(result, reader?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,93 +0,0 @@
|
||||
//! Socket handlers
|
||||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
use std::net::{
|
||||
SocketAddr,
|
||||
};
|
||||
use tokio::io::{
|
||||
AsyncWrite,
|
||||
AsyncRead
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::{
|
||||
mpsc,
|
||||
oneshot,
|
||||
};
|
||||
use futures::Future;
|
||||
use bytes::Bytes;
|
||||
|
||||
use cancel::*;
|
||||
|
||||
pub mod enc;
|
||||
pub mod pipe;
|
||||
|
||||
/// Details of a newly accepted raw socket peer.
|
||||
///
|
||||
/// This connected will have been "accepted", but not yet trusted
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct RawSockPeerAccepted
|
||||
{
|
||||
/// Address of socket.
|
||||
pub addr: SocketAddr,
|
||||
/// Trust this peer from the start? This should almost always be false.
|
||||
pub auto_trust: bool,
|
||||
}
|
||||
|
||||
/// Details of a connected, set-up raw socket connection.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RawSockPeerTrusted
|
||||
{
|
||||
/// The socket's details
|
||||
pub sock_details: RawSockPeerAccepted,
|
||||
|
||||
/// Capabilities for this peer
|
||||
pub cap_allow: HashSet<cap::RawSockCapability>,
|
||||
}
|
||||
|
||||
/// A raw, received message from a `RawPeer`.
|
||||
#[derive(Debug)]
|
||||
pub struct RawMessage{
|
||||
message_bytes: Bytes,
|
||||
}
|
||||
|
||||
/// A connected raw peer, created and handled by `handle_new_socket_with_shutdown()`.
|
||||
#[derive(Debug)]
|
||||
pub struct RawPeer{
|
||||
info: RawSockPeerTrusted,
|
||||
rx: mpsc::Receiver<RawMessage>,
|
||||
}
|
||||
|
||||
/// Handles a **newly connected** raw socket.
|
||||
///
|
||||
/// This will handle setting up socket peer encryption and validation.
|
||||
pub fn handle_new_socket_with_shutdown<R, W, C: cancel::CancelFuture + 'static + Send>(
|
||||
sock_details: RawSockPeerAccepted,
|
||||
set_peer: oneshot::Sender<RawPeer>,
|
||||
tx: W, rx: R,
|
||||
shutdown: C
|
||||
) -> JoinHandle<eyre::Result<()>>
|
||||
where R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static
|
||||
{
|
||||
tokio::spawn(async move {
|
||||
match {
|
||||
with_cancel!(async move {
|
||||
// Create empty cap
|
||||
let mut sock_details = RawSockPeerTrusted {
|
||||
sock_details,
|
||||
cap_allow: HashSet::new(),
|
||||
};
|
||||
|
||||
// Set up encryption
|
||||
|
||||
|
||||
//TODO: Find caps for this peer.
|
||||
|
||||
Ok(())
|
||||
}, shutdown)
|
||||
} {
|
||||
Ok(v) => v,
|
||||
Err(x) => Err(eyre::Report::from(x)),
|
||||
}
|
||||
})
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
//! Piping buffered data from a raw socket to `ESock`
|
||||
//!
|
||||
//! This exists because i'm too dumb to implement a functional AsyncRead/Write buffered wrapper stream :/
|
||||
use super::*;
|
||||
use std::{
|
||||
io,
|
||||
marker::{
|
||||
Send, Sync,
|
||||
|
||||
Unpin,
|
||||
PhantomData,
|
||||
},
|
||||
};
|
||||
use tokio::sync::{
|
||||
mpsc,
|
||||
};
|
||||
use enc::{
|
||||
ESock,
|
||||
|
||||
ESockReadHalf,
|
||||
ESockWriteHalf,
|
||||
};
|
||||
|
||||
/// The default buffer size for `BufferedESock`.
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 32;
|
||||
|
||||
/// Task-based buffered piping to/from encrypted sockets.
|
||||
pub struct BufferedESock<W, R>
|
||||
{
|
||||
bufsz: usize,
|
||||
_backing: PhantomData<ESock<W, R>>,
|
||||
}
|
||||
|
||||
impl<W, R> BufferedESock<W, R>
|
||||
where W: AsyncWrite + Unpin + Send + 'static,
|
||||
R: AsyncRead + Unpin + Send + 'static
|
||||
{
|
||||
/// Create a new buffered ESock pipe with a specific buffer size
|
||||
pub fn with_size(tx: W, rx: R, bufsz: usize) -> Self
|
||||
{
|
||||
//TODO: Spawn read+write buffer tasks
|
||||
Self {
|
||||
bufsz,
|
||||
_backing: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new buffered ESock pipe with the default buffer size (`DEFAULT_BUFFER_SIZE`).
|
||||
#[inline] pub fn new(tx: W, rx: R) -> Self
|
||||
{
|
||||
Self::with_size(tx, rx, DEFAULT_BUFFER_SIZE)
|
||||
}
|
||||
}
|
Loading…
Reference in new issue