You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
4.9 KiB

use super::*;
use sha2::{
Sha512,
Digest,
};
use std::{
fmt,
hash::{
Hasher,
},
};
/// Number of bytes required to store a SHA512 hash.
pub const HASH_SIZE: usize = 64;
/// Represents a single hash output from SHA512 algorithm.
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Sha512Hash([u8; HASH_SIZE]);
#[cfg(feature="serde")]
impl serde::Serialize for Sha512Hash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
serializer.serialize_bytes(&self.0[..])
}
}
#[cfg(feature="serde")]
pub struct Sha512HashVisitor;
#[cfg(feature="serde")]
impl<'de> serde::de::Visitor<'de> for Sha512HashVisitor {
type Value = Sha512Hash;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("an array of 64 bytes")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where E: serde::de::Error
{
let mut output = [0u8; HASH_SIZE];
if v.len() == output.len() {
unsafe {
std::ptr::copy_nonoverlapping(&v[0] as *const u8, &mut output[0] as *mut u8, HASH_SIZE);
}
Ok(Sha512Hash::from_bytes(output))
} else {
Err(E::custom(format!("Expected {} bytes, got {}", HASH_SIZE, v.len())))
}
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> where
A: serde::de::SeqAccess<'de>
{
let mut bytes = [0u8; HASH_SIZE];
let mut i=0usize;
while let Some(byte) = seq.next_element()?
{
bytes[i] = byte;
i+=1;
if i==HASH_SIZE {
return Ok(Sha512Hash::from_bytes(bytes));
}
}
use serde::de::Error;
Err(A::Error::custom(format!("Expected {} bytes, got {}", HASH_SIZE, i)))
}
}
#[cfg(feature="serde")]
impl<'de> serde::Deserialize<'de> for Sha512Hash {
fn deserialize<D>(deserializer: D) -> Result<Sha512Hash, D::Error>
where
D: serde::de::Deserializer<'de>,
{
deserializer.deserialize_bytes(Sha512HashVisitor)
}
}
impl Default for Sha512Hash
{
#[inline]
fn default() -> Self
{
Self([0; HASH_SIZE])
}
}
impl Sha512Hash
{
#[inline] pub const fn empty() -> Self
{
Self([0u8; HASH_SIZE])
}
#[inline] pub const fn from_bytes(from: [u8; HASH_SIZE]) -> Self
{
Self(from)
}
#[inline] pub const fn into_bytes(self) -> [u8; HASH_SIZE]
{
self.0
}
}
impl AsRef<[u8]> for Sha512Hash
{
fn as_ref(&self) -> &[u8]
{
&self.0[..]
}
}
impl AsMut<[u8]> for Sha512Hash
{
fn as_mut(&mut self) -> &mut [u8]
{
&mut self.0[..]
}
}
impl From<Sha512> for Sha512Hash
{
fn from(from: Sha512) -> Self
{
let mut arr = [0u8; HASH_SIZE];
let from = from.finalize(); // Into not implemented for [T; 64]. sigh...
debug_assert_eq!(arr.len(), from.len());
unsafe {
std::ptr::copy_nonoverlapping(&from[0] as *const u8, &mut arr[0] as *mut u8, HASH_SIZE);
}
Self(arr)
}
}
impl fmt::Display for Sha512Hash
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
for byte in self.0.iter()
{
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}
#[derive(Debug)]
pub(super) struct Sha512Hasher(Sha512);
impl Hasher for Sha512Hasher
{
#[inline] fn write(&mut self, bytes: &[u8]) {
self.0.update(bytes)
}
#[inline] fn finish(&self) -> u64
{
use sha2::digest::generic_array::sequence::Split;
let arr = self.0.clone().finalize();
// Take the first 8 bytes and convert to native endian u64
u64::from_ne_bytes(arr.split().0.into())
}
}
impl Sha512Hasher
{
#[inline] pub fn new() -> Self
{
Self(Sha512::new())
}
#[inline] pub fn finalize(self) -> Sha512Hash
{
self.0.into()
}
#[inline] pub fn into_inner(self) -> Sha512
{
self.0
}
}
#[cfg(test)]
mod test
{
use super::*;
#[test]
fn match_bytes()
{
let mut hasher = Sha512Hasher::new();
"hello world".hash(&mut hasher);
let low = hasher.finish();
let full = hasher.finalize();
let mut top = [0u8; std::mem::size_of::<u64>()];
assert!(top.len() < full.0.len());
unsafe {
std::ptr::copy_nonoverlapping(&full.0[0] as *const u8, &mut top[0] as *mut u8, top.len());
}
assert_eq!(u64::from_ne_bytes(top), low);
}
#[cfg(feature="serde")]
#[test]
fn ser_json()
{
let hash = crate::compute_hash_for("hello world");
println!("Orig: {}", hash);
let string = serde_json::to_string(&hash).expect("Serial failed");
println!("Ser: {:?}", string);
let de = serde_json::from_str(&string[..]).expect("Deserial failed");
println!("De: {}", de);
assert_eq!(hash, de);
}
#[cfg(feature="serde")]
#[test]
fn ser_cbor()
{
let hash = crate::compute_hash_for("hello world");
println!("Orig: {}", hash);
let vec = serde_cbor::to_vec(&hash).expect("Serial failed");
println!("Ser: {:?}", vec);
let de = serde_cbor::from_slice(&vec[..]).expect("Deserial failed");
println!("De: {}", de);
assert_eq!(hash, de);
}
}