From 23d022b5fe0e153f09ba2194104cde0ddb3fb62f Mon Sep 17 00:00:00 2001 From: Avril Date: Mon, 25 Apr 2022 09:37:14 +0100 Subject: [PATCH] Added working memfile implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fortune for collect's current commit: Blessing − 吉 --- Cargo.toml | 3 ++ src/main.rs | 97 +++++++++++++++++++++++++++++++++++++++++++------- src/memfile.rs | 26 ++++++++++++++ 3 files changed, 113 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1022970..84b0d2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,9 @@ default = ["jemalloc", "memfile", "logging", "tracing/release_max_level_warn"] # # TODO: mmap, memfd_create() ver memfile = ["bitflags", "lazy_static", "stackalloc"] +# When unable to determine the size of the input, preallocate the buffer to a multiple of the system page-size before writing to it. This can save extra `ftruncate()` calls, but will also result in the buffer needing to be truncated to the correct size at the end if the sizes as not matched. +memfile-preallocate = ["memfile"] + # bytes: use `bytes` crate for collecting instead of `std::vec` # Use jemalloc instead of system malloc. diff --git a/src/main.rs b/src/main.rs index 713a194..0570690 100644 --- a/src/main.rs +++ b/src/main.rs @@ -260,12 +260,19 @@ fn non_map_work() -> eyre::Result<()> #[inline] #[cfg(feature="memfile")] fn map_work() -> eyre::Result<()> -{ - extern "C" { - fn getpagesize() -> libc::c_int; - } - /// 8 pages - const DEFAULT_BUFFER_SIZE: fn () -> Option = || { unsafe { std::num::NonZeroUsize::new((getpagesize() as usize) * 8) } }; +{ + const DEFAULT_BUFFER_SIZE: fn () -> Option = || { + cfg_if!{ + if #[cfg(feature="memfile-preallocate")] { + extern "C" { + fn getpagesize() -> libc::c_int; + } + unsafe { std::num::NonZeroUsize::new(getpagesize() as usize * 8) } + } else { + std::num::NonZeroUsize::new(0) + } + } + }; if_trace!(trace!("strategy: mapped memory file")); @@ -290,8 +297,15 @@ fn map_work() -> eyre::Result<()> let (mut file, read) = { let stdin = io::stdin(); - let buffsz = try_get_size(&stdin).or_else(DEFAULT_BUFFER_SIZE); - if_trace!(trace!("Attempted determining input size: {:?}", buffsz)); + let buffsz = try_get_size(&stdin); + if_trace!(debug!("Attempted determining input size: {:?}", buffsz)); + let buffsz = buffsz.or_else(DEFAULT_BUFFER_SIZE); + if_trace!(if let Some(buf) = buffsz.as_ref() { + trace!("Failed to determine input size: preallocating to {}", buf); + } else { + trace!("Failed to determine input size: alllocating on-the-fly (no preallocation)"); + }); + let mut file = memfile::create_memfile(Some("collect-buffer"), buffsz.map(|x| x.get()).unwrap_or(0)) .with_section(|| format!("{:?}", buffsz).header("Deduced input buffer size")) @@ -300,15 +314,72 @@ fn map_work() -> eyre::Result<()> let read = io::copy(&mut stdin.lock(), &mut file) .with_section(|| format!("{:?}", file).header("Memory buffer file"))?; - { + let read = if cfg!(any(feature="memfile-preallocate", debug_assertions)) { use io::*; + let sp = file.stream_position(); + let sl = memfile::stream_len(&file); + + if_trace!(trace!("Stream position after read: {:?}", sp)); + if_trace!(trace!("Stream length after read: {:?}", sp)); + let read = match sp.as_ref() { + Ok(&v) if v != read => { + if_trace!(warn!("Reported read value not equal to memfile stream position: expected from `io::copy()`: {v}, got {read}")); + v + }, + Ok(&x) => { + if_trace!(trace!("Reported memfile stream position and copy result equal: {x} == {}", read)); + x + }, + Err(e) => { + if_trace!(error!("Could not report memfile stream position, ignoring check on {read}: {e}")); + read + }, + }; + + let truncate_stream = |bad: u64, good: u64| { + use std::num::NonZeroU64; + use std::borrow::Cow; + file.set_len(good) + .map(|_| good) + .with_section(|| match NonZeroU64::new(bad) {Some (b) => Cow::Owned(b.get().to_string()), None => Cow::Borrowed("") }.header("Original (bad) length")) + .with_section(|| good.header("New (correct) length")) + .wrap_err(eyre!("Failed to truncate stream to correct length") + .with_section(|| format!("{:?}", file).header("Memory buffer file"))) + }; + + let read = match sl.as_ref() { + Ok(&v) if v != read => { + if_trace!(warn!("Reported read value not equal to memfile stream length: expected from `io::copy()`: {read}, got {v}")); + if_trace!(debug!("Attempting to correct memfile stream length from {v} to {read}")); + + truncate_stream(v, read)? + }, + Ok(&v) => { + if_trace!(trace!("Reported memfile stream length and copy result equal: {v} == {}", read)); + v + }, + Err(e) => { + if_trace!(error!("Could not report memfile stream length, ignoring check on {read}: {e}")); + if_trace!(warn!("Attempting to correct memfile stream length anyway")); + if let Err(e) = truncate_stream(0, read) { + if_trace!(error!("Truncate failed: {e}")); + } + + read + } + }; + file.seek(SeekFrom::Start(0)) .with_section(|| read.header("Actual read bytes")) .wrap_err(eyre!("Failed to seek back to start of memory buffer file for output") - .with_section(|| unwrap_int_string(file.stream_position()).header("Memfile position")) + .with_section(|| unwrap_int_string(sp).header("Memfile position")) /*.with_section(|| file.stream_len().map(|x| x.to_string()) - .unwrap_or_else(|e| format!("")).header("Memfile full length"))*/)?; - } + .unwrap_or_else(|e| format!("")).header("Memfile full length"))*/)?; + + read + } else { + read + }; (file, usize::try_from(read) .wrap_err(eyre!("Failed to convert read bytes to `usize`") @@ -317,7 +388,7 @@ fn map_work() -> eyre::Result<()> .with_suggestion(|| "It is likely you are running on a 32-bit ptr width machine and this input exceeds that of the maximum 32-bit unsigned integer value") .with_note(|| usize::MAX.header("Maximum value of `usize`")))?) }; - if_trace!(info!("collected {read} from stdin. starting write.")); + if_trace!(info!("collected {} from stdin. starting write.", read)); let written = io::copy(&mut file, &mut io::stdout().lock()) diff --git a/src/memfile.rs b/src/memfile.rs index 18a8bba..a0e7431 100644 --- a/src/memfile.rs +++ b/src/memfile.rs @@ -24,6 +24,22 @@ const MEMFD_CREATE_FLAGS: libc::c_uint = libc::MFD_CLOEXEC; #[repr(transparent)] pub struct RawFile(fd::RawFileDescriptor); +/// Attempt to get the length of a stream's file descriptor +#[inline] +#[cfg_attr(feature="logging", instrument(level="debug", err, skip_all, fields(from_fd = from.as_raw_fd())))] +pub fn stream_len(from: &(impl AsRawFd + ?Sized)) -> io::Result +{ + let mut stat = std::mem::MaybeUninit::uninit(); + match unsafe { libc::fstat(from.as_raw_fd(), stat.as_mut_ptr()) } { + -1 => Err(io::Error::last_os_error()), + _ => { + let stat = unsafe { stat.assume_init() }; + debug_assert!(stat.st_size >= 0, "bad stat size"); + Ok(stat.st_size as u64) + }, + } +} + /// Create an in-memory `File`, with an optional name #[cfg_attr(feature="logging", instrument(level="info", err))] pub fn create_memfile(name: Option<&str>, size: usize) -> eyre::Result @@ -230,6 +246,16 @@ impl RawFile , fallocate(fd.0.get(), 0, 0, len.try_into() .map_err(|_| Allocate(None, len))?) , Allocate(Some(fd.fileno().clone()), len))?; + if cfg!(debug_assertions) { + if_trace!(trace!("Allocated {len} bytes to memory buffer")); + let seeked; + assert_eq!(attempt_call!(-1 + , { seeked = libc::lseek(fd.0.get(), 0, libc::SEEK_CUR); seeked } + , io::Error::last_os_error()) + .expect("Failed to check seek position in fd") + , 0, "memfd seek position is non-zero after fallocate()"); + if_trace!(if seeked != 0 { warn!("Trace offset is non-zero: {seeked}") } else { trace!("Trace offset verified ok") }); + } } else { if_trace!(trace!("No length provided, skipping fallocate() call")); }