From a3319ee7cfcc8f25793005648140b401efe0c734 Mon Sep 17 00:00:00 2001 From: Avril Date: Mon, 19 Oct 2020 17:12:43 +0100 Subject: [PATCH] initial commit --- .gitignore | 3 + Cargo.toml | 18 ++++ src/hashing.rs | 230 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 214 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 465 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/hashing.rs create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..80aca69 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +*~ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..021b34a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "refset" +description = "A non-owning HashSet" +keywords = ["hash", "set", "reference"] +version = "0.1.0" +authors = ["Avril "] +edition = "2018" +license= "mit" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +sha2 = "0.9" +serde = {version = "1.0", optional = true, features=["derive"]} + +[dev-dependencies] +serde_json = "1.0" +serde_cbor = "0.11.1" diff --git a/src/hashing.rs b/src/hashing.rs new file mode 100644 index 0000000..2b54ffc --- /dev/null +++ b/src/hashing.rs @@ -0,0 +1,230 @@ +use super::*; +use sha2::{ + Sha512, + Digest, +}; +use std::{ + fmt, + hash::{ + Hasher, + }, +}; + +/// Number of bytes required to store a SHA512 hash. +pub const HASH_SIZE: usize = 64; + +/// Represents a single hash output from SHA512 algorithm. +#[repr(transparent)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct Sha512Hash([u8; HASH_SIZE]); + +#[cfg(feature="serde")] +impl serde::Serialize for Sha512Hash { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_bytes(&self.0[..]) + } +} + +#[cfg(feature="serde")] +pub struct Sha512HashVisitor; + +#[cfg(feature="serde")] +impl<'de> serde::de::Visitor<'de> for Sha512HashVisitor { + type Value = Sha512Hash; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an array of 64 bytes") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where E: serde::de::Error + { + let mut output = [0u8; HASH_SIZE]; + if v.len() == output.len() { + unsafe { + std::ptr::copy_nonoverlapping(&v[0] as *const u8, &mut output[0] as *mut u8, HASH_SIZE); + } + Ok(Sha512Hash::from_bytes(output)) + } else { + Err(E::custom(format!("Expected {} bytes, got {}", HASH_SIZE, v.len()))) + } + } + fn visit_seq(self, mut seq: A) -> Result where + A: serde::de::SeqAccess<'de> + { + let mut bytes = [0u8; HASH_SIZE]; + let mut i=0usize; + while let Some(byte) = seq.next_element()? + { + bytes[i] = byte; + i+=1; + if i==HASH_SIZE { + return Ok(Sha512Hash::from_bytes(bytes)); + } + } + use serde::de::Error; + Err(A::Error::custom(format!("Expected {} bytes, got {}", HASH_SIZE, i))) + } +} + +#[cfg(feature="serde")] +impl<'de> serde::Deserialize<'de> for Sha512Hash { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_bytes(Sha512HashVisitor) + } +} + +impl Default for Sha512Hash +{ + #[inline] + fn default() -> Self + { + Self([0; HASH_SIZE]) + } +} + +impl Sha512Hash +{ + #[inline] pub const fn empty() -> Self + { + Self([0u8; HASH_SIZE]) + } + #[inline] pub const fn from_bytes(from: [u8; HASH_SIZE]) -> Self + { + Self(from) + } + #[inline] pub const fn into_bytes(self) -> [u8; HASH_SIZE] + { + self.0 + } +} + +impl AsRef<[u8]> for Sha512Hash +{ + fn as_ref(&self) -> &[u8] + { + &self.0[..] + } +} + +impl AsMut<[u8]> for Sha512Hash +{ + fn as_mut(&mut self) -> &mut [u8] + { + &mut self.0[..] + } +} + +impl From for Sha512Hash +{ + fn from(from: Sha512) -> Self + { + let mut arr = [0u8; HASH_SIZE]; + let from = from.finalize(); // Into not implemented for [T; 64]. sigh... + debug_assert_eq!(arr.len(), from.len()); + unsafe { + std::ptr::copy_nonoverlapping(&from[0] as *const u8, &mut arr[0] as *mut u8, HASH_SIZE); + } + Self(arr) + } +} + +impl fmt::Display for Sha512Hash +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + for byte in self.0.iter() + { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} + +#[derive(Debug)] +pub(super) struct Sha512Hasher(Sha512); + +impl Hasher for Sha512Hasher +{ + #[inline] fn write(&mut self, bytes: &[u8]) { + self.0.update(bytes) + } + #[inline] fn finish(&self) -> u64 + { + use sha2::digest::generic_array::sequence::Split; + let arr = self.0.clone().finalize(); + // Take the first 8 bytes and convert to native endian u64 + u64::from_ne_bytes(arr.split().0.into()) + } +} + +impl Sha512Hasher +{ + #[inline] pub fn new() -> Self + { + Self(Sha512::new()) + } + #[inline] pub fn finalize(self) -> Sha512Hash + { + self.0.into() + } + #[inline] pub fn into_inner(self) -> Sha512 + { + self.0 + } +} + +#[cfg(test)] +mod test +{ + use super::*; + #[test] + fn match_bytes() + { + let mut hasher = Sha512Hasher::new(); + "hello world".hash(&mut hasher); + let low = hasher.finish(); + let full = hasher.finalize(); + + let mut top = [0u8; std::mem::size_of::()]; + assert!(top.len() < full.0.len()); + unsafe { + std::ptr::copy_nonoverlapping(&full.0[0] as *const u8, &mut top[0] as *mut u8, top.len()); + } + + assert_eq!(u64::from_ne_bytes(top), low); + } + + #[cfg(feature="serde")] + #[test] + fn ser_json() + { + let hash = crate::compute_hash_for("hello world"); + println!("Orig: {}", hash); + let string = serde_json::to_string(&hash).expect("Serial failed"); + println!("Ser: {:?}", string); + let de = serde_json::from_str(&string[..]).expect("Deserial failed"); + println!("De: {}", de); + assert_eq!(hash, de); + } + + #[cfg(feature="serde")] + #[test] + fn ser_cbor() + { + let hash = crate::compute_hash_for("hello world"); + println!("Orig: {}", hash); + let vec = serde_cbor::to_vec(&hash).expect("Serial failed"); + println!("Ser: {:?}", vec); + let de = serde_cbor::from_slice(&vec[..]).expect("Deserial failed"); + println!("De: {}", de); + assert_eq!(hash, de); + } + +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9740dc7 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,214 @@ +//! A hash-set analogue that does not own its data. +//! +//! It can be used to "mark" items without the need to transfer ownership to the map +//! +//! # Example use case +//! ``` +//! # use refset::HashRefSet; +//! /// Process arguments while ignoring duplicates +//! fn process_args(args: impl IntoIterator) { +//! let mut same= HashRefSet::new(); +//! for argument in args.into_iter() +//! { +//! if !same.insert(argument.as_str()) { +//! // Already processed this input, ignore +//! continue; +//! } +//! //do work... +//! } +//! } +//! ``` +//! # Serialisation support with `serde` crate +//! `HashRefSet` and `HashType` both implement `Serialize` and `Deserialize` from the `serde` crate if the `serde` feature is enabled. By default it is not. +//! # Drawbacks +//! Since the item is not inserted itself, we cannot use `Eq` to double check there was not a hash collision. +//! While the hashing algorithm used (Sha512) is extremely unlikely to produce collisions, especially for small data types, keep in mind that it is not infallible. +use std::{ + collections::{ + hash_set, + HashSet, + }, + marker::{ + PhantomData, + Send, + Sync, + }, + hash::Hash, + borrow::Borrow, +}; + +mod hashing; + +/// The type used to store the hash of each item. +/// +/// It is a result of the `SHA512` algorithm as a newtype 64 byte array marked with `#[repr(transparent)]`. +/// If you want to get the bytes from it, you can transmute safely. +/// ``` +/// # use refset::HashType; +/// fn hash_bytes(hash: HashType) -> [u8; 64] +/// { +/// unsafe { +/// std::mem::transmute(hash) +/// } +/// } +/// +/// fn hash_bytes_assert() +/// { +/// assert_eq!(hash_bytes(Default::default()), [0u8; 64]); +/// } +/// ``` +pub type HashType = hashing::Sha512Hash; + +/// Compute the `HashType` value for this `T`. +fn compute_hash_for(value: &T) -> HashType +{ + let mut hasher = hashing::Sha512Hasher::new(); + value.hash(&mut hasher); + hasher.finalize() +} + +#[allow(dead_code)] +#[cold] fn compute_both_hash_for(value: &T) -> (u64, HashType) +{ + use sha2::{ + Digest, + digest::generic_array::sequence::Split, + }; + let mut hasher = hashing::Sha512Hasher::new(); + value.hash(&mut hasher); + let sha512 = hasher.into_inner(); + + let full = sha512.finalize(); + + let mut arr = [0u8; hashing::HASH_SIZE]; + debug_assert_eq!(arr.len(), full.len()); + unsafe { + std::ptr::copy_nonoverlapping(&full[0] as *const u8, &mut arr[0] as *mut u8, hashing::HASH_SIZE); + } + (u64::from_ne_bytes(full.split().0.into()), HashType::from_bytes(arr)) +} + +/// A hash-set of references to an item. +/// +/// Instead of inserting the item into the set, the set is "marked" with the item. +/// Think of this as inserting a reference into the set with no lifetime. +/// +/// Any type that can borrow to `T` can be used to insert, and neither type needs to be `Sized`. +/// `T` need only implement `Hash`. +/// +/// # Hashing algorithm +/// The hasing algorithm used is `Sha512`, which is rather large (64 bytes). +/// At present there is no way to change the hasher used, I might implement that functionality in the future. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +#[cfg_attr(feature="serde", derive(serde::Serialize, serde::Deserialize))] +pub struct HashRefSet(HashSet, PhantomData>); + +unsafe impl Send for HashRefSet{} +unsafe impl Sync for HashRefSet{} + +impl HashRefSet +{ + /// Create a new empty `HashRefSet` + pub fn new() -> Self + { + Self( + HashSet::new(), + PhantomData + ) + } + /// Create a new `HashRefSet` with a capacity + pub fn with_capacity(cap: usize) -> Self + { + Self(HashSet::with_capacity(cap), PhantomData) + } + + /// Insert a reference into the set. The reference can be any type that borrows to `T`. + /// + /// Returns `true` if there was no previous item, `false` if there was. + pub fn insert(&mut self, value: &Q) -> bool + where Q: ?Sized + Borrow + { + self.0.insert(compute_hash_for(value.borrow())) + } + + /// Remove a reference from the set. + /// + /// Returns `true` if it existed. + pub fn remove(&mut self, value: &Q) -> bool + where Q: ?Sized + Borrow + { + self.0.remove(&compute_hash_for(value.borrow())) + } + + /// Check if this value has been inserted into the set. + pub fn contains(&mut self, value: &Q) -> bool + where Q: ?Sized + Borrow + { + self.0.contains(&compute_hash_for(value.borrow())) + } + + /// The number of items stored in the set + pub fn len(&self) -> usize + { + self.0.len() + } + + /// Is the set empty + pub fn is_empty(&self) -> bool + { + self.0.is_empty() + } + + /// An iterator over the hashes stored in the set. + pub fn hashes_iter(&self) -> hash_set::Iter<'_, HashType> + { + self.0.iter() + } + + #[inline] fn into_hashes_iter(self) -> hash_set::IntoIter + { + self.0.into_iter() + } +} + +impl IntoIterator for HashRefSet +{ + type Item= HashType; + type IntoIter = hash_set::IntoIter; + + #[inline] fn into_iter(self) -> Self::IntoIter + { + self.into_hashes_iter() + } +} + + +#[cfg(test)] +mod tests +{ + use super::*; + #[test] + fn insert() + { + let mut refset = HashRefSet::new(); + + let values= vec![ + "hi", + "hello", + "one", + "two", + ]; + for &string in values.iter() + { + refset.insert(string); + } + + for string in values + { + assert!(refset.contains(string)); + } + + assert!(refset.insert("none")); + assert!(!refset.insert("two")); + } +}