You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

437 lines
11 KiB

//! Extensions
use std::{
fmt,
error,
pin::Pin,
task::{Poll,Context,},
};
use tokio::{
io::AsyncRead,
prelude::*,
};
use futures::future::Future;
pub trait JoinStrsExt: Sized
{
/// Join an iterator of `str` with a seperator
fn join(self, with: &str) -> String;
}
impl<T,I> JoinStrsExt for I
where I: Iterator<Item=T>,
T: AsRef<str>
{
fn join(self, with: &str) -> String
{
let mut output = String::new();
let mut first=true;
for string in self
{
if !first {
output.push_str(with);
}
let string = string.as_ref();
output.push_str(string);
first=false;
}
output
}
}
/*macro_rules! typed_swap {
(@ [] $($reversed:tt)*) => {
fn swap(self) -> ($($reversed)*);
};
(@ [$first:tt $($rest:tt)*] $($reversed:tt)*) => {
typed_swap!{@ [$($rest)*] $first $($reversed)*}
};
(@impl {$($body:tt)*} [] $($reversed:tt)*) => {
fn swap(self) -> ($($reversed)*)
{
$($body)*
}
};
(@impl {$($body:tt)*} [$first:tt $($rest:tt)*] $($reversed:tt)*) => {
typed_swap!{@impl {$($body)*} [$($rest)*] $first $($reversed)*}
};
() => {};
({$($params:tt)*} $($rest:tt)*) => {
mod swap {
pub trait SwapTupleExt<$($params)*>: Sized
{
typed_swap!(@ [$($params)*]);
}
impl<$($params)*> SwapTupleExt<$($params)*> for ($($params)*)
{
typed_swap!(@impl {
todo!()
} [$($params)*]);
}
typed_swap!($($rest)*);
}
pub use swap::*;
};
(all $first:tt $($params:tt)+) => {
typed_swap!({$first, $($params),+});
mod nswap {
typed_swap!(all $($params)+);
}
};
(all $($one:tt)?) => {};
}
typed_swap!(all A B C D E F G H I J K L M N O P Q R S T U V W X Y Z);
pub use swap::*;
fn test()
{
let sw = (1, 2).swap();
}*/
// ^ unfortunately not lol
pub trait SwapTupleExt<T,U>: Sized
{
fn swap(self) -> (U,T);
}
impl<T,U> SwapTupleExt<T,U> for (T,U)
{
#[inline(always)] fn swap(self) -> (U,T) {
(self.1, self.0)
}
}
/*typed_swap!({A, B}
{A, U, V}
{T, U, V, W});*/
const ASCII_MAP: [char; 256] = [
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?',
'@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O',
'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
'`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
];
pub struct HexStringIter<'a, I>(&'a I, bool);
pub struct HexView<'a, I>(&'a I);
const SPLIT_EVERY: usize = 16;
impl<'a, I: AsRef<[u8]>> fmt::Display for HexView<'a, I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
use std::iter;
let mut abuf = ['\0'; SPLIT_EVERY];
let mut last_n =0 ;
for (i, (n, &byte)) in (0..).zip(iter::repeat(0..SPLIT_EVERY).flatten().zip(self.0.as_ref().iter()))
{
if n== 0 {
write!(f,"0x{:016x}\t", i)?;
}
abuf[n] = ASCII_MAP[byte as usize];
write!(f, "{:02x} ", byte)?;
if n==SPLIT_EVERY-1 {
write!(f, "\t\t")?;
for ch in abuf.iter().filter(|&x| *x!= '\0')
{
write!(f, "{}", ch)?;
}
writeln!(f)?;
abuf = ['\0'; SPLIT_EVERY];
}
last_n = n;
}
if last_n != SPLIT_EVERY-1
{
for _ in 0..(SPLIT_EVERY-last_n)
{
write!(f, " ")?;
}
write!(f, "\t\t")?;
for ch in abuf.iter().filter(|&x| *x!= '\0')
{
write!(f, "{}", ch)?;
}
writeln!(f)?;
}
Ok(())
}
}
impl<'a, I: AsRef<[u8]>> fmt::Display for HexStringIter<'a, I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
if self.1 {
let mut iter = self.0.as_ref().iter();
if let Some(byte) = iter.next()
{
write!(f, "{:02}", byte)?;
} else {
return Ok(())
}
for byte in iter {
write!(f, " {:02x}", byte)?;
}
} else {
for byte in self.0.as_ref().iter() {
write!(f, "{:02x}", byte)?;
}
}
Ok(())
}
}
pub trait HexStringExt: Sized + AsRef<[u8]>
{
fn fmt_view(&self) -> HexView<'_, Self>;
fn fmt_hex(&self) -> HexStringIter<'_, Self>;
fn to_hex_string(&self) -> String
{
let mut string = String::with_capacity(self.as_ref().len()*2);
use fmt::Write;
write!(&mut string, "{}", self.fmt_hex()).unwrap();
string
}
fn to_broken_hex_string(&self) -> String
{
let fmt = HexStringIter(
self.fmt_hex().0,
true
);
let mut string = String::with_capacity(self.as_ref().len()*3);
use fmt::Write;
write!(&mut string, "{}", fmt).unwrap();
string
}
}
impl<T: AsRef<[u8]>> HexStringExt for T
{
fn fmt_hex(&self) -> HexStringIter<'_, Self>
{
HexStringIter(&self, false)
}
fn fmt_view(&self) -> HexView<'_, Self>
{
HexView(&self)
}
}
#[pin_project]
pub struct ReadAllBytes<'a, T: AsyncRead+Unpin+?Sized>(#[pin] &'a mut T, Option<usize>);
impl<'a, T: AsyncRead+Unpin+?Sized> Future for ReadAllBytes<'a, T>
{
type Output = std::io::Result<Vec<u8>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output>
{
let fut = async move {
let this = self.project();
let mut output = Vec::with_capacity(4096*10);
let mut input = this.0;
let max = *this.1;
let mut buffer =[0u8; 4096];
let mut read;
while {read = input.read(&mut buffer[..]).await?; read!=0} {
output.extend_from_slice(&buffer[..read]);
if let Some(max) = max {
if output.len() >=max {
return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("Attempted to read more than allowed max {} bytes", max)));
}
}
}
Ok(output)
};
tokio::pin!(fut);
fut.poll(ctx)
}
}
pub trait ReadAllBytesExt: AsyncRead+Unpin
{
/// Attempt to read the whole stream to a new `Vec<u8>`.
fn read_whole_stream(&mut self, max: Option<usize>) -> ReadAllBytes<'_, Self>
{
ReadAllBytes(self, max)
}
}
impl<T: AsyncRead+Unpin+?Sized> ReadAllBytesExt for T{}
pub trait FromHexExt
{
fn repl_with_hex<U: AsRef<[u8]>>(&mut self, input: U) -> Result<(), HexDecodeError>;
}
impl<T: AsMut<[u8]>+?Sized> FromHexExt for T
{
fn repl_with_hex<U: AsRef<[u8]>>(&mut self, input: U) -> Result<(), HexDecodeError> {
let out = self.as_mut();
#[inline] fn val(c: u8, idx: usize) -> Result<u8, HexDecodeError> {
match c {
b'A'..=b'F' => Ok(c - b'A' + 10),
b'a'..=b'f' => Ok(c - b'a' + 10),
b'0'..=b'9' => Ok(c - b'0'),
_ => Err(HexDecodeError{
chr: c as char,
idx,
}),
}
}
for (i, (byte, digits)) in (0..).zip(out.iter_mut().zip(input.as_ref().chunks_exact(2)))
{
*byte = val(digits[0], 2*i)? << 4 | val(digits[1], 2 * i + 1)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct HexDecodeError {
idx: usize,
chr: char,
}
impl error::Error for HexDecodeError{}
impl fmt::Display for HexDecodeError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "Invalid hex at index {} (character was {:?})", self.idx, self.chr)
}
}
#[cfg(test)]
mod tests
{
use super::*;
fn format()
{
let bytes = b"hello world one two three \x142!";
panic!("\n{}\n", bytes.fmt_view());
}
#[test]
fn hex()
{
const INPUT_HEX: [u8; 32] = hex_literal::hex!("d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d");
const INPUT_STR: &str = "d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d";
let mut output = [0u8; 32];
output.repl_with_hex(INPUT_STR).expect("Failed!");
assert_eq!(&INPUT_HEX[..], &output[..]);
}
#[cfg(nightly)]
mod benchmarks
{
use super::*;
use test::{Bencher, black_box};
#[bench]
fn hex_via_val(b: &mut Bencher)
{
fn repl_with_hex<U: AsRef<[u8]>>(out: &mut [u8], input: U) -> Result<(), HexDecodeError> {
#[inline] fn val(c: u8, idx: usize) -> Result<u8, HexDecodeError> {
match c {
b'A'..=b'F' => Ok(c - b'A' + 10),
b'a'..=b'f' => Ok(c - b'a' + 10),
b'0'..=b'9' => Ok(c - b'0'),
_ => Err(HexDecodeError{
chr: c as char,
idx,
}),
}
}
for (i, (byte, digits)) in (0..).zip(out.iter_mut().zip(input.as_ref().chunks_exact(2)))
{
*byte = val(digits[0], 2*i)? << 4 | val(digits[1], 2 * i + 1)?;
}
Ok(())
}
const INPUT_HEX: [u8; 32] = hex_literal::hex!("d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d");
const INPUT_STR: &str = "d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d";
let mut output = [0u8; 32];
b.iter(|| {
black_box(repl_with_hex(&mut output[..], INPUT_STR).unwrap());
});
assert_eq!(&INPUT_HEX[..], &output[..]);
}
#[bench]
fn hex_via_lazy(b: &mut Bencher)
{
fn repl_with_hex<U: AsRef<[u8]>>(out: &mut [u8], input: U) -> Result<(), HexDecodeError> {
use smallmap::Map;
lazy_static::lazy_static! {
static ref MAP: Map<u8, u8> = {
let mut map = Map::new();
for c in 0..=255u8
{
map.insert(c, match c {
b'A'..=b'F' => c - b'A' + 10,
b'a'..=b'f' => c - b'a' + 10,
b'0'..=b'9' => c - b'0',
_ => continue,
});
}
map
};
}
#[inline(always)] fn val(c: u8, idx: usize) -> Result<u8, HexDecodeError> {
MAP.get(&c).copied()
.ok_or_else(|| HexDecodeError{idx, chr: c as char})
}
for (i, (byte, digits)) in (0..).zip(out.iter_mut().zip(input.as_ref().chunks_exact(2)))
{
*byte = val(digits[0], 2*i)? << 4 | val(digits[1], 2 * i + 1)?;
}
Ok(())
}
const INPUT_HEX: [u8; 32] = hex_literal::hex!("d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d");
const INPUT_STR: &str = "d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d";
let mut output = [0u8; 32];
b.iter(|| {
black_box(repl_with_hex(&mut output[..], INPUT_STR).unwrap());
});
assert_eq!(&INPUT_HEX[..], &output[..]);
}
}
}