You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

579 lines
14 KiB

//! Extensions
use std::{
fmt,
error,
pin::Pin,
task::{Poll,Context,},
ops::{
Range,
},
marker::{
PhantomData,
},
};
use tokio::{
io::AsyncRead,
prelude::*,
};
use futures::future::Future;
pub trait JoinStrsExt: Sized
{
/// Join an iterator of `str` with a seperator
fn join(self, with: &str) -> String;
}
impl<T,I> JoinStrsExt for I
where I: Iterator<Item=T>,
T: AsRef<str>
{
fn join(self, with: &str) -> String
{
let mut output = String::new();
let mut first=true;
for string in self
{
if !first {
output.push_str(with);
}
let string = string.as_ref();
output.push_str(string);
first=false;
}
output
}
}
/*macro_rules! typed_swap {
(@ [] $($reversed:tt)*) => {
fn swap(self) -> ($($reversed)*);
};
(@ [$first:tt $($rest:tt)*] $($reversed:tt)*) => {
typed_swap!{@ [$($rest)*] $first $($reversed)*}
};
(@impl {$($body:tt)*} [] $($reversed:tt)*) => {
fn swap(self) -> ($($reversed)*)
{
$($body)*
}
};
(@impl {$($body:tt)*} [$first:tt $($rest:tt)*] $($reversed:tt)*) => {
typed_swap!{@impl {$($body)*} [$($rest)*] $first $($reversed)*}
};
() => {};
({$($params:tt)*} $($rest:tt)*) => {
mod swap {
pub trait SwapTupleExt<$($params)*>: Sized
{
typed_swap!(@ [$($params)*]);
}
impl<$($params)*> SwapTupleExt<$($params)*> for ($($params)*)
{
typed_swap!(@impl {
todo!()
} [$($params)*]);
}
typed_swap!($($rest)*);
}
pub use swap::*;
};
(all $first:tt $($params:tt)+) => {
typed_swap!({$first, $($params),+});
mod nswap {
typed_swap!(all $($params)+);
}
};
(all $($one:tt)?) => {};
}
typed_swap!(all A B C D E F G H I J K L M N O P Q R S T U V W X Y Z);
pub use swap::*;
fn test()
{
let sw = (1, 2).swap();
}*/
// ^ unfortunately not lol
pub trait SwapTupleExt<T,U>: Sized
{
fn swap(self) -> (U,T);
}
impl<T,U> SwapTupleExt<T,U> for (T,U)
{
#[inline(always)] fn swap(self) -> (U,T) {
(self.1, self.0)
}
}
/*typed_swap!({A, B}
{A, U, V}
{T, U, V, W});*/
const ASCII_MAP: [char; 256] = [
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?',
'@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O',
'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
'`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
'.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.',
];
const fn create_hex_map() -> [(u8, u8); 256]
{
let mut out = [(0, 0); 256];
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut i = 0usize;
while i <= 255
{
out[i] = (
HEX[i >> 4],
HEX[i & 0xf]
);
i+=1;
}
out
}
const HEX_MAP: [(u8, u8); 256] = create_hex_map();
pub struct HexStringView<'a, I:?Sized>(&'a I, bool);
pub struct HexStringIter<'a, I:?Sized>(std::slice::Iter<'a, u8>, (u8, u8), PhantomData<&'a I>);
pub struct HexView<'a, I:?Sized>(&'a I);
pub struct AsciiView<'a, I:?Sized>(&'a I);
pub struct AsciiIter<'a, I:?Sized>(&'a [u8], PhantomData<&'a I>);
const SPLIT_EVERY: usize = 16;
impl<'a, I: ?Sized+AsRef<[u8]>> Iterator for AsciiIter<'a, I>
{
type Item = char;
fn next(&mut self) -> Option<Self::Item>
{
match match self.0 {
[] => None,
[chr, ..] => Some(ASCII_MAP[*chr as usize]),
} {
x @ Some(_) => {
self.0 = &self.0[1..];
x
},
_ => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.0.len(), Some(self.0.len()))
}
}
impl<'a, I: ?Sized+AsRef<[u8]>> ExactSizeIterator for AsciiIter<'a, I>{}
impl<'a, I: ?Sized+AsRef<[u8]>> std::iter::FusedIterator for AsciiIter<'a, I>{}
impl<'a, I: ?Sized+AsRef<[u8]>> Iterator for HexStringIter<'a, I>
{
type Item = char;
fn next(&mut self) -> Option<Self::Item>
{
match self.1 {
ref mut buf @ (0, 0) => {
// both are taken
if let Some(&byte) = self.0.next() {
*buf = HEX_MAP[byte as usize];
} else {
return None;
}
(Some(buf.0 as char),buf.0 = 0).0
},
(0, ref mut second) => {
// first is taken
(Some(*second as char),*second = 0).0
},
#[cold] (ref mut first, _) => {
// neither are taken, usually shouldn't happen
(Some(*first as char),*first = 0).0
},
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let sz = self.0.size_hint();
(sz.0 * 2, sz.1.map(|x| x*2))
}
}
impl<'a, I: ?Sized+AsRef<[u8]>> ExactSizeIterator for HexStringIter<'a, I>{}
impl<'a, I: ?Sized+AsRef<[u8]>> std::iter::FusedIterator for HexStringIter<'a, I>{}
impl<'a, I: AsRef<[u8]>+?Sized> fmt::Display for AsciiView<'a, I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
for byte in self.0.as_ref().iter().map(|&byte| ASCII_MAP[byte as usize])
{
use std::fmt::Write;
f.write_char(byte)?;
}
Ok(())
}
}
impl<'a, I: AsRef<[u8]>+?Sized> fmt::Display for HexView<'a, I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
use std::iter;
let mut abuf = ['\0'; SPLIT_EVERY];
let mut last_n =0 ;
for (i, (n, &byte)) in (0..).zip(iter::repeat(0..SPLIT_EVERY).flatten().zip(self.0.as_ref().iter()))
{
if n== 0 {
write!(f,"0x{:016x}\t", i)?;
}
abuf[n] = ASCII_MAP[byte as usize];
write!(f, "{:02x} ", byte)?;
if n==SPLIT_EVERY-1 {
write!(f, "\t\t")?;
for ch in abuf.iter().filter(|&x| *x!= '\0')
{
write!(f, "{}", ch)?;
}
writeln!(f)?;
abuf = ['\0'; SPLIT_EVERY];
}
last_n = n;
}
if last_n != SPLIT_EVERY-1
{
for _ in 0..(SPLIT_EVERY-last_n)
{
write!(f, " ")?;
}
write!(f, "\t\t")?;
for ch in abuf.iter().filter(|&x| *x!= '\0')
{
write!(f, "{}", ch)?;
}
writeln!(f)?;
}
Ok(())
}
}
impl<'a, I: AsRef<[u8]>+?Sized> fmt::Display for HexStringView<'a, I>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
if self.1 {
let mut iter = self.0.as_ref().iter();
if let Some(byte) = iter.next()
{
write!(f, "{:02}", byte)?;
} else {
return Ok(())
}
for byte in iter {
write!(f, " {:02x}", byte)?;
}
} else {
for byte in self.0.as_ref().iter() {
write!(f, "{:02x}", byte)?;
}
}
Ok(())
}
}
/// Extensions on byte slices to print them nicely
pub trait HexStringExt: AsRef<[u8]>
{
/// An iterator that prints readable ascii of each byte
fn iter_ascii(&self) -> AsciiIter<'_, Self>;
/// A `Display` implementor that prints ascii of each byte
fn fmt_ascii(&self) -> AsciiView<'_, Self>;
/// A pretty hex view `Display` implementor of the bytes
fn fmt_view(&self) -> HexView<'_, Self>;
/// A `Display` implementor that prints the hex of each byte in lowercase
fn fmt_hex(&self) -> HexStringView<'_, Self>;
/// An iterator over `char`s that yields the hex of each byte
///
/// # Notes
/// This yields each character one at a time, to get the hex of each byte, chunk it with a window of 2.
fn iter_hex(&self) -> HexStringIter<'_, Self>;
/// Convenience method for creating a hex string.
fn to_hex_string(&self) -> String
{
let mut string = String::with_capacity(self.as_ref().len()*2);
use fmt::Write;
write!(&mut string, "{}", self.fmt_hex()).unwrap();
string
}
/// Convenience method for creating a hex string with each byte broken by a hyphen.
fn to_broken_hex_string(&self) -> String
{
let fmt = HexStringView(
self.fmt_hex().0,
true
);
let mut string = String::with_capacity(self.as_ref().len()*3);
use fmt::Write;
write!(&mut string, "{}", fmt).unwrap();
string
}
/// Convenience method for creating a string from `fmt_view()`
#[inline] fn to_view_string(&self) -> String
{
format!("{}", self.fmt_view())
}
/// Convenience method for creating a string from `fmt_ascii()`
#[inline] fn to_ascii_string(&self) -> String
{
self.iter_ascii().collect()
}
}
impl<T: AsRef<[u8]>+?Sized> HexStringExt for T
{
fn iter_hex(&self) -> HexStringIter<'_, Self>
{
HexStringIter(self.as_ref().iter(), (0,0), PhantomData)
}
fn iter_ascii(&self) -> AsciiIter<'_, Self>
{
AsciiIter(self.as_ref(), PhantomData)
}
fn fmt_ascii(&self) -> AsciiView<'_, Self>
{
AsciiView(&self)
}
fn fmt_hex(&self) -> HexStringView<'_, Self>
{
HexStringView(&self, false)
}
fn fmt_view(&self) -> HexView<'_, Self>
{
HexView(&self)
}
}
#[pin_project]
pub struct ReadAllBytes<'a, T: AsyncRead+Unpin+?Sized>(#[pin] &'a mut T, Option<usize>);
impl<'a, T: AsyncRead+Unpin+?Sized> Future for ReadAllBytes<'a, T>
{
type Output = std::io::Result<Vec<u8>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output>
{
let fut = async move {
let this = self.project();
let mut output = Vec::with_capacity(4096*10);
let mut input = this.0;
let max = *this.1;
let mut buffer =[0u8; 4096];
let mut read;
while {read = input.read(&mut buffer[..]).await?; read!=0} {
output.extend_from_slice(&buffer[..read]);
if let Some(max) = max {
if output.len() >=max {
return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("Attempted to read more than allowed max {} bytes", max)));
}
}
}
Ok(output)
};
tokio::pin!(fut);
fut.poll(ctx)
}
}
pub trait ReadAllBytesExt: AsyncRead+Unpin
{
/// Attempt to read the whole stream to a new `Vec<u8>`.
fn read_whole_stream(&mut self, max: Option<usize>) -> ReadAllBytes<'_, Self>
{
ReadAllBytes(self, max)
}
}
impl<T: AsyncRead+Unpin+?Sized> ReadAllBytesExt for T{}
pub trait FromHexExt
{
fn repl_with_hex<U: AsRef<[u8]>>(&mut self, input: U) -> Result<(), HexDecodeError>;
}
impl<T: AsMut<[u8]>+?Sized> FromHexExt for T
{
fn repl_with_hex<U: AsRef<[u8]>>(&mut self, input: U) -> Result<(), HexDecodeError> {
let out = self.as_mut();
#[inline] fn val(c: u8, idx: usize) -> Result<u8, HexDecodeError> {
match c {
b'A'..=b'F' => Ok(c - b'A' + 10),
b'a'..=b'f' => Ok(c - b'a' + 10),
b'0'..=b'9' => Ok(c - b'0'),
_ => Err(HexDecodeError{
chr: c as char,
idx,
}),
}
}
for (i, (byte, digits)) in (0..).zip(out.iter_mut().zip(input.as_ref().chunks_exact(2)))
{
*byte = val(digits[0], 2*i)? << 4 | val(digits[1], 2 * i + 1)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct HexDecodeError {
idx: usize,
chr: char,
}
impl error::Error for HexDecodeError{}
impl fmt::Display for HexDecodeError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "Invalid hex at index {} (character was {:?})", self.idx, self.chr)
}
}
#[cfg(test)]
mod tests
{
use super::*;
fn format()
{
let bytes = b"hello world one two three \x142!";
panic!("\n{}\n", bytes.fmt_view());
}
#[test]
fn hex()
{
const INPUT_HEX: [u8; 32] = hex_literal::hex!("d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d");
const INPUT_STR: &str = "d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d";
let mut output = [0u8; 32];
output.repl_with_hex(INPUT_STR).expect("Failed!");
assert_eq!(&INPUT_HEX[..], &output[..]);
}
#[cfg(nightly)]
mod benchmarks
{
use super::*;
use test::{Bencher, black_box};
#[bench]
fn hex_via_val(b: &mut Bencher)
{
fn repl_with_hex<U: AsRef<[u8]>>(out: &mut [u8], input: U) -> Result<(), HexDecodeError> {
#[inline] fn val(c: u8, idx: usize) -> Result<u8, HexDecodeError> {
match c {
b'A'..=b'F' => Ok(c - b'A' + 10),
b'a'..=b'f' => Ok(c - b'a' + 10),
b'0'..=b'9' => Ok(c - b'0'),
_ => Err(HexDecodeError{
chr: c as char,
idx,
}),
}
}
for (i, (byte, digits)) in (0..).zip(out.iter_mut().zip(input.as_ref().chunks_exact(2)))
{
*byte = val(digits[0], 2*i)? << 4 | val(digits[1], 2 * i + 1)?;
}
Ok(())
}
const INPUT_HEX: [u8; 32] = hex_literal::hex!("d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d");
const INPUT_STR: &str = "d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d";
let mut output = [0u8; 32];
b.iter(|| {
black_box(repl_with_hex(&mut output[..], INPUT_STR).unwrap());
});
assert_eq!(&INPUT_HEX[..], &output[..]);
}
#[bench]
fn hex_via_lazy(b: &mut Bencher)
{
fn repl_with_hex<U: AsRef<[u8]>>(out: &mut [u8], input: U) -> Result<(), HexDecodeError> {
use smallmap::Map;
lazy_static::lazy_static! {
static ref MAP: Map<u8, u8> = {
let mut map = Map::new();
for c in 0..=255u8
{
map.insert(c, match c {
b'A'..=b'F' => c - b'A' + 10,
b'a'..=b'f' => c - b'a' + 10,
b'0'..=b'9' => c - b'0',
_ => continue,
});
}
map
};
}
#[inline(always)] fn val(c: u8, idx: usize) -> Result<u8, HexDecodeError> {
MAP.get(&c).copied()
.ok_or_else(|| HexDecodeError{idx, chr: c as char})
}
for (i, (byte, digits)) in (0..).zip(out.iter_mut().zip(input.as_ref().chunks_exact(2)))
{
*byte = val(digits[0], 2*i)? << 4 | val(digits[1], 2 * i + 1)?;
}
Ok(())
}
const INPUT_HEX: [u8; 32] = hex_literal::hex!("d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d");
const INPUT_STR: &str = "d0a2404173bac722b29282652f2c457b573261e3c8701b908bb0bd3ada3d7f2d";
let mut output = [0u8; 32];
b.iter(|| {
black_box(repl_with_hex(&mut output[..], INPUT_STR).unwrap());
});
assert_eq!(&INPUT_HEX[..], &output[..]);
}
}
}