diff --git a/Cargo.lock b/Cargo.lock index b2006cf..95f1c09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -586,6 +586,7 @@ dependencies = [ "futures", "lazy_static", "log", + "pin-project", "pretty_env_logger", "recolored", "rustc_version", diff --git a/Cargo.toml b/Cargo.toml index 6f60c78..3a4630f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sever" description = "Coerce hardlinks into new files" -version = "1.0.1" +version = "1.0.2" authors = ["Avril "] edition = "2018" readme = "README.org" @@ -18,7 +18,7 @@ limit-concurrency = ["parallel"] recursive = [] limit-recursion = ["recursive"] splash = [] -parallel = ["tokio", "futures"] +parallel = ["tokio", "futures", "pin-project"] threads = ["parallel", "tokio/rt-threaded"] # use PRETTY_ENV_LOGGER I guess @@ -35,6 +35,7 @@ futures = {version = "0.3.5", optional = true} lazy_static = "1.4.0" uuid = {version = "0.8.1", features = ["v4"]} recolored = "1.9.3" +pin-project = {version = "0.4.26", optional=true} [build-dependencies] rustc_version = "0.2" diff --git a/src/ext.rs b/src/ext.rs index 56aba06..0a0cd43 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -23,3 +23,60 @@ where I: IntoIterator, string } } + +#[cfg(feature="parallel")] +mod para +{ + use super::*; + use std::{ + collections::HashSet, + task::{Poll, Context,}, + pin::Pin, + marker::PhantomData, + hash::Hash, + }; + use futures::{ + stream::{ + Stream, + }, + }; + + #[pin_project] + pub struct DedupStream(#[pin] I, HashSet, PhantomData); + + impl, T: Hash> Stream for DedupStream + { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.as_mut().project(); + match this.0.poll_next(cx) { + Poll::Ready(Some(x)) => { + if this.1.insert(map::compute(&x)) { + Poll::Ready(Some(x)) + } else { + self.poll_next(cx) + } + }, + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + } + + pub trait DedupStreamExt: Stream+ Sized + { + fn dedup(self) -> DedupStream; + } + + impl DedupStreamExt for T + where T::Item: Hash + { + fn dedup(self) -> DedupStream + { + DedupStream(self, HashSet::new(), PhantomData) + } + } + +} +pub use para::*; diff --git a/src/main.rs b/src/main.rs index c7d8886..1ce4583 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] #![allow(unused_imports)] +#[cfg(feature="parallel")] #[macro_use] extern crate pin_project; #[macro_use] extern crate log; #[macro_use] mod macros; @@ -73,7 +74,9 @@ async fn main() -> eyre::Result<()> { async move { Some(parallel::expand_dir(file).await) //TODO: We gotta in here, too } - }).flatten()).await, + }) + .flatten() + .dedup()).await, "Jobs failed") } diff --git a/src/map.rs b/src/map.rs index 79b5721..2583f31 100644 --- a/src/map.rs +++ b/src/map.rs @@ -7,9 +7,9 @@ use std::{ }; //TODO: Feature flag for SHA256 hashing -type HashOutput = u64; +pub type HashOutput = u64; -fn compute(what: &H) -> HashOutput +pub fn compute(what: &H) -> HashOutput { use std::hash::Hasher; let mut hasher = std::collections::hash_map::DefaultHasher::new(); diff --git a/src/parallel.rs b/src/parallel.rs index 9cf1000..437084f 100644 --- a/src/parallel.rs +++ b/src/parallel.rs @@ -213,7 +213,8 @@ pub async fn expand_dir(p: String) -> impl Stream tx.send(p).await.unwrap(); } }); - rx //TODO: map this to dedup + rx + //DedupStream(rx, HashSet::new()) } else { stream::iter(iter::once(p).filter_map(|p| { if Path::new(&p).is_dir() {