diff --git a/Cargo.toml b/Cargo.toml index a81cc3d..2ef388c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,5 @@ sha2 = "0.9" futures = { version = "0.3", optional = true } lazy_static = "1.4" chrono = "0.4" -shellexpand = "1.1" \ No newline at end of file +shellexpand = "1.1" +# cfg_if = "0.1" \ No newline at end of file diff --git a/src/arg.rs b/src/arg.rs index b5f7420..195623a 100644 --- a/src/arg.rs +++ b/src/arg.rs @@ -37,6 +37,12 @@ pub fn usage() -> ! println!(" --cancel -w\t\tAlias for `--error-mode CANCEL`"); println!(" --error -W\t\tAlias for `--error-mode TERMINATE`"); println!(" --recurse |inf\tRecursive mode, give max depth or infinite."); + #[cfg(feature="threads")] + println!(" --threads |inf\tMax number of threads to run at once."); + #[cfg(feature="threads")] + println!(" -U\t\tUlimited max threads."); + #[cfg(feature="threads")] + println!(" --sync -S\t\tRun only one file at a time."); println!(" --\t\t\tStop reading args"); println!("Other:"); println!(" --help -h:\t\tPrint this message"); @@ -120,6 +126,9 @@ where I: IntoIterator Ok(config::Rebase{ save: files.clone(), //TODO: Seperate save+loads load: files, + + #[cfg(feature="threads")] + max_threads: config::MAX_THREADS, }) } @@ -137,6 +146,9 @@ where I: IntoIterator let mut mode_er = error::Mode::Cancel; let mut mode_rec = config::RecursionMode::None; let mut mode_log = log::Mode::Warn; + + #[cfg(feature="threads")] + let mut threads = config::MAX_THREADS; macro_rules! push { ($arg:expr) => { @@ -166,7 +178,17 @@ where I: IntoIterator "--rebase" => return Ok(Output::Rebase(parse_rebase(args)?)), "--" => reading = false, - + + #[cfg(feature="threads")] + "--threads" if take_one!() => { + if one.to_lowercase().trim() == "inf" { + threads = None; + } else { + threads = Some(one.as_str().parse::().or_else(|e| Err(Error::BadNumber(e)))?); + } + }, + #[cfg(feature="threads")] + "--sync" => threads = Some(unsafe{std::num::NonZeroUsize::new_unchecked(1)}), "--load" => { load.push(validate_path(config::DEFAULT_HASHNAME.to_string(), Ensure::File, false)?.to_owned()); }, @@ -225,7 +247,12 @@ where I: IntoIterator 'q' => mode_er = error::Mode::Ignore, 'h' => return Ok(Output::Help), - + + #[cfg(feature="threads")] + 'U' => threads = None, + #[cfg(feature="threads")] + 'S' => threads = Some(unsafe{std::num::NonZeroUsize::new_unchecked(1)}), + 'r' => mode_rec = config::RecursionMode::All, _ => return Err(Error::UnknownArgChar(argchar)), } @@ -257,6 +284,8 @@ where I: IntoIterator }, save, load, + #[cfg(feature="threads")] + max_threads: threads, })) } @@ -271,6 +300,7 @@ pub enum Error ExpectedFile(PathBuf), ExpectedDirectory(PathBuf), UnknownErrorMode(String), + BadNumber(std::num::ParseIntError), Unknown, } @@ -281,6 +311,7 @@ impl fmt::Display for Error { write!(f, "failed to parse args: ")?; match self { + Error::BadNumber(num) => write!(f, "{}", num), Error::NoInput => write!(f, "need at least one input"), Error::Parse(value, typ) => write!(f, "expected a {}, got `{}'", typ, value), Error::UnknownArg(arg) => write!(f, "i don't know how to `{}'", arg), diff --git a/src/config.rs b/src/config.rs index 433eb86..348edfd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,9 @@ use super::*; use lazy_static::lazy_static; +#[cfg(feature="threads")] +use std::num::NonZeroUsize; + #[derive(Debug, Clone)] pub enum RecursionMode { @@ -53,6 +56,10 @@ fn expand_path(path: impl AsRef) -> String shellexpand::tilde(path.as_ref()).to_string() } +/// Default max threads, `None` for unlimited. +#[cfg(feature="threads")] +pub const MAX_THREADS: Option = Some(unsafe{NonZeroUsize::new_unchecked(10)}); + #[derive(Debug)] pub struct Config { @@ -64,6 +71,9 @@ pub struct Config pub save: Vec, /// Load hashes from pub load: Vec, + /// Max number of threads to spawn + #[cfg(feature="threads")] + pub max_threads: Option, //TODO: Implement } #[derive(Debug)] @@ -73,4 +83,7 @@ pub struct Rebase pub load: Vec, /// Rebase to here pub save: Vec, + /// Max number of threads to spawn + #[cfg(feature="threads")] + pub max_threads: Option, //TODO: Implement } diff --git a/src/main.rs b/src/main.rs index 7066d8d..2b5f25e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -80,7 +80,7 @@ fn absolute(path: impl AsRef) -> std::path::PathBuf } #[cfg(feature="threads")] -async fn rebase_one_async(path: impl AsRef, hash: hash::Sha256Hash) -> Result, error::Error> +async fn rebase_one_async(path: impl AsRef, hash: hash::Sha256Hash, semaphore: Option>) -> Result, error::Error> { use std::{ convert::TryInto, @@ -91,6 +91,10 @@ async fn rebase_one_async(path: impl AsRef, hash: hash::Sha256H }, }; let path = path.as_ref(); + let _lock = match semaphore { + Some(sem) => Some(sem.acquire_owned().await), + None => None, + }; let mut file = OpenOptions::new() .read(true) .open(path).await?; @@ -113,11 +117,13 @@ async fn rebase(config: config::Rebase) -> Result<(), Box path::{ Path, }, + sync::Arc, }; use tokio::{ fs::{ OpenOptions, }, + sync::Semaphore, }; let mut hashes = container::DupeMap::new(); for (transient, load) in config.load.iter().map(|x| (false, x)).chain(config.save.iter().map(|x| (true, x))) @@ -140,6 +146,7 @@ async fn rebase(config: config::Rebase) -> Result<(), Box let mut remove = Vec::new(); let mut children = Vec::with_capacity(hashes.cache_len()); + let semaphore = config.max_threads.map(|num| Arc::new(Semaphore::new(num.into()))); for (path, (hash, trans)) in hashes.cache_iter() { if !trans { //Don't rebuild transient ones, this is desired I think? Maybe not... Dunno. @@ -147,8 +154,9 @@ async fn rebase(config: config::Rebase) -> Result<(), Box //Getting hash let path = path.clone(); let hash = *hash; + let semaphore = semaphore.as_ref().map(|semaphore| Arc::clone(semaphore)); children.push(tokio::task::spawn(async move { - rebase_one_async(path, hash).await + rebase_one_async(path, hash, semaphore).await })); } else { remove.push(path.clone()); @@ -266,6 +274,7 @@ async fn main() -> Result<(), Box> log!(Debug, lmode => "Loaded hashes: {}", hashes); log!(Info, lmode => "Starting checks (threaded)"); let hashes = Arc::new(Mutex::new(hashes)); + let semaphore = args.max_threads.map(|num| Arc::new(tokio::sync::Semaphore::new(num.into()))); for path in args.paths.iter() { let path = Path::new(path); @@ -274,9 +283,10 @@ async fn main() -> Result<(), Box> let mode = args.mode.clone(); let path = absolute(&path); let hashes= Arc::clone(&hashes); + let semaphore = semaphore.as_ref().map(|sem| Arc::clone(sem)); children.push(tokio::task::spawn(async move { log!(Debug, mode.logging_mode => " + {:?}", path); - let res = mode.error_mode.handle(proc::do_dir_async(path.clone(), 0, hashes, mode.clone()).await).log_and_forget(&mode.logging_mode, log::Level::Error); + let res = mode.error_mode.handle(proc::do_dir_async(path.clone(), 0, hashes, mode.clone(), semaphore).await).log_and_forget(&mode.logging_mode, log::Level::Error); log!(Info, mode.logging_mode => " - {:?}", path); res })); diff --git a/src/proc.rs b/src/proc.rs index 9c7b131..949cdb6 100644 --- a/src/proc.rs +++ b/src/proc.rs @@ -121,7 +121,7 @@ pub fn process_file>(path: P, set: &mut container::DupeMap) -> Re /// Process a file and add it to the table, returns true if is not a dupe. #[cfg(feature="threads")] -pub async fn process_file_async>(path: P, set: &std::sync::Arc>) -> Result +pub async fn process_file_async>(path: P, set: &std::sync::Arc>, sem: Option>) -> Result { use tokio::{ fs::{ @@ -136,6 +136,10 @@ pub async fn process_file_async>(path: P, set: &std::sync::Arc Some(sem.acquire_owned().await), + None => None, + }; let mut file = OpenOptions::new() .read(true) .open(path).await?; @@ -186,7 +190,7 @@ pub fn do_dir>(dir: P, depth: usize, set: &mut container::DupeMap /// Walk a dir structure and remove all dupes in it #[cfg(feature="threads")] -pub fn do_dir_async + std::marker::Send + std::marker::Sync + 'static>(dir: P, depth: usize, set: std::sync::Arc>, mode: config::Mode) -> futures::future::BoxFuture<'static, Result> +pub fn do_dir_async + std::marker::Send + std::marker::Sync + 'static>(dir: P, depth: usize, set: std::sync::Arc>, mode: config::Mode, semaphore: Option>) -> futures::future::BoxFuture<'static, Result> { use std::sync::Arc; use futures::future::{ @@ -213,9 +217,10 @@ pub fn do_dir_async + std::marker::Send + std::marker::Sync + 'st let set = Arc::clone(&set); let cmode = cmode.clone(); let mode = mode.clone(); + let semaphore = semaphore.as_ref().map(|sem| Arc::clone(sem)); children.push(tokio::task::spawn(async move { log!(Info, cmode.logging_mode => "OK {:?}", obj); - match mode.handle(do_dir_async(obj, depth+1, set, cmode).await) { + match mode.handle(do_dir_async(obj, depth+1, set, cmode, semaphore).await) { Ok(v) => Ok(v.unwrap_or_default()), Err(v) => Err(v), } @@ -224,8 +229,9 @@ pub fn do_dir_async + std::marker::Send + std::marker::Sync + 'st let set = Arc::clone(&set); let mode = mode.clone(); let cmode = cmode.clone(); + let semaphore = semaphore.as_ref().map(|sem| Arc::clone(sem)); workers.push(tokio::task::spawn(async move { - match mode.handle(process_file_async(&obj, &set).await) { + match mode.handle(process_file_async(&obj, &set, semaphore).await) { Ok(v) => { if v.unwrap_or_default() { log!(Info, cmode.logging_mode => "OK {:?}", obj);