diff --git a/src/ext.rs b/src/ext.rs index 226449b..ba53044 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -413,10 +413,10 @@ lazy_static! { /// A wrapper for hashing with a specific salt. #[derive(Debug, Hash)] -pub struct Salted<'a, T: std::hash::Hash>(&'a T, &'a [u8]); +pub struct Salted<'a, T: ?Sized + std::hash::Hash>(&'a T, &'a [u8]); impl<'a, T> Salted<'a, T> -where T: std::hash::Hash +where T: std::hash::Hash + ?Sized { /// Create a new wrapper. pub fn new(val: &'a T, salt: &'a [u8]) -> Self @@ -426,10 +426,10 @@ where T: std::hash::Hash } /// A wrapper for hashing with the global salt. #[derive(Debug, Hash)] -pub struct GloballySalted<'a, T: std::hash::Hash>(&'a T, &'static [u8]); +pub struct GloballySalted<'a, T: ?Sized + std::hash::Hash>(&'a T, &'static [u8]); impl<'a, T> GloballySalted<'a, T> -where T: std::hash::Hash +where T: std::hash::Hash + ?Sized { /// Create a new wrapper. pub fn new(val: &'a T) -> Self @@ -438,6 +438,24 @@ where T: std::hash::Hash } } +pub trait HashWithSaltExt: std::hash::Hash +{ + /// Create a hash wrapper around this instance that hashes with a specific salt. + fn hash_with_salt<'a>(&'a self, salt: &'a [u8]) -> Salted<'a, Self>; + /// Create a hash wrapper around this instance that hashes with the global salt. + #[inline] fn hash_with_global_salt(&self) -> GloballySalted<'_, Self> + { + GloballySalted::new(self) + } +} + +impl HashWithSaltExt for T where T: ?Sized + std::hash::Hash +{ + #[inline] fn hash_with_salt<'a>(&'a self, salt: &'a [u8]) -> Salted<'a, Self> { + Salted::new(self, salt) + } +} + mod sha256_hasher { use std::mem::size_of; use sha2::{ @@ -489,3 +507,24 @@ pub use sha256_hasher::Sha256HashExt; /// Value may hold one in place or allocate on the heap to hold many. pub type MaybeVec = smallvec::SmallVec<[T; 1]>; + +#[macro_export] macro_rules! impl_deref { + (for $($(frag:tt)*;)? $name:ident impl $to:ident as $expr:expr $(; mut $mut_expr:expr)?) => { + impl $($($frag)*)? ::std::ops::Deref for $name + { + type Target = $to; + fn deref(&self) -> &Self::Target + { + $expr + } + } + $( + impl $($($frag)*)? ::std::ops::DerefMut for $name + { + fn deref_mut(&mut self) -> &mut ::Target + { + $mut_expr + } + })? + }; +} diff --git a/src/state/freeze.rs b/src/state/freeze.rs index d398ac5..8a285b8 100644 --- a/src/state/freeze.rs +++ b/src/state/freeze.rs @@ -5,24 +5,24 @@ use std::{error,fmt}; /// An image of the entire post container #[derive(Debug, Default, Serialize, Deserialize)] -pub struct Imouto +pub struct Freeze { posts: Vec, } -impl From for Oneesan +impl From for Imouto { - #[inline] fn from(from: Imouto) -> Self + #[inline] fn from(from: Freeze) -> Self { Self::from_freeze(from) } } -impl TryFrom for Imouto +impl TryFrom for Freeze { type Error = FreezeError; - #[inline] fn try_from(from: Oneesan) -> Result + #[inline] fn try_from(from: Imouto) -> Result { from.try_into_freeze() } @@ -65,16 +65,16 @@ impl fmt::Display for FreezeError } } -impl Oneesan +impl Imouto { /// Create a serialisable image of this store by cloning each post into it. - pub async fn freeze(&self) -> Imouto + pub async fn freeze(&self) -> Freeze { - let read = self.posts.read().await; - let mut sis = Imouto{ - posts: Vec::with_capacity(read.0.len()), + let read = &self.all; + let mut sis = Freeze{ + posts: Vec::with_capacity(read.len()), }; - for (_, post) in read.0.iter() + for (_, post) in read.iter() { sis.posts.push(post.read().await.clone()); } @@ -86,13 +86,13 @@ impl Oneesan /// /// # Fails /// If references to any posts are still held elsewhere. - pub fn try_into_freeze(self) -> Result + pub fn try_into_freeze(self) -> Result { - let read = self.posts.into_inner(); - let mut sis = Imouto{ - posts: Vec::with_capacity(read.0.len()), + let read = self.all; + let mut sis = Freeze{ + posts: Vec::with_capacity(read.len()), }; - for post in read.0.into_iter() + for post in read.into_iter() { sis.posts.push(match Arc::try_unwrap(post) { Ok(val) => val.into_inner(), @@ -108,13 +108,13 @@ impl Oneesan /// /// # Panics /// If references to any posts are still held elsewhere. - pub fn into_freeze(self) -> Imouto + pub fn into_freeze(self) -> Freeze { self.try_into_freeze().expect("Failed to consume into freeze") } /// Create a new store from a serialisable image of one by cloning each post in it - pub fn unfreeze(freeze: &Imouto) -> Self + pub fn unfreeze(freeze: &Freeze) -> Self { let mut posts = Arena::new(); let mut user_map = HashMap::new(); @@ -128,12 +128,13 @@ impl Oneesan } Self { - posts: RwLock::new((posts, user_map)) + all: posts, + user_map, } } /// Create a new store by consuming serialisable image of one by cloning each post in it - pub fn from_freeze(freeze: Imouto) -> Self + pub fn from_freeze(freeze: Freeze) -> Self { let mut posts = Arena::new(); let mut user_map = HashMap::new(); @@ -147,7 +148,8 @@ impl Oneesan } Self { - posts: RwLock::new((posts, user_map)) + all: posts, + user_map, } } } diff --git a/src/state/mod.rs b/src/state/mod.rs index 2984ed3..2980136 100644 --- a/src/state/mod.rs +++ b/src/state/mod.rs @@ -5,6 +5,7 @@ use generational_arena::{ }; use std::sync::Arc; use tokio::sync::RwLock; +use std::ops::{Deref, DerefMut}; pub mod session; pub mod user; @@ -13,10 +14,55 @@ pub mod body; mod freeze; pub use freeze::*; -/// Entire post container -pub struct Oneesan +/// Entire post state container +#[derive(Debug)] +pub struct Imouto { - posts: RwLock<(Arena>> // All posts - , HashMap> // Post lookup by user ID - )>, + all: Arena>>, + user_map: HashMap>, +} + +impl Imouto +{ + + /// Create a new empty container + pub fn new() -> Self + { + Self { + all: Arena::new(), + user_map: HashMap::new(), + } + } +} + +#[derive(Debug)] +/// Entire program state +struct Oneesan +{ + posts: RwLock, +} + +/// Shares whole program state +#[derive(Debug, Clone)] +pub struct State(Arc); + +impl State +{ + /// Create a new empty state. + pub fn new() -> Self + { + Self(Arc::new(Oneesan { + posts: RwLock::new(Imouto::new()), + })) + } + /// Get a reference to the post state container + pub async fn imouto(&self) -> tokio::sync::RwLockReadGuard<'_, Imouto> + { + self.0.posts.read().await + } + /// Get a mutable reference to the post state container + pub async fn imouto_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, Imouto> + { + self.0.posts.write().await + } } diff --git a/src/state/session.rs b/src/state/session.rs index 748caa3..7f7aabf 100644 --- a/src/state/session.rs +++ b/src/state/session.rs @@ -3,6 +3,15 @@ use super::*; id_type!(SessionID; "A unique session ID, not bound to a user."); +impl SessionID +{ + /// Generate a random session ID. + #[inline] fn generate() -> Self + { + Self::id_new() + } +} + #[derive(Debug)] pub struct Session { @@ -12,6 +21,14 @@ pub struct Session impl Session { + /// Create a new session object + pub fn create(user: user::User) -> Self + { + Self { + user, + id: SessionID::generate(), + } + } /// The randomly generated ID of this session, irrespective of the user of this session. #[inline] pub fn session_id(&self) -> &SessionID { diff --git a/src/state/user.rs b/src/state/user.rs index 1bc990c..12ac603 100644 --- a/src/state/user.rs +++ b/src/state/user.rs @@ -45,7 +45,7 @@ impl UserID } /// A user not bound to a session. -#[derive(Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)] pub struct User { addr: SocketAddr, @@ -59,3 +59,48 @@ impl User UserID(self.addr, session.session_id().clone()) } } + +#[cfg(test)] +mod tests +{ + use super::*; + use tokio::sync::mpsc; + use tokio::time; + use std::net::SocketAddrV4; + use std::net::Ipv4Addr; + #[tokio::test] + async fn counter_tokens() + { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 80)); + let usr = User{addr}; + let ses = session::Session::create(usr); + + let id = ses.user_id(); + + let (mut tx, mut rx) = mpsc::channel(5); + let task = tokio::spawn(async move { + let id = ses.user_id(); + + while let Some(token) = rx.recv().await { + if !id.validate_token(token) { + panic!("Failed to validate token {:x} for id {:?}", token, id); + } else { + eprintln!("Token {:x} valid for id {:?}", token, id); + } + } + }); + + for x in 1..=10 + { + if x % 2 == 0 { + time::delay_for(time::Duration::from_millis(10 * x)).await; + } + if tx.send(id.generate_token()).await.is_err() { + eprintln!("Failed to send to task"); + break; + } + } + drop(tx); + task.await.expect("Background validate task failed"); + } +}