Compare commits

..

1 Commits

Author SHA1 Message Date
Avril af19935167
Attempted specialisation. Failed.
3 years ago

32
Cargo.lock generated

@ -2,12 +2,6 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "ad-hoc-iter"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90a8dd76beceb5313687262230fcbaaf8d4e25c37541350cf0932e9adb8309c8"
[[package]]
name = "addr2line"
version = "0.16.0"
@ -83,15 +77,6 @@ version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38"
[[package]]
name = "bytes"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040"
dependencies = [
"serde",
]
[[package]]
name = "cc"
version = "1.0.69"
@ -112,19 +97,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chacha20stream"
version = "2.1.0"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54c8d48b47fa0a89a94b80d32b1b3fc9ffc1a232a5201ff5a2d14ac77bc7561d"
checksum = "3a91f983a237d46407e744f0b9c5d2866f018954de5879905d7af6bf06953aea"
dependencies = [
"base64 0.13.0",
"getrandom 0.2.3",
"libc",
"openssl",
"pin-project",
"rustc_version",
"serde",
"smallvec",
"stackalloc",
"tokio 0.2.25",
]
@ -185,11 +167,10 @@ dependencies = [
[[package]]
name = "cryptohelpers"
version = "1.8.2"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9143447fb393f8d38abbb617af9b986a0941785ddc63685bd8de735fb31bcafc"
checksum = "14be74ce15793a86acd04872953368ce27d07f384f07b8028bd5aaa31a031a38"
dependencies = [
"base64 0.13.0",
"crc",
"futures",
"getrandom 0.1.16",
@ -799,9 +780,6 @@ dependencies = [
name = "rsh"
version = "0.1.0"
dependencies = [
"ad-hoc-iter",
"base64 0.13.0",
"bytes 1.0.1",
"chacha20stream",
"color-eyre",
"cryptohelpers",
@ -972,7 +950,7 @@ version = "0.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6703a273949a90131b290be1fe7b039d0fc884aa1935860dfcbe056f28cd8092"
dependencies = [
"bytes 0.5.6",
"bytes",
"fnv",
"futures-core",
"iovec",

@ -6,18 +6,15 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ad-hoc-iter = "0.2.3"
base64 = "0.13.0"
bytes = { version = "1.0.1", features = ["serde"] }
chacha20stream = { version = "2.1.0", features = ["async", "serde"] }
chacha20stream = { version = "1.0.3", features = ["async"] }
color-eyre = "0.5.11"
cryptohelpers = { version = "1.8.2" , features = ["serialise", "full"] }
cryptohelpers = { version = "1.8.1" , features = ["serialise", "full"] }
futures = "0.3.16"
mopa = "0.2.2"
pin-project = "1.0.8"
serde = { version = "1.0.126", features = ["derive"] }
serde_cbor = "0.11.1"
smallvec = { version = "1.6.1", features = ["union", "serde", "write", "const_generics"] }
smallvec = { version = "1.6.1", features = ["union", "serde", "write"] }
stackalloc = "1.1.1"
tokio = { version = "0.2", features = ["full"] }
tokio-uring = "0.1.0"

@ -1,37 +0,0 @@
//! Binary / byte maniuplation
use bytes::BufMut;
/// Concatenate an iterator of byte slices into a buffer.
///
/// # Returns
/// The number of bytes written
/// # Panics
/// If the buffer cannot hold all the slices
pub fn collect_slices_into<B: BufMut + ?Sized, I, T>(into: &mut B, from: I) -> usize
where I: IntoIterator<Item=T>,
T: AsRef<[u8]>
{
let mut done =0;
for slice in from.into_iter()
{
let s = slice.as_ref();
into.put_slice(s);
done+=s.len();
}
done
}
/// Collect an iterator of byte slices into a new exact-size buffer.
///
/// # Returns
/// The number of bytes written, and the new array
///
/// # Panics
/// If the total bytes in all slices exceeds `SIZE`.
pub fn collect_slices_exact<T, I, const SIZE: usize>(from: I) -> (usize, [u8; SIZE])
where I: IntoIterator<Item=T>,
T: AsRef<[u8]>
{
let mut output = [0u8; SIZE];
(collect_slices_into(&mut &mut output[..], from), output)
}

@ -1,43 +0,0 @@
//! Capabilities (permissions) of a connection.
use super::*;
use std::time::Duration;
/*
pub mod fail;
pub use fail::Failures;
/// How lenient to be with a certain operation
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Copy)]
pub enum Leniency
{
/// Ignore **all** malformed/missed messages.
Ignore,
/// Allow `n` failures before disconnecting the socket.
Specific(Failures),
/// Allow
Normal,
/// Immediately disconnect the socket on **any** malformed/missed message.
None,
}
*/
/// A capability (permission) for a raw socket's data transmission.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RawSockCapability
{
/// Process messages that aren't signed.
AllowUnsignedMessages,
/// Do not disconnect the socket when a malformed message is received, just ignore the message.
SoftFail,
/// Throttle the number of messages to process
RateLimit { tx: usize, rx: usize },
/// The request response timeout for messages with an expected response.
RecvRespTimeout { tx: Duration, rx: Duration },
/// Max number of bytes to read for a single message.
//TODO: Implement this for message
MaxMessageSize(usize),
}

@ -1,102 +0,0 @@
//! Failure counting/capping
use super::*;
use std::{
fmt, error,
};
/// A measure of failures, used to track or to check failures.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Copy)]
pub struct Failures
{
/// Number of failures happened back-to-back.
pub seq: usize,
/// Total number of failures over the socket's lifetime.
///
/// Set allowed to `0` for unlimited.
pub total: usize,
/// Window of time to keep failures.
pub window: Duration,
/// Number of failures happened in the last `window` of time.
pub last_window: usize,
}
/// When a failure cap is exceeded, which one is exceeded; and what is the limit that is exceeded?
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FailureCapExceeded
{
/// Too many sequential errors
Sequential(usize),
/// Too many total errors
Total(usize),
/// Too many errors in the refresh window
Windowed(usize, Duration),
}
impl error::Error for FailureCapExceeded{}
impl fmt::Display for FailureCapExceeded
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
match self {
Self::Sequential(_) => write!(f, "too many sequential errors"),
Self::Total(_) => write!(f, "too many total errors"),
Self::Windowed(_, _) => write!(f, "too many errors in refresh window"),
}
}
}
impl FailureCapExceeded
{
/// Convert into a detailed report. (This shouldn't be shared with peer, probably.)
pub fn into_detailed_report(self) -> eyre::Report
{
let rep = eyre::Report::from(&self);
match self {
Self::Sequential(cap) => rep.with_section(|| cap.header("Exceeded limit")),
Self::Total(cap) => rep.with_section(|| cap.header("Exceeded limit")),
Self::Windowed(cap, w) => rep.with_section(|| cap.header("Exceeded limit"))
.with_section(|| format!("{:?}", w).header("Refresh window was"))
}
}
}
impl Failures
{
/// Default allowed failure parameters.
pub const DEFAULT_ALLOWED: Self = Self {
seq: 10,
total: 65536,
window: Duration::from_secs(10),
last_window: 5,
};
/// Has this `Failures` exceeded failure cap `other`?
pub fn cap_check(&self, other: &Self) -> Result<(), FailureCapExceeded>
{
macro_rules! chk {
($name:ident, $err:expr) => {
if other.$name != 0 && (self.$name >= other.$name) {
return Err($err)?;
}
}
}
chk!(seq, FailureCapExceeded::Sequential(other.seq));
chk!(total, FailureCapExceeded::Total(other.total));
//debug_assert!(other.window == self.window); //TODO: Should we disallow this?
chk!(last_window, FailureCapExceeded::Windowed(other.last_window, self.window));
Ok(())
}
}
impl Default for Failures
{
#[inline]
fn default() -> Self
{
Self::DEFAULT_ALLOWED
}
}

@ -4,14 +4,11 @@ use std::mem::{self, MaybeUninit};
use std::iter;
use smallvec::SmallVec;
mod alloc;
pub use alloc::*;
/// Max size of memory allowed to be allocated on the stack.
pub const STACK_MEM_ALLOC_MAX: usize = 4096;
mod hex;
pub use hex::*;
mod base64;
pub use self::base64::*;
/// A stack-allocated vector that spills onto the heap when needed.
pub type StackVec<T> = SmallVec<[T; STACK_MEM_ALLOC_MAX]>;
/// A maybe-atom that can spill into a vector.
pub type MaybeVec<T> = SmallVec<[T; 1]>;
@ -26,6 +23,89 @@ pub fn vec_uninit<T>(sz: usize) -> Vec<MaybeUninit<T>>
}
}
/// Allocate a local buffer initialised from `init`.
pub fn alloc_local_with<T, U>(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with(sz, init, within)
}
}
/// Allocate a local zero-initialised byte buffer
pub fn alloc_local_bytes<U>(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<MaybeUninit<u8>> = vec_uninit(sz);
within(unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz);
stackalloc::helpers::slice_assume_init_mut(&mut memory[..])
})
} else {
stackalloc::alloca_zeroed(sz, within)
}
}
/// Allocate a local zero-initialised buffer
pub fn alloc_local_zeroed<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
let sz_bytes = mem::size_of::<T>() * sz;
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes);
}
within(&mut memory[..])
} else {
stackalloc::alloca_zeroed(sz_bytes, move |buf| {
unsafe {
debug_assert_eq!(buf.len() / mem::size_of::<T>(), sz);
within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<T>, sz))
}
})
}
}
/// Allocate a local uninitialised buffer
pub fn alloc_local_uninit<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
within(&mut memory[..])
} else {
stackalloc::stackalloc_uninit(sz, within)
}
}
/// Allocate a local buffer initialised with `init`.
pub fn alloc_local<T: Clone, U>(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc(sz, init, within)
}
}
/// Allocate a local buffer initialised with `T::default()`.
pub fn alloc_local_with_default<T: Default, U>(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(Default::default).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with_default(sz, within)
}
}
/// Create a blanket-implementing trait that is a subtrait of any number of traits.
///
/// # Usage
@ -91,9 +171,7 @@ const _:() = {
let _ref: &std::io::Stdin = a.downcast_ref::<std::io::Stdin>().unwrap();
}
}
/*
XXX: This is broken on newest nightly?
const _TEST: () = _a::<dyn Test>();
const _TEST2: () = _b::<dyn TestAny>();
*/
};

@ -1,91 +0,0 @@
//! Stack allocation helpers
use super::*;
/// Max size of memory allowed to be allocated on the stack.
pub const STACK_MEM_ALLOC_MAX: usize = 2048; // 2KB
/// A stack-allocated vector that spills onto the heap when needed.
pub type StackVec<T> = SmallVec<[T; STACK_MEM_ALLOC_MAX]>;
/// Allocate a local buffer initialised from `init`.
pub fn alloc_local_with<T, U>(sz: usize, init: impl FnMut() -> T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with(sz, init, within)
}
}
/// Allocate a local zero-initialised byte buffer
pub fn alloc_local_bytes<U>(sz: usize, within: impl FnOnce(&mut [u8]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<MaybeUninit<u8>> = vec_uninit(sz);
within(unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz);
stackalloc::helpers::slice_assume_init_mut(&mut memory[..])
})
} else {
stackalloc::alloca_zeroed(sz, within)
}
}
/// Allocate a local zero-initialised buffer
pub fn alloc_local_zeroed<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
let sz_bytes = mem::size_of::<T>() * sz;
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
unsafe {
std::ptr::write_bytes(memory.as_mut_ptr(), 0, sz_bytes);
}
within(&mut memory[..])
} else {
stackalloc::alloca_zeroed(sz_bytes, move |buf| {
unsafe {
debug_assert_eq!(buf.len() / mem::size_of::<T>(), sz);
within(std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<T>, sz))
}
})
}
}
/// Allocate a local uninitialised buffer
pub fn alloc_local_uninit<T, U>(sz: usize, within: impl FnOnce(&mut [MaybeUninit<T>]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory = vec_uninit(sz);
within(&mut memory[..])
} else {
stackalloc::stackalloc_uninit(sz, within)
}
}
/// Allocate a local buffer initialised with `init`.
pub fn alloc_local<T: Clone, U>(sz: usize, init: T, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat(init).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc(sz, init, within)
}
}
/// Allocate a local buffer initialised with `T::default()`.
pub fn alloc_local_with_default<T: Default, U>(sz: usize, within: impl FnOnce(&mut [T]) -> U) -> U
{
if sz > STACK_MEM_ALLOC_MAX {
let mut memory: Vec<T> = iter::repeat_with(Default::default).take(sz).collect();
within(&mut memory[..])
} else {
stackalloc::stackalloc_with_default(sz, within)
}
}

@ -1,15 +0,0 @@
//! Base64 formatting extensions
use super::*;
pub trait Base64StringExt
{
fn to_base64_string(&self) -> String;
}
impl<T: ?Sized> Base64StringExt for T
where T: AsRef<[u8]>
{
fn to_base64_string(&self) -> String {
::base64::encode(self.as_ref())
}
}

@ -1,124 +0,0 @@
use std::{
mem,
iter::{
self,
ExactSizeIterator,
FusedIterator,
},
slice,
fmt,
};
#[derive(Debug, Clone)]
pub struct HexStringIter<I>(I, [u8; 2]);
impl<I: Iterator<Item = u8>> HexStringIter<I>
{
/// Write this hex string iterator to a formattable buffer
pub fn consume<F>(self, f: &mut F) -> fmt::Result
where F: std::fmt::Write
{
if self.1[0] != 0 {
write!(f, "{}", self.1[0] as char)?;
}
if self.1[1] != 0 {
write!(f, "{}", self.1[1] as char)?;
}
for x in self.0 {
write!(f, "{:02x}", x)?;
}
Ok(())
}
/// Consume into a string
pub fn into_string(self) -> String
{
let mut output = match self.size_hint() {
(0, None) => String::new(),
(_, Some(x)) |
(x, None) => String::with_capacity(x),
};
self.consume(&mut output).unwrap();
output
}
}
pub trait HexStringIterExt<I>: Sized
{
fn into_hex(self) -> HexStringIter<I>;
}
pub type HexStringSliceIter<'a> = HexStringIter<iter::Copied<slice::Iter<'a, u8>>>;
pub trait HexStringSliceIterExt
{
fn hex(&self) -> HexStringSliceIter<'_>;
}
impl<S> HexStringSliceIterExt for S
where S: AsRef<[u8]>
{
fn hex(&self) -> HexStringSliceIter<'_>
{
self.as_ref().iter().copied().into_hex()
}
}
impl<I: IntoIterator<Item=u8>> HexStringIterExt<I::IntoIter> for I
{
#[inline] fn into_hex(self) -> HexStringIter<I::IntoIter> {
HexStringIter(self.into_iter(), [0u8; 2])
}
}
impl<I: Iterator<Item = u8>> Iterator for HexStringIter<I>
{
type Item = char;
fn next(&mut self) -> Option<Self::Item>
{
match self.1 {
[_, 0] => {
use std::io::Write;
write!(&mut self.1[..], "{:02x}", self.0.next()?).unwrap();
Some(mem::replace(&mut self.1[0], 0) as char)
},
[0, _] => Some(mem::replace(&mut self.1[1], 0) as char),
_ => unreachable!(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (l, h) = self.0.size_hint();
(l * 2, h.map(|x| x*2))
}
}
impl<I: Iterator<Item = u8> + ExactSizeIterator> ExactSizeIterator for HexStringIter<I>{}
impl<I: Iterator<Item = u8> + FusedIterator> FusedIterator for HexStringIter<I>{}
impl<I: Iterator<Item = u8>> From<HexStringIter<I>> for String
{
fn from(from: HexStringIter<I>) -> Self
{
from.into_string()
}
}
impl<I: Iterator<Item = u8> + Clone> fmt::Display for HexStringIter<I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
self.clone().consume(f)
}
}
/*
#[macro_export] macro_rules! prog1 {
($first:expr, $($rest:expr);+ $(;)?) => {
($first, $( $rest ),+).0
}
}
*/

@ -1,10 +1,12 @@
//! Remote communication
#![cfg_attr(nightly, feature(const_fn_trait_bound))]
#![cfg_attr(nightly, feature(specialization))]
#![allow(dead_code)]
#[macro_use] extern crate mopa;
#[macro_use] extern crate serde;
#[macro_use] extern crate ad_hoc_iter;
#[macro_use] extern crate pin_project;
#[allow(unused_imports)]
@ -22,18 +24,15 @@ use std::convert::{
};
mod ext; use ext::*;
mod bin;
mod message;
mod cancel;
mod cap;
mod sock;
//mod pipeline;
#[tokio::main]
async fn main() -> eyre::Result<()>
{
async fn main() -> eyre::Result<()> {
println!("Hello, world!");
Ok(())
}

@ -1,6 +1,7 @@
//! Messages
use super::*;
use std::marker::PhantomData;
use serde::{Serialize, Deserialize};
use cryptohelpers::{
sha256,
aes,
@ -27,25 +28,12 @@ pub use builder::*;
pub mod value;
pub use value::MessageValue;
/// A `SerializedMessage` whos type has been erased.
pub type UntypedSerializedMessage = SerializedMessage<value::UntypedMessageValue>;
/// Size of encrypted AES key
pub const RSA_BLOCK_SIZE: usize = 512;
/// Max size to pre-allocate when reading a message buffer.
pub const MAX_ALLOC_SIZE: usize = 4096; // 4kb
/// Max size to allow reading for a message buffer component.
///
/// Not including the message body, see `MAX_BODY_SIZE` for body.
pub const MAX_READ_SIZE: usize = 2048 * 1024; // 2MB.
/// Max allowed size of a single message body.
///
/// Set to 0 for unlimited.
pub const MAX_BODY_SIZE: usize = (1024 * 1024) * 1024; // 1GB
/// A message that can send a value into bytes.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Message<V: ?Sized + MessageValue>
@ -72,7 +60,7 @@ impl<V: ?Sized + MessageValue> Message<V>
macro_rules! accessor {
($name:ident, $fn_name:ident, $type:ty $(; $comment:literal)?) => {
$(#[doc = $comment])?
#[inline] pub fn $fn_name(&self) -> &$type
#[inline] pub fn $fn_name(&self) -> &$type
{
&self.0.$name
}
@ -99,8 +87,6 @@ struct SerHeader
timestamp: u64,
/// `id` of message this one is responding to, if needed.
responds_to: Option<Uuid>,
//TODO: Add `flags` bitflags
//TODO: Add `kind` enum
}
/// A reference to a message's header.
@ -119,7 +105,7 @@ impl<V: ?Sized + MessageValue> AsRef<V> for Message<V>
///
/// Messages of this type are not yet validated, and may be invalid/corrupt. The validation happens when converting back to a `Message<V>` (of the same `V`.)
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SerializedMessage<V: ?Sized>
pub struct SerializedMessage<V: ?Sized + MessageValue>
{
header: SerHeader,
/// cbor serialised `V`.
@ -168,11 +154,7 @@ impl<V: ?Sized + MessageValue> Message<V>
{
let send_with: &S = send_with.borrow();
let data = serde_cbor::to_vec(&self.value)?;
if MAX_BODY_SIZE > 0 && data.len() > MAX_BODY_SIZE {
return Err(eyre!("Encoded body is too large"))
.with_section(|| data.len().header("Body size was"))
.with_section(|| MAX_BODY_SIZE.header("Max size is"));
}
let sig = if self.sign {
Some(send_with.sign_data(&data[..]).expect("Message expected signing, sender did not support it"))
} else {
@ -250,19 +232,13 @@ impl<V: ?Sized + MessageValue> Message<V>
}
}
mod binary;
pub use binary::*;
impl<V: ?Sized> SerializedMessage<V>
impl<V: ?Sized + MessageValue> SerializedMessage<V>
{
/// Get the message header
#[inline(always)] pub fn header(&self) -> MessageHeader<'_, V>
{
MessageHeader(&self.header, PhantomData)
}
}
/*
/// Consume into an async writer
pub async fn into_writer_async<W:AsyncWrite+Unpin>(self, mut writer: W) -> eyre::Result<usize>
{
@ -340,18 +316,6 @@ impl<V: ?Sized> SerializedMessage<V>
Ok(w)
}
/// Consume into `Vec<u8>`.
pub fn into_bytes(self) -> Vec<u8>
{
let mut v = Vec::with_capacity(self.data.len()<<1);
self.into_writer(&mut v).expect("Failed to write to in-memory buffer");
v
}
}
*/
/*
impl<V: ?Sized + MessageValue> SerializedMessage<V>
{
/// Create from a reader.
///
/// The message may be in an invalid state. It is only possible to extract the value after validating it into a `Message<V>`.
@ -434,7 +398,13 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
_phantom: PhantomData,
})
}
/// Consume into `Vec<u8>`.
pub fn into_bytes(self) -> Vec<u8>
{
let mut v = Vec::with_capacity(self.data.len()<<1);
self.into_writer(&mut v).expect("Failed to write to in-memory buffer");
v
}
/// Create from bytes
#[inline] pub fn from_bytes(bytes: impl AsRef<[u8]>) -> eyre::Result<Self>
{
@ -442,6 +412,6 @@ impl<V: ?Sized + MessageValue> SerializedMessage<V>
Self::from_reader(&mut &bytes[..])
}
}
*/
#[cfg(test)]
mod tests;

@ -1,198 +0,0 @@
//! Creating binary from messages
//!
//! `SerializedMessage` to `Bytes`.
use super::*;
use bytes::{
Bytes,
BytesMut,
BufMut,
Buf,
};
macro_rules! try_from {
(ref $into:ty, $from:expr $(; $fmt:literal)?) => {
{
let from = &$from;
#[inline] fn _type_name_of_val<T: ?Sized>(_: &T) -> &'static str
{
::std::any::type_name::<T>()
}
<$into>::try_from(*from).wrap_err(eyre!("Failed to convert type"))
.with_section(|| ::std::any::type_name::<$into>().header("New type"))
.with_section(|| _type_name_of_val(from).header("Old type"))
.with_section(|| from.to_string().header("Value was"))
}
};
}
/// Check bit written/read from binary streams to check basic integrity.
pub const MESSAGE_HEADER_CHECK: [u8; 4] = (0xc0ffee00u32).to_be_bytes();
impl<V: ?Sized> SerializedMessage<V>
{
/// Write this message to a buffer
///
/// # Panics
/// If `buffer` cannot hold enough bytes.
pub fn into_buffer(self, mut buffer: impl BufMut) -> eyre::Result<usize>
{
let mut w=0;
macro_rules! write {
($bytes:expr) => {
{
let slice: &[u8] = ($bytes).as_ref();
buffer.put_slice(slice);
w+=slice.len();
}
};
(? $o:expr) => {
{
match $o {
Some(opt) => {
buffer.put_u8(1);
write!(opt);
},
None => {buffer.put_u8(0);},
}
w+=1;
}
};
(: $ser:expr) => {
{
let mut v = StackVec::new();
#[inline] fn _type_name_of_val<T: ?Sized>(_: &T) -> &'static str
{
::std::any::type_name::<T>()
}
let ser = $ser;
serde_cbor::to_writer(&mut v, $ser)
.wrap_err(eyre!("Failed to serialise value to temporary buffer"))
.with_section(|| _type_name_of_val(ser).header("Type was"))?;
buffer.put_u64(try_from!(ref u64, v.len())?);
write!(&v[..]);
}
};
}
write!(MESSAGE_HEADER_CHECK);
write!(: &self.header);
buffer.put_u64(try_from!(ref u64, self.data.len())?);
write!(self.data);
write!(self.hash);
write!(? self.enc_key);
write!(? self.sig);
Ok(w)
}
/// Write this message to a new `Bytes`.
pub fn into_bytes(self) -> eyre::Result<Bytes>
{
let mut output = BytesMut::with_capacity(4096); //TODO: Find a better default capacity for this.
self.into_buffer(&mut output)?;
Ok(output.freeze())
}
}
impl<V: ?Sized + MessageValue> SerializedMessage<V>
{
/// Create from a buffer of bytes.
///
/// # Panics
/// If `bytes` does not contain enough data to read.
pub fn from_buffer(mut bytes: impl Buf) -> eyre::Result<Self>
{
macro_rules! read {
($bref:expr) => {
{
let by: &mut [u8] = ($bref).as_mut();
bytes.copy_to_slice(by);
}
};
(? $odef:expr) => {
{
let by = bytes.get_u8();
match by {
0 => None,
1 => {
let mut def = $odef;
read!(&mut def);
Some(def)
},
x => {
return Err(eyre!("Invalid optional-set bit (should be 0 or 1)").with_section(|| x.header("Value was")));
}
}
}
};
(: $ser:ty) => {
{
let len = try_from!(ref usize, bytes.get_u64())?;
if len > MAX_READ_SIZE {
return Err(eyre!("Invalid length read: {}", len)
.with_section(|| MAX_READ_SIZE.header("Max length read")));
}
alloc_local_bytes(len, |de| {
read!(&mut de[..]);
serde_cbor::from_slice::<$ser>(&de[..]).wrap_err(eyre!("Failed to deserialise CBOR from reader")).with_section(|| std::any::type_name::<$ser>().header("Type to deserialise was"))
})?
}
};
($into:expr, $num:expr) => {
{
let num = $num;
let reader = (&mut bytes).reader();
copy_buffer($into, reader, num).wrap_err(eyre!("Failed to read {} bytes from reader", num))?
}
}
}
let mut check = [0u8; MESSAGE_HEADER_CHECK.len()];
read!(&mut check[..]);
if check != MESSAGE_HEADER_CHECK {
return Err(eyre!("Invalid check bit for message header"))
.with_section(|| u32::from_be_bytes(check).header("Expected"))
.with_section(|| u32::from_be_bytes(MESSAGE_HEADER_CHECK).header("Got"));
}
let header = read!(: SerHeader);
let data_len = try_from!(ref usize, bytes.get_u64())?;
if MAX_BODY_SIZE > 0 && data_len > MAX_BODY_SIZE {
return Err(eyre!("Body size too large"))
.with_section(|| data_len.header("Encoded size was"))
.with_section(|| MAX_BODY_SIZE.header("Max size is"));
}
let mut data = Vec::with_capacity(std::cmp::min(data_len, MAX_ALLOC_SIZE)); //XXX: Redesign so we don't allocate OR try to read massive buffers by accident on corrupted/malformed messages
read!(&mut data, data_len);
if data.len()!=data_len {
return Err(eyre!("Failed to read body bytes from buffer"))
.with_section(|| format!("{} bytes", data_len).header("Tried to read"))
.with_section(|| format!("{} bytes", data.len()).header("Read only"));
}
let mut hash = sha256::Sha256Hash::default();
read!(&mut hash);
let enc_key: Option<[u8; RSA_BLOCK_SIZE]> = read!(? [0u8; RSA_BLOCK_SIZE]);
let sig: Option<rsa::Signature> = read!(? rsa::Signature::default());
Ok(Self {
header,
data,
hash,
enc_key,
sig,
_phantom: PhantomData,
})
}
/// Create from a slice of bytes
#[inline] pub fn from_slice(bytes: impl AsRef<[u8]>) -> eyre::Result<Self>
{
Self::from_buffer(bytes.as_ref())
}
/// Create from a `Bytes` instance
#[inline(always)] pub fn from_bytes(bytes: Bytes) -> eyre::Result<Self>
{
Self::from_buffer(bytes)
}
}

@ -12,9 +12,9 @@ where S: MessageSender,
println!(">> Created message: {:?}", message);
let serialised = message.serialise(s).expect("Failed to serialise message");
println!(">> Serialised message: {:?}", serialised);
let binary = serialised.into_bytes().expect("Failed to serialize to bytes");
let binary = serialised.into_bytes();
println!(">> Written to {} bytes", binary.len());
let read = SerializedMessage::from_bytes(binary).expect("Failed to read serialised message from binary");
let read = SerializedMessage::from_bytes(&binary).expect("Failed to read serialised message from binary");
println!(">> Read from bytes: {:?}", read);
let deserialised = Message::deserialize(&read, d).expect("Failed to deserialise message");
println!(">> Deserialised message: {:?}", deserialised);
@ -60,7 +60,7 @@ fn message_serial_sign()
#[test]
fn message_serial_encrypt()
{
//color_eyre::install().unwrap();
color_eyre::install().unwrap();
let rsa_priv = rsa::RsaPrivateKey::generate().unwrap();
struct Dec(rsa::RsaPrivateKey);
struct Enc(rsa::RsaPublicKey);

@ -1,66 +1,32 @@
//! Message values
use super::*;
use std::mem;
use serde::{Serialize, Deserialize};
/// A value that can be used for messages.
pub trait MessageValue: Serialize + for<'de> Deserialize<'de>{}
impl<T: ?Sized> MessageValue for T
where T: Serialize + for<'de> Deserialize<'de>{}
/*
use std::any::Any;
use serde::de::DeserializeOwned;
#[derive(Debug, Serialize, Deserialize)]
pub struct DynamicMessageValue<T>(T)
where T: Any + 'static;
pub struct MessageValueAnyRef<'a>(&'a (dyn Any +'static));
pub struct MessageValueAnyMut<'a>(&'a mut (dyn Any +'static));
pub struct MessageValueAny(Box<dyn Any +'static>);
//impl<T> MessageValue for DynamicMessageValue<T>
//where T: Serialize + for<'de> Deserialize<'de> + Any{}
*/
/// A type-unsafe value that can be used to transmute `SerializedMessage` instances.
///
/// This operation is unsafe and can result in deserializing the `SerializedMessage` failing, or even worse, a type confusion.
///
/// This type does not implement `MessageValue`, as serialized messages of this value cannot be created (from `Message::serialize()`), nor deserialized. They must first be converted into a typed `SerializedMessage<V>` with `UntypedSerializedMessage::into_typed<V>()`.
///
/// This is an empty (!) type.
#[derive(Debug)]
pub enum UntypedMessageValue{}
impl SerializedMessage<UntypedMessageValue>
/// A value that can be used for messages.
pub trait MessageValue: Serialize + for<'de> Deserialize<'de>
{
/// Transmute into a specifically typed `SerializedMessage`
///
/// # Safety
/// If `V` is not the original type of this message (before being untyped), then deserialisation will likely fail, or, much worse, cause a *type consufion* bug, where the object of type `V` is successfully deserialized with an invalid value from an unknown type.
/// Take special care when using this.
pub const unsafe fn into_typed<V: MessageValue + ?Sized>(self) -> SerializedMessage<V>
{
mem::transmute(self)
}
fn as_dynamic(&self) -> Option<MessageValueAnyRef<'_>> { None }
fn as_dynamic_mut(&mut self) -> Option<MessageValueAnyMut<'_>> { None }
fn into_dynamic(self) -> Result<MessageValueAny, Self> { Err(self) }
}
impl<V: MessageValue + ?Sized> SerializedMessage<V>
default impl<T: ?Sized> MessageValue for T
where T: Serialize + for<'de> Deserialize<'de>
{
/// Consume this value into an untyped `SerializedMessage`.
///
/// # Safety
/// This operation is safe, however, doing anything at all with the resulting `SerializedMessage` is not.
/// It must be unsafly converted into a types message before it can be deserialized.
pub const fn into_untyped(self) -> SerializedMessage<UntypedMessageValue>
{
unsafe {
mem::transmute(self)
}
}
default fn as_dynamic(&self) -> Option<MessageValueAnyRef<'_>> { None }
default fn as_dynamic_mut(&mut self) -> Option<MessageValueAnyMut<'_>> { None }
default fn into_dynamic(self) -> Result<MessageValueAny, Self> { Err(self) }
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DynamicMessageValue<T>(T)
where T: Any + 'static;
//impl<T> MessageValue for DynamicMessageValue<T>
//where T: Serialize + for<'de> Deserialize<'de> + Any{}

@ -0,0 +1,30 @@
//! Socket handlers
use super::*;
use tokio::io::{
AsyncWrite,
AsyncRead
};
use tokio::task::JoinHandle;
use futures::Future;
use cancel::*;
/// Handles a raw, opened socket
pub fn handle_socket_with_shutdown<R, W, C: cancel::CancelFuture + 'static + Send>(tx: W, rx: R, shutdown: C) -> JoinHandle<eyre::Result<()>>
where R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static
{
tokio::spawn(async move {
match {
with_cancel!(async move {
//TODO: How to handle reads+writes?
Ok(())
}, shutdown)
} {
Ok(v) => v,
Err(x) => Err(eyre::Report::from(x)),
}
})
}

@ -1,874 +0,0 @@
//! Socket encryption wrapper
use super::*;
use cryptohelpers::{
rsa::{
self,
RsaPublicKey,
RsaPrivateKey,
openssl::{
symm::Crypter,
error::ErrorStack,
},
},
sha256,
};
use chacha20stream::{
AsyncSink,
AsyncSource,
Key, IV,
cha,
};
use std::sync::Arc;
use tokio::{
sync::{
RwLock,
RwLockReadGuard,
RwLockWriteGuard,
},
};
use std::{
io,
fmt,
task::{
Context, Poll,
},
pin::Pin,
marker::Unpin,
};
use smallvec::SmallVec;
/// Size of a single RSA ciphertext.
pub const RSA_CIPHERTEXT_SIZE: usize = 512;
/// A single, full block of RSA ciphertext.
type RsaCiphertextBlock = [u8; RSA_CIPHERTEXT_SIZE];
/// Max size to read when exchanging keys
const TRANS_KEY_MAX_SIZE: usize = 4096;
/// Encrypted socket information.
#[derive(Debug)]
struct ESockInfo {
us: RsaPrivateKey,
them: Option<RsaPublicKey>,
}
impl ESockInfo
{
/// Generate a new private key
pub fn new(us: impl Into<RsaPrivateKey>) -> Self
{
Self {
us: us.into(),
them: None,
}
}
/// Generate a new private key for the local endpoint
pub fn generate() -> Result<Self, rsa::Error>
{
Ok(Self::new(RsaPrivateKey::generate()?))
}
}
/// The encryption state of the Tx and Rx instances.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
struct ESockState {
encr: bool,
encw: bool,
}
impl Default for ESockState
{
#[inline]
fn default() -> Self
{
Self {
encr: false,
encw: false,
}
}
}
/// Contains a cc20 Key and IV that can be serialized and then encrypted
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct ESockSessionKey
{
key: Key,
iv: IV,
}
impl fmt::Display for ESockSessionKey
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "Key: {}, IV: {}", self.key.hex(), self.iv.hex())
}
}
impl ESockSessionKey
{
/// Generate a new cc20 key + iv,
pub fn generate() -> Self
{
let (key,iv) = cha::keygen();
Self{key,iv}
}
/// Generate an encryption device
pub fn to_decrypter(&self) -> Result<Crypter, ErrorStack>
{
cha::decrypter(&self.key, &self.iv)
}
/// Generate an encryption device
pub fn to_encrypter(&self) -> Result<Crypter, ErrorStack>
{
cha::encrypter(&self.key, &self.iv)
}
/// Encrypt with RSA
pub fn to_ciphertext<K: ?Sized + rsa::PublicKey>(&self, rsa_key: &K) -> eyre::Result<RsaCiphertextBlock>
{
let mut output = [0u8; RSA_CIPHERTEXT_SIZE];
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new(); // We know size will fit into here.
serde_cbor::to_writer(&mut temp, self)
.wrap_err(eyre!("Failed to CBOR encode session key to buffer"))
.with_section(|| self.clone().header("Session key was"))?;
debug_assert!(temp.len() < RSA_CIPHERTEXT_SIZE);
let _wr = rsa::encrypt_slice_sync(&temp, rsa_key, &mut &mut output[..])
.wrap_err(eyre!("Failed to encrypt session key with RSA public key"))
.with_section(|| self.clone().header("Session key was"))
.with_section({let temp = temp.len(); move || temp.header("Encoded data size was")})
.with_section(move || base64::encode(temp).header("Encoded data (base64) was"))?;
debug_assert_eq!(_wr, output.len());
Ok(output)
}
/// Decrypt from RSA
pub fn from_ciphertext<K: ?Sized + rsa::PrivateKey>(data: &[u8; RSA_CIPHERTEXT_SIZE], rsa_key: &K) -> eyre::Result<Self>
where <K as rsa::PublicKey>::KeyType: rsa::openssl::pkey::HasPrivate //ugh, why do we have to have this bound??? it should be implied ffs... :/
{
let mut temp = SmallVec::<[u8; RSA_CIPHERTEXT_SIZE]>::new();
rsa::decrypt_slice_sync(data, rsa_key, &mut temp)
.wrap_err(eyre!("Failed to decrypt ciphertext to session key"))
.with_section({let data = data.len(); move || data.header("Ciphertext length was")})
.with_section(|| base64::encode(data).header("Ciphertext was"))?;
Ok(serde_cbor::from_slice(&temp[..])
.wrap_err(eyre!("Failed to decode CBOR data to session key object"))
.with_section({let temp = temp.len(); move || temp.header("Encoded data size was")})
.with_section(move || base64::encode(temp).header("Encoded data (base64) was"))?)
}
}
/// A tx+rx socket.
#[pin_project]
#[derive(Debug)]
pub struct ESock<W, R> {
info: ESockInfo,
state: ESockState,
#[pin]
rx: AsyncSource<R>,
#[pin]
tx: AsyncSink<W>,
}
impl<W: AsyncWrite, R: AsyncRead> ESock<W, R>
{
fn inner(&self) -> (&W, &R)
{
(self.tx.inner(), self.rx.inner())
}
fn inner_mut(&mut self) -> (&mut W, &mut R)
{
(self.tx.inner_mut(), self.rx.inner_mut())
}
///Get a mutable ref to unencrypted read+write
fn unencrypted(&mut self) -> (&mut W, &mut R)
{
(self.tx.inner_mut(), self.rx.inner_mut())
}
/// Get a mutable ref to encrypted write+read
fn encrypted(&mut self) -> (&mut AsyncSink<W>, &mut AsyncSource<R>)
{
(&mut self.tx, &mut self.rx)
}
/// Have the RSA keys been exchanged?
pub fn has_exchanged(&self) -> bool
{
self.info.them.is_some()
}
/// Is the Write + Read operation encrypted? Tuple is `(Tx, Rx)`.
#[inline] pub fn is_encrypted(&self) -> (bool, bool)
{
(self.state.encw, self.state.encr)
}
/// Create a new `ESock` wrapper over this writer and reader with this specific RSA key.
pub fn with_key(key: impl Into<RsaPrivateKey>, tx: W, rx: R) -> Self
{
let (tk, tiv) = cha::keygen();
Self {
info: ESockInfo::new(key),
state: Default::default(),
// Note: These key+IV pairs are never used, as `state` defaults to unencrypted, and a new key/iv pair is generated when we `set_encrypted_write/read(true)`.
// TODO: Have a method to exchange these default session keys after `exchange()`?
tx: AsyncSink::encrypt(tx, tk, tiv).expect("Failed to create temp AsyncSink"),
rx: AsyncSource::encrypt(rx, tk, tiv).expect("Failed to create temp AsyncSource"),
}
}
/// Create a new `ESock` wrapper over this writer and reader with a newly generated private key
#[inline] pub fn new(tx: W, rx: R) -> Result<Self, rsa::Error>
{
Ok(Self::with_key(RsaPrivateKey::generate()?, tx, rx))
}
/// The local RSA private key
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
{
&self.info.us
}
/// THe remote RSA public key (if exchange has happened.)
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
{
self.info.them.as_ref()
}
/// Split this `ESock` into a read+write pair.
///
/// # Note
/// You must preform an `exchange()` before splitting, as exchanging RSA keys is not possible on a single half.
///
/// It is also more efficient to `set_encrypted_write/read(true)` on `ESock` than it is on the halves, but changinc encryption modes on halves is still possible.
pub fn split(self) -> (ESockWriteHalf<W>, ESockReadHalf<R>)
{
let arced = Arc::new(self.info);
(ESockWriteHalf(Arc::clone(&arced), self.tx, self.state.encw),
ESockReadHalf(arced, self.rx, self.state.encr))
}
/// Merge a previously split `ESock` into a single one again.
///
/// # Panics
/// If the two halves were not split from the same `ESock`.
pub fn unsplit(txh: ESockWriteHalf<W>, rxh: ESockReadHalf<R>) -> Self
{
#[cold]
#[inline(never)]
fn _panic_ptr_ineq() -> !
{
panic!("Cannot merge halves split from different sources")
}
if !Arc::ptr_eq(&txh.0, &rxh.0) {
_panic_ptr_ineq();
}
let tx = txh.1;
drop(txh.0);
let info = Arc::try_unwrap(rxh.0).unwrap();
let rx = rxh.1;
Self {
state: ESockState {
encw: txh.2,
encr: rxh.2,
},
info,
tx, rx
}
}
}
async fn set_encrypted_write_for<T: AsyncWrite + Unpin>(info: &ESockInfo, tx: &mut AsyncSink<T>) -> eyre::Result<()>
{
use tokio::prelude::*;
let session_key = ESockSessionKey::generate();
let data = {
let them = info.them.as_ref().expect("Cannot set encrypted write when keys have not been exchanged");
session_key.to_ciphertext(them)
.wrap_err(eyre!("Failed to encrypt session key with foreign endpoint's key"))
.with_section(|| session_key.to_string().header("Session key was"))
.with_section(|| them.to_string().header("Foreign pubkey was"))?
};
let crypter = session_key.to_encrypter()
.wrap_err(eyre!("Failed to create encryption device from session key for Tx"))
.with_section(|| session_key.to_string().header("Session key was"))?;
// Send rsa `data` over unencrypted endpoint
tx.inner_mut().write_all(&data[..]).await
.wrap_err(eyre!("Failed to write ciphertext to endpoint"))
.with_section(|| data.to_base64_string().header("Ciphertext of session key was"))?;
// Set crypter of `tx` to `session_key`.
*tx.crypter_mut() = crypter;
Ok(())
}
async fn set_encrypted_read_for<T: AsyncRead + Unpin>(info: &ESockInfo, rx: &mut AsyncSource<T>) -> eyre::Result<()>
{
use tokio::prelude::*;
let mut data = [0u8; RSA_CIPHERTEXT_SIZE];
// Read `data` from unencrypted endpoint
rx.inner_mut().read_exact(&mut data[..]).await
.wrap_err(eyre!("Failed to read ciphertext from endpoint"))?;
// Decrypt `data`
let session_key = ESockSessionKey::from_ciphertext(&data, &info.us)
.wrap_err(eyre!("Failed to decrypt session key from ciphertext"))
.with_section(|| data.to_base64_string().header("Ciphertext was"))
.with_section(|| info.us.to_string().header("Our RSA key is"))?;
// Set crypter of `rx` to `session_key`.
*rx.crypter_mut() = session_key.to_decrypter()
.wrap_err(eyre!("Failed to create decryption device from session key for Rx"))
.with_section(|| session_key.to_string().header("Decrypted session key was"))?;
Ok(())
}
impl<W: AsyncWrite+ Unpin, R: AsyncRead + Unpin> ESock<W, R>
{
/// Get the Tx and Rx of the stream.
///
/// # Returns
/// Returns encrypted stream halfs if the stream is encrypted, unencrypted if not.
pub fn stream(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{
(if self.state.encw {
&mut self.tx
} else {
self.tx.inner_mut()
}, if self.state.encr {
&mut self.rx
} else {
self.rx.inner_mut()
})
}
/// Enable write encryption
pub async fn set_encrypted_write(&mut self, set: bool) -> eyre::Result<()>
{
if set {
set_encrypted_write_for(&self.info, &mut self.tx).await?;
// Set `encw` to true
self.state.encw = true;
Ok(())
} else {
self.state.encw = false;
Ok(())
}
}
/// Enable read encryption
///
/// The other endpoint must have sent a `set_encrypted_write()`
pub async fn set_encrypted_read(&mut self, set: bool) -> eyre::Result<()>
{
if set {
set_encrypted_read_for(&self.info, &mut self.rx).await?;
// Set `encr` to true
self.state.encr = true;
Ok(())
} else {
self.state.encr = false;
Ok(())
}
}
/// Get dynamic ref to unencrypted write+read
fn unencrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{
(self.tx.inner_mut(), self.rx.inner_mut())
}
/// Get dynamic ref to encrypted write+read
fn encrypted_dyn(&mut self) -> (&mut (dyn AsyncWrite + Unpin + '_), &mut (dyn AsyncRead + Unpin + '_))
{
(&mut self.tx, &mut self.rx)
}
/// Exchange keys.
pub async fn exchange(&mut self) -> eyre::Result<()>
{
use tokio::prelude::*;
let our_key = self.info.us.get_public_parts();
let (tx, rx) = self.inner_mut();
let read_fut = {
async move {
// Read the public key from `rx`.
//TODO: Find pubkey max size.
let mut sz_buf = [0u8; std::mem::size_of::<u64>()];
rx.read_exact(&mut sz_buf[..]).await
.wrap_err(eyre!("Failed to read size of pubkey form endpoint"))?;
let sz64 = u64::from_be_bytes(sz_buf);
let sz= match usize::try_from(sz64)
.wrap_err(eyre!("Read size could not fit into u64"))
.with_section(|| format!("{:?}", sz_buf).header("Read buffer was"))
.with_section(|| u64::from_be_bytes(sz_buf).header("64=bit size value was"))
.with_warning(|| "This should not happen, it is only possible when you are running a machine with a pointer size lower than 64 bits.")
.with_suggestion(|| "The message is likely malformed. If it is not, then you are communicating with an endpoint of 64 bits whereas your pointer size is far less.")? {
x if x > TRANS_KEY_MAX_SIZE => return Err(eyre!("Recv'd key size exceeded max acceptable key buffer size")),
x => x
};
let mut key_bytes = Vec::with_capacity(sz);
tokio::io::copy(&mut rx.take(sz64), &mut key_bytes).await
.wrap_err("Failed to read key bytes into buffer")
.with_section(move || sz64.header("Pubkey size to read was"))?;
if key_bytes.len() != sz {
return Err(eyre!("Could not read required bytes"));
}
let k = RsaPublicKey::from_bytes(&key_bytes)
.wrap_err("Failed to construct RSA public key from read bytes")
.with_section(|| sz.header("Pubkey size was"))
.with_section(move || key_bytes.to_base64_string().header("Pubkey bytes were"))?;
Result::<RsaPublicKey, eyre::Report>::Ok(k)
}
};
let write_fut = {
let key_bytes = our_key.to_bytes();
assert!(key_bytes.len() <= TRANS_KEY_MAX_SIZE);
let sz64 = u64::try_from(key_bytes.len())
.wrap_err(eyre!("Size of our pubkey could not fit into u64"))
.with_section(|| key_bytes.len().header("Size was"))
.with_warning(|| "This should not happen, it is only possible when you are running a machine with a pointer size larger than 64 bits.")
.with_warning(|| "There was likely internal memory corruption.")?;
let sz_buf = sz64.to_be_bytes();
async move {
tx.write_all(&sz_buf[..]).await
.wrap_err(eyre!("Failed to write key size"))
.with_section(|| sz64.header("Key size bytes were"))
.with_section(|| format!("{:?}", sz_buf).header("Key size bytes (BE) were"))?;
tx.write_all(&key_bytes[..]).await
.wrap_err(eyre!("Failed to write key bytes"))
.with_section(|| sz64.header("Size of key was"))
.with_section(|| key_bytes.to_base64_string().header("Key bytes are"))?;
Result::<(), eyre::Report>::Ok(())
}
};
let (send, recv) = tokio::join! [write_fut, read_fut];
send.wrap_err("Failed to send our pubkey")?;
let recv = recv.wrap_err("Failed to receive foreign pubkey")?;
self.info.them = Some(recv);
Ok(())
}
}
//XXX: For some reason, non-exact reads + writes cause garbage to be produced on the receiving end?
// Is this fixable? Why does it disjoint? I have no idea... This is supposed to be a stream cipher, right? Why does positioning matter? Have I misunderstood how it workd? Eh...
// With this bug, it seems the `while read(buffer) > 0` construct is impossible. This might make this entirely useless. Hopefully with the rigid size-based format for `Message` we won't run into this problem, but subsequent data streaming will likely be affected unless we use rigid, fixed, and (inefficiently) communicated buffer sizes.
impl<W, R> AsyncWrite for ESock<W, R>
where W: AsyncWrite
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
if self.state.encw {
self.project().tx.poll_write(cx, buf)
} else {
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self.map_unchecked_mut(|this| this.tx.inner_mut()).poll_write(cx, buf)}
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
// Should we do anything else here?
// Should we clear foreign key/current session key?
self.project().tx.poll_shutdown(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().tx.poll_flush(cx)
}
}
impl<W, R> AsyncRead for ESock<W, R>
where R: AsyncRead
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
//XXX: If the encryption state of the socket is changed between polls, this breaks. Idk if we can do anything about that tho.
if self.state.encr {
self.project().rx.poll_read(cx, buf)
} else {
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self.map_unchecked_mut(|this| this.rx.inner_mut()).poll_read(cx, buf)}
}
}
}
/// Write half for `ESock`.
#[pin_project]
#[derive(Debug)]
pub struct ESockWriteHalf<W>(Arc<ESockInfo>, #[pin] AsyncSink<W>, bool);
/// Read half for `ESock`.
#[pin_project]
#[derive(Debug)]
pub struct ESockReadHalf<R>(Arc<ESockInfo>, #[pin] AsyncSource<R>, bool);
//Impl AsyncRead/Write + set_encrypted_read/write for ESockRead/WriteHalf.
impl<W: AsyncWrite> ESockWriteHalf<W>
{
/// Does this write half have a live corresponding read half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional(&self) -> bool
{
Arc::strong_count(&self.0) > 1
}
/// Is write encrypted on this half?
#[inline(always)] pub fn is_encrypted(&self) -> bool
{
self.2
}
/// The local RSA private key
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
{
&self.0.us
}
/// THe remote RSA public key (if exchange has happened.)
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
{
self.0.them.as_ref()
}
/// End an encrypted session syncronously.
///
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
pub fn clear_encryption(&mut self)
{
self.2 = false;
}
}
impl<R: AsyncRead> ESockReadHalf<R>
{
/// Does this read half have a live corresponding write half?
///
/// It's not required to have one, however, exchange is not possible without since it requires sticking the halves back together.
pub fn is_bidirectional(&self) -> bool
{
Arc::strong_count(&self.0) > 1
}
/// Is write encrypted on this half?
#[inline(always)] pub fn is_encrypted(&self) -> bool
{
self.2
}
/// The local RSA private key
#[inline] pub fn local_key(&self) -> &RsaPrivateKey
{
&self.0.us
}
/// THe remote RSA public key (if exchange has happened.)
#[inline] pub fn foreign_key(&self) -> Option<&RsaPublicKey>
{
self.0.them.as_ref()
}
/// End an encrypted session syncronously.
///
/// Same as calling `set_encryption(false).now_or_never()`, but more efficient.
pub fn clear_encryption(&mut self)
{
self.2 = false;
}
}
impl<W: AsyncWrite + Unpin> ESockWriteHalf<W>
{
/// Begin or end an encrypted writing session
pub async fn set_encryption(&mut self, set: bool) -> eyre::Result<()>
{
if set {
set_encrypted_write_for(&self.0, &mut self.1).await?;
self.2 = true;
} else {
self.2 = false;
}
Ok(())
}
}
impl<R: AsyncRead + Unpin> ESockReadHalf<R>
{
/// Begin or end an encrypted reading session
pub async fn set_encryption(&mut self, set: bool) -> eyre::Result<()>
{
if set {
set_encrypted_read_for(&self.0, &mut self.1).await?;
self.2 = true;
} else {
self.2 = false;
}
Ok(())
}
}
impl<W: AsyncWrite> AsyncWrite for ESockWriteHalf<W>
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
if self.2 {
// Encrypted
self.project().1.poll_write(cx, buf)
} else {
// Unencrypted
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self.map_unchecked_mut(|this| this.1.inner_mut()).poll_write(cx, buf)}
}
}
#[inline(always)] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().1.poll_flush(cx)
}
#[inline(always)] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().1.poll_flush(cx)
}
}
impl<R: AsyncRead> AsyncRead for ESockReadHalf<R>
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
if self.2 {
// Encrypted
self.project().1.poll_read(cx, buf)
} else {
// Unencrypted
// SAFETY: Uhh... well I think this is fine? Because we can project the container.
// TODO: Can we project the `tx`? Or maybe add a method in `AsyncSink` to map a pinned sink to a `Pin<&mut W>`?
unsafe { self.map_unchecked_mut(|this| this.1.inner_mut()).poll_read(cx, buf)}
}
}
}
#[cfg(test)]
mod tests
{
use super::ESock;
#[test]
fn rsa_ciphertext_len() -> crate::eyre::Result<()>
{
let data = {
use chacha20stream::cha::{KEY_SIZE, IV_SIZE};
let (key, iv) = chacha20stream::cha::keygen();
let (sz, d) = crate::bin::collect_slices_exact::<&[u8], _, {KEY_SIZE + IV_SIZE}>([key.as_ref(), iv.as_ref()]);
assert_eq!(sz, d.len());
d
};
println!("KEY+IV: {} bytes", data.len());
let key = cryptohelpers::rsa::RsaPublicKey::generate()?;
let rsa = cryptohelpers::rsa::encrypt_slice_to_vec(data, &key)?;
println!("Rsa ciphertext size: {}", rsa.len());
assert_eq!(rsa.len(), super::RSA_CIPHERTEXT_SIZE, "Incorrect RSA ciphertext length constant for cc20 KEY+IV encoding.");
Ok(())
}
#[test]
fn rsa_serial_ciphertext_len() -> crate::eyre::Result<()>
{
let data = serde_cbor::to_vec(&{
let (key, iv) = chacha20stream::cha::keygen();
super::ESockSessionKey {
key, iv,
}
}).expect("Failed to CBOR encode Key+IV");
println!("(cbor) KEY+IV: {} bytes", data.len());
let key = cryptohelpers::rsa::RsaPublicKey::generate()?;
let rsa = cryptohelpers::rsa::encrypt_slice_to_vec(data, &key)?;
println!("Rsa ciphertext size: {}", rsa.len());
assert_eq!(rsa.len(), super::RSA_CIPHERTEXT_SIZE, "Incorrect RSA ciphertext length constant for cc20 KEY+IV CBOR encoding.");
Ok(())
}
fn gen_duplex_esock(bufsz: usize) -> crate::eyre::Result<(ESock<tokio::io::DuplexStream, tokio::io::DuplexStream>, ESock<tokio::io::DuplexStream, tokio::io::DuplexStream>)>
{
use crate::*;
let (atx, brx) = tokio::io::duplex(bufsz);
let (btx, arx) = tokio::io::duplex(bufsz);
let tx = ESock::new(atx, arx).wrap_err(eyre!("Failed to create TX"))?;
let rx = ESock::new(btx, brx).wrap_err(eyre!("Failed to create RX"))?;
Ok((tx, rx))
}
#[tokio::test]
async fn esock_exchange() -> crate::eyre::Result<()>
{
use crate::*;
const VALUE: &'static [u8] = b"Hello world!";
// The duplex buffer size here is smaller than an RSA ciphertext block. So, writing the session key must be buffered with a buffer size this small (should return Pending at least once.)
// Using a low buffer size to make sure the test passes even when the entire buffer cannot be written at once.
let (mut tx, mut rx) = gen_duplex_esock(16).wrap_err(eyre!("Failed to weave socks"))?;
let writer = tokio::spawn(async move {
use tokio::prelude::*;
tx.exchange().await?;
assert!(tx.has_exchanged());
tx.set_encrypted_write(true).await?;
assert_eq!((true, false), tx.is_encrypted());
tx.write_all(VALUE).await?;
tx.write_all(VALUE).await?;
// Check resp
tx.set_encrypted_read(true).await?;
assert_eq!({
let mut chk = [0u8; 3];
tx.read_exact(&mut chk[..]).await?;
chk
}, [0xaau8,0, 0], "Failed response check");
// Write unencrypted
tx.set_encrypted_write(false).await?;
tx.write_all(&[2,1,0xfa]).await?;
Result::<_, eyre::Report>::Ok(VALUE)
});
let reader = tokio::spawn(async move {
use tokio::prelude::*;
rx.exchange().await?;
assert!(rx.has_exchanged());
rx.set_encrypted_read(true).await?;
assert_eq!((false, true), rx.is_encrypted());
let mut val = vec![0u8; VALUE.len()];
rx.read_exact(&mut val[..]).await?;
let mut val2 = vec![0u8; VALUE.len()];
rx.read_exact(&mut val2[..]).await?;
assert_eq!(val, val2);
// Send resp
rx.set_encrypted_write(true).await?;
rx.write_all(&[0xaa, 0, 0]).await?;
// Read unencrypted
rx.set_encrypted_read(false).await?;
assert_eq!({
let mut buf = [0u8; 3];
rx.read_exact(&mut buf[..]).await?;
buf
}, [2u8,1,0xfa], "2nd response incorrect");
Result::<_, eyre::Report>::Ok(val)
});
let (writer, reader) = tokio::join![writer, reader];
let writer = writer.expect("Tx task panic");
let reader = reader.expect("Rx task panic");
eprintln!("Txr: {:?}", writer);
eprintln!("Rxr: {:?}", reader);
writer?;
let val = reader?;
println!("Read: {:?}", val);
assert_eq!(&val, VALUE);
Ok(())
}
#[tokio::test]
async fn esock_split() -> crate::eyre::Result<()>
{
use super::*;
const SLICES: &'static [&'static [u8]] = &[
&[1,5,3,7,6,9,100,0],
&[7,6,2,90],
&[3,6,1,0],
&[5,1,3,3],
];
let result = SLICES.iter().map(|&slice| slice.iter().map(|&b| u64::from(b)).sum::<u64>()).sum::<u64>();
println!("Result: {}", result);
let (mut tx, mut rx) = gen_duplex_esock(super::TRANS_KEY_MAX_SIZE * 4).wrap_err(eyre!("Failed to weave socks"))?;
let (writer, reader) = {
use tokio::prelude::*;
let writer = tokio::spawn(async move {
tx.exchange().await?;
let (mut tx, mut rx) = tx.split();
//tx.set_encryption(true).await?;
let slices = &SLICES[1..];
for &slice in slices.iter()
{
println!("Writing slice: {:?}", slice);
tx.write_all(slice).await?;
}
//let mut tx = ESock::unsplit(tx, rx);
tx.write_all(SLICES[0]).await?;
Result::<_, eyre::Report>::Ok(())
});
let reader = tokio::spawn(async move {
rx.exchange().await?;
let (mut tx, mut rx) = rx.split();
//rx.set_encryption(true).await?;
let (mut mtx, mut mrx) = tokio::sync::mpsc::channel::<Vec<u8>>(16);
let sorter = tokio::spawn(async move {
let mut done = 0u64;
while let Some(buf) = mrx.recv().await
{
//buf.sort();
done += buf.iter().map(|&b| u64::from(b)).sum::<u64>();
println!("Got buffer: {:?}", buf);
tx.write_all(&buf).await?;
}
Result::<_, eyre::Report>::Ok(done)
});
let mut buffer = [0u8; 16];
while let Ok(read) = rx.read(&mut buffer[..]).await
{
if read == 0 {
break;
}
mtx.send(Vec::from(&buffer[..read])).await?;
}
drop(mtx);
let sum = sorter.await.expect("(reader) Sorter task panic")?;
Result::<_, eyre::Report>::Ok(sum)
});
let (writer, reader) = tokio::join![writer, reader];
(writer.expect("Writer task panic"),
reader.expect("Reader task panic"))
};
writer?;
assert_eq!(result, reader?);
Ok(())
}
}

@ -1,93 +0,0 @@
//! Socket handlers
use super::*;
use std::collections::HashSet;
use std::net::{
SocketAddr,
};
use tokio::io::{
AsyncWrite,
AsyncRead
};
use tokio::task::JoinHandle;
use tokio::sync::{
mpsc,
oneshot,
};
use futures::Future;
use bytes::Bytes;
use cancel::*;
pub mod enc;
pub mod pipe;
/// Details of a newly accepted raw socket peer.
///
/// This connected will have been "accepted", but not yet trusted
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RawSockPeerAccepted
{
/// Address of socket.
pub addr: SocketAddr,
/// Trust this peer from the start? This should almost always be false.
pub auto_trust: bool,
}
/// Details of a connected, set-up raw socket connection.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RawSockPeerTrusted
{
/// The socket's details
pub sock_details: RawSockPeerAccepted,
/// Capabilities for this peer
pub cap_allow: HashSet<cap::RawSockCapability>,
}
/// A raw, received message from a `RawPeer`.
#[derive(Debug)]
pub struct RawMessage{
message_bytes: Bytes,
}
/// A connected raw peer, created and handled by `handle_new_socket_with_shutdown()`.
#[derive(Debug)]
pub struct RawPeer{
info: RawSockPeerTrusted,
rx: mpsc::Receiver<RawMessage>,
}
/// Handles a **newly connected** raw socket.
///
/// This will handle setting up socket peer encryption and validation.
pub fn handle_new_socket_with_shutdown<R, W, C: cancel::CancelFuture + 'static + Send>(
sock_details: RawSockPeerAccepted,
set_peer: oneshot::Sender<RawPeer>,
tx: W, rx: R,
shutdown: C
) -> JoinHandle<eyre::Result<()>>
where R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static
{
tokio::spawn(async move {
match {
with_cancel!(async move {
// Create empty cap
let mut sock_details = RawSockPeerTrusted {
sock_details,
cap_allow: HashSet::new(),
};
// Set up encryption
//TODO: Find caps for this peer.
Ok(())
}, shutdown)
} {
Ok(v) => v,
Err(x) => Err(eyre::Report::from(x)),
}
})
}

@ -1,53 +0,0 @@
//! Piping buffered data from a raw socket to `ESock`
//!
//! This exists because i'm too dumb to implement a functional AsyncRead/Write buffered wrapper stream :/
use super::*;
use std::{
io,
marker::{
Send, Sync,
Unpin,
PhantomData,
},
};
use tokio::sync::{
mpsc,
};
use enc::{
ESock,
ESockReadHalf,
ESockWriteHalf,
};
/// The default buffer size for `BufferedESock`.
pub const DEFAULT_BUFFER_SIZE: usize = 32;
/// Task-based buffered piping to/from encrypted sockets.
pub struct BufferedESock<W, R>
{
bufsz: usize,
_backing: PhantomData<ESock<W, R>>,
}
impl<W, R> BufferedESock<W, R>
where W: AsyncWrite + Unpin + Send + 'static,
R: AsyncRead + Unpin + Send + 'static
{
/// Create a new buffered ESock pipe with a specific buffer size
pub fn with_size(tx: W, rx: R, bufsz: usize) -> Self
{
//TODO: Spawn read+write buffer tasks
Self {
bufsz,
_backing: PhantomData,
}
}
/// Create a new buffered ESock pipe with the default buffer size (`DEFAULT_BUFFER_SIZE`).
#[inline] pub fn new(tx: W, rx: R) -> Self
{
Self::with_size(tx, rx, DEFAULT_BUFFER_SIZE)
}
}
Loading…
Cancel
Save