diff --git a/Cargo.lock b/Cargo.lock index 76b1a97..06ebde6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,4 +4,4 @@ version = 3 [[package]] name = "reverse" -version = "0.3.0" +version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 44b3726..288a1f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "reverse" -version = "0.3.0" +version = "0.4.0" authors = ["Avril "] edition = "2018" @@ -25,6 +25,9 @@ ignore-output-errors = [] # Ignore invalid arguments instead of removing invalid UTF8 characters if they exist in the argument ignore-invalid-args = [] +# Operate on OsString/byte arrays instead of strings; so non-utf8 characters will be preserved. +byte-strings = [] + [profile.release] opt-level = 3 lto = "fat" diff --git a/src/main.rs b/src/main.rs index f62c4b4..20711e5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,9 +28,69 @@ fn binsearch<'a, V: ?Sized, T: PartialEq + 'a>(slice: &'a [T], find: &V) -> O } } -fn collect_input() -> Box + 'static> +trait Input: AsRef<[u8]> + std::fmt::Debug{} +impl Input for T +where T: AsRef<[u8]> + std::fmt::Debug{} + +fn collect_input() -> Box + 'static> { - // TODO: Use non-panicking functions for both reading lines and reading args, just skip invalid lines/args + use std::{ + ffi::{OsStr, OsString}, + os::unix::ffi::*, + }; + + #[derive(Debug)] + enum MaybeUTF8 + { + UTF8(String), + Raw(OsString), + } + + impl AsRef<[u8]> for MaybeUTF8 + { + #[inline] + fn as_ref(&self) -> &[u8] + { + match self { + Self::UTF8(string) => string.as_bytes(), + Self::Raw(raw) => raw.as_bytes(), + } + } + } + + impl From for MaybeUTF8 + { + #[inline(always)] + fn from(from: String) -> Self + { + Self::UTF8(from) + } + } + + impl From for MaybeUTF8 + { + #[inline(always)] + fn from(from: OsString) -> Self + { + Self::Raw(from) + } + } + + #[allow(dead_code)] + impl MaybeUTF8 + { + #[inline(always)] + pub fn from_raw_bytes(bytes: &[u8]) -> Self + { + Self::Raw(OsStr::from_bytes(bytes).to_os_string()) + } + #[inline(always)] + #[deprecated(note="XXX: TODO: Only use this if the read_until() into vec does not add the '\n' into the vec as well. Otherwise, *always* use this.")] + pub fn from_raw_vec(vec: Vec) -> Self + { + Self::Raw(OsString::from_vec(vec)) + } + } if std::env::args_os().len() <= 1 { use std::io::{ @@ -38,20 +98,62 @@ fn collect_input() -> Box + 'static> BufRead, }; // No args, collect stdin lines - Box::new(io::stdin() - .lock() - .lines() - .filter_map(Result::ok)) + if !cfg!(feature="byte-strings") { + // Collect utf8 string lines + Box::new(io::stdin() + .lock() + .lines() + .filter_map(Result::ok) + .map(MaybeUTF8::from)) + } else { + // Collect arbitrary byte strings + struct OsLineReader<'a>(io::StdinLock<'a>, Vec); + + impl<'a> Iterator for OsLineReader<'a> + { + type Item = MaybeUTF8; + fn next(&mut self) -> Option + { + Some(match handle_fmt_err_or(self.0.read_until(b'\n', &mut self.1), || 0) { + 0 => return None, + read_sz => { + let line = MaybeUTF8::from_raw_bytes(&self.1[..]); //TODO: XXX: If self.1 here does not have the '\n' added into it by read_until(); use from_raw_vec(self.1.clone()) instead; it'll be more efficient. + self.1.clear(); + //TODO: todo!("Do we need read_sz ({read_sz}) at all here? Will the `\n` be inside the read string?"); + line + }, + }) + } + } + + Box::new(OsLineReader(io::stdin().lock(), // Acquire the lock until the iterator is consumed (like all other paths in this function) + Vec::with_capacity(4096))) // Buffer of size 4k + } } else { // Has arguments, return them - if cfg!(feature="ignore-invalid-args") { - Box::new(std::env::args_os().skip(1).filter_map(|os| os.into_string().ok())) + if cfg!(feature="byte-strings") { + Box::new(std::env::args_os().skip(1).map(MaybeUTF8::from)) + } else if cfg!(feature="ignore-invalid-args") { + Box::new(std::env::args_os().skip(1).filter_map(|os| os.into_string().ok()).map(MaybeUTF8::from)) } else { - Box::new(std::env::args_os().skip(1).map(|os| os.to_string_lossy().into_owned())) + Box::new(std::env::args_os().skip(1).map(|os| os.to_string_lossy().into_owned().into())) } } } +#[cfg_attr(feature="ignore-output-errors", inline)] +fn handle_fmt_err_or(res: std::io::Result, or: F) -> T +where F: FnOnce() -> T +{ + #[cfg(not(feature="ignore-output-errors"))] + if let Err(e) = res { + eprintln!("[!] failed to write line: {e}"); + or() + } + #[cfg(feature="ignore-output-errors")] + res.unwrap_or_else(|_| or()) +} + #[cfg_attr(feature="ignore-output-errors", inline(always))] fn handle_fmt_err(res: std::io::Result) { @@ -63,7 +165,7 @@ fn handle_fmt_err(res: std::io::Result) } fn main() { - let mut args: Vec = collect_input().collect(); + let mut args: Vec<_> = collect_input().collect(); reverse(&mut args[..]); //eprintln!("{:?}", binsearch(&args[..], "1")); // It works! #[cfg(feature="output-lines")] @@ -84,7 +186,7 @@ fn main() { } else { //writeln!(&mut out, "{}", x) - out.write(x.as_bytes()) + out.write(x.as_ref()) .and_then(|_| out.write(b"\n")) .map(|_| {}) }