use super::*; use sha2::{ Sha512, Digest, }; use std::{ fmt, hash::{ Hasher, }, }; /// Number of bytes required to store a SHA512 hash. pub const HASH_SIZE: usize = 64; /// Represents a single hash output from SHA512 algorithm. #[repr(transparent)] #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct Sha512Hash([u8; HASH_SIZE]); #[cfg(feature="serde")] impl serde::Serialize for Sha512Hash { fn serialize(&self, serializer: S) -> Result where S: serde::ser::Serializer, { serializer.serialize_bytes(&self.0[..]) } } #[cfg(feature="serde")] pub struct Sha512HashVisitor; #[cfg(feature="serde")] impl<'de> serde::de::Visitor<'de> for Sha512HashVisitor { type Value = Sha512Hash; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("an array of 64 bytes") } fn visit_bytes(self, v: &[u8]) -> Result where E: serde::de::Error { let mut output = [0u8; HASH_SIZE]; if v.len() == output.len() { unsafe { std::ptr::copy_nonoverlapping(&v[0] as *const u8, &mut output[0] as *mut u8, HASH_SIZE); } Ok(Sha512Hash::from_bytes(output)) } else { Err(E::custom(format!("Expected {} bytes, got {}", HASH_SIZE, v.len()))) } } fn visit_seq(self, mut seq: A) -> Result where A: serde::de::SeqAccess<'de> { let mut bytes = [0u8; HASH_SIZE]; let mut i=0usize; while let Some(byte) = seq.next_element()? { bytes[i] = byte; i+=1; if i==HASH_SIZE { return Ok(Sha512Hash::from_bytes(bytes)); } } use serde::de::Error; Err(A::Error::custom(format!("Expected {} bytes, got {}", HASH_SIZE, i))) } } #[cfg(feature="serde")] impl<'de> serde::Deserialize<'de> for Sha512Hash { fn deserialize(deserializer: D) -> Result where D: serde::de::Deserializer<'de>, { deserializer.deserialize_bytes(Sha512HashVisitor) } } impl Default for Sha512Hash { #[inline] fn default() -> Self { Self([0; HASH_SIZE]) } } impl Sha512Hash { #[inline] pub const fn empty() -> Self { Self([0u8; HASH_SIZE]) } #[inline] pub const fn from_bytes(from: [u8; HASH_SIZE]) -> Self { Self(from) } #[inline] pub const fn into_bytes(self) -> [u8; HASH_SIZE] { self.0 } } impl AsRef<[u8]> for Sha512Hash { fn as_ref(&self) -> &[u8] { &self.0[..] } } impl AsMut<[u8]> for Sha512Hash { fn as_mut(&mut self) -> &mut [u8] { &mut self.0[..] } } impl From for Sha512Hash { fn from(from: Sha512) -> Self { let mut arr = [0u8; HASH_SIZE]; let from = from.finalize(); // Into not implemented for [T; 64]. sigh... debug_assert_eq!(arr.len(), from.len()); unsafe { std::ptr::copy_nonoverlapping(&from[0] as *const u8, &mut arr[0] as *mut u8, HASH_SIZE); } Self(arr) } } impl fmt::Display for Sha512Hash { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for byte in self.0.iter() { write!(f, "{:02x}", byte)?; } Ok(()) } } #[derive(Debug)] pub(super) struct Sha512Hasher(Sha512); impl Hasher for Sha512Hasher { #[inline] fn write(&mut self, bytes: &[u8]) { self.0.update(bytes) } #[inline] fn finish(&self) -> u64 { use sha2::digest::generic_array::sequence::Split; let arr = self.0.clone().finalize(); // Take the first 8 bytes and convert to native endian u64 u64::from_ne_bytes(arr.split().0.into()) } } impl Sha512Hasher { #[inline] pub fn new() -> Self { Self(Sha512::new()) } #[inline] pub fn finalize(self) -> Sha512Hash { self.0.into() } #[inline] pub fn into_inner(self) -> Sha512 { self.0 } } #[cfg(test)] mod test { use super::*; #[test] fn match_bytes() { let mut hasher = Sha512Hasher::new(); "hello world".hash(&mut hasher); let low = hasher.finish(); let full = hasher.finalize(); let mut top = [0u8; std::mem::size_of::()]; assert!(top.len() < full.0.len()); unsafe { std::ptr::copy_nonoverlapping(&full.0[0] as *const u8, &mut top[0] as *mut u8, top.len()); } assert_eq!(u64::from_ne_bytes(top), low); } #[cfg(feature="serde")] #[test] fn ser_json() { let hash = crate::compute_hash_for("hello world"); println!("Orig: {}", hash); let string = serde_json::to_string(&hash).expect("Serial failed"); println!("Ser: {:?}", string); let de = serde_json::from_str(&string[..]).expect("Deserial failed"); println!("De: {}", de); assert_eq!(hash, de); } #[cfg(feature="serde")] #[test] fn ser_cbor() { let hash = crate::compute_hash_for("hello world"); println!("Orig: {}", hash); let vec = serde_cbor::to_vec(&hash).expect("Serial failed"); println!("Ser: {:?}", vec); let de = serde_cbor::from_slice(&vec[..]).expect("Deserial failed"); println!("De: {}", de); assert_eq!(hash, de); } }