diff --git a/src/file.rs b/src/file.rs index 72505c5..51bf3c3 100644 --- a/src/file.rs +++ b/src/file.rs @@ -21,6 +21,27 @@ pub use self::{ pub mod memory; +#[cold] +fn last_os_error() -> io::Error +{ + io::Error::last_os_error() +} + +fn stat_file_size(file: &T) -> io::Result +where T: AsRawFd +{ + use libc::fstat; + let fd = file.as_raw_fd(); + let sz = unsafe { + let mut stat = std::mem::MaybeUninit::uninit(); + if fstat(fd, stat.as_mut_ptr()) != 0 { + return Err(last_os_error()); + } + stat.assume_init().st_size & i64::MAX + } as u64; + Ok(sz) +} + #[derive(Debug)] enum MaybeMappedInner { diff --git a/src/file/memory.rs b/src/file/memory.rs index 86bec92..aa2a98b 100644 --- a/src/file/memory.rs +++ b/src/file/memory.rs @@ -12,6 +12,11 @@ use libc::{ MFD_HUGETLB, ftruncate, + + lseek, + SEEK_SET, + SEEK_CUR, + SEEK_END, }; use std::{ ffi::{CStr, CString}, @@ -146,15 +151,15 @@ impl ops::Deref for NamedMemoryFile } } -#[cold] -fn last_os_error() -> io::Error -{ - io::Error::last_os_error() -} - //impl `MemoryFile` (memfd_create() fd wrapper) impl MemoryFile { + /// The total size of the `memfd` memory file. + #[inline] + pub fn size(&self) -> io::Result + { + stat_file_size(&self.0) + } /// Create a new, empty, memory file with the specified raw C-string pointer name and the default set of flags. /// # Safety @@ -226,7 +231,7 @@ impl MemoryFile { stackalloc::alloca_zeroed(name.len()+1, move |cname| { cname[..name.len()].copy_from_slice(name.as_bytes()); - debug_assert_ne!(cname[name.len()], 0, "Copied name not nul-terminated for `memfd_create()` call."); + debug_assert_eq!(cname[name.len()], 0, "Copied name not nul-terminated for `memfd_create()` call."); // SAFETY: We have initialised `cname[..]`, and we know the final byte will be unsafe { @@ -372,6 +377,11 @@ raw::impl_io_for_fd!(MemoryFile => .0.as_raw_fd()); #[cfg(test)] mod test { + use std::io::{ + self, + Write, Read, Seek, + }; + #[test] fn default_flag_cloexec_visible() { @@ -380,5 +390,48 @@ mod test { assert_eq!(super::DEFAULT_FLAGS, cfg!(feature="default-cloexec").then(|| super::MFD_CLOEXEC).unwrap_or_default(), "Compile-time default creation flags are not in accordance with provided global crate configuration"); } - //TODO: Test if `NamedMemoryFile.get_path()` works properly. + fn mem_seek() + { + // Test `SEEK_SET`. + let mut file = super::MemoryFile::new_named("test-mem-seek").expect("Failed to create new file 'memfd:test-mem-seek'"); + file.write_all(b"hello world").expect("Failed to write to memfd"); + file.seek(io::SeekFrom::Start(6)).expect("Failed to seek to 6.."); + + let mut buf = Vec::new(); + let sz = file.read_to_end(&mut buf).expect("Failed to read from memfd"); + + assert_eq!(sz, 5, "Invalid number of bytes read"); + assert_eq!(&buf, b"world", "Invalid string data read"); + + file.seek_relative(-(sz as i64)).expect("Failed to seek -5"); + + buf.clear(); + assert_eq!(file.read_to_end(&mut buf).expect("Failed to read from memfd (2nd pass)"), sz, "Invalid number of bytes re-read from file"); + assert_eq!(&buf, b"world", "Invalid string data read in 2nd pass"); + } + + #[test] + fn open_hugetlb() + { + const SIZE: usize = 1024 * 1024 * 4; // 1GB + + let ht = super::HugePage::Static(crate::MapHugeFlag::HUGE_2MB); // 2MB hugetlb + + let file = super::MemoryFile::new_named_with_size_hugetlb("test-mem-hugetlb", SIZE, dbg!(ht.compute_huge()).expect("Invalid `HugePage` spec.")) + .expect("Failed to open hugetlb memfile"); + + let mut file = crate::MappedFile::new(file, 4096, crate::Perm::Writeonly, crate::Flags::Shared.with_hugetlb(ht)) + .expect("Failed to map hugetlb memfile for writing"); + + write!(&mut &mut file[..], "Hello").expect("Failed to write to mapped file (1)"); + write!(&mut &mut file[2000..], "World").expect("Failed to write to mapped file (2)"); + + file.flush(crate::Flush::Wait).expect("Failed to flush mapped data to memfd"); + + let file = crate::MappedFile::new(super::ManagedFD::alias(file.inner()).expect("dup() failed"), 4096, crate::Perm::Readonly, crate::Flags::Shared) + .expect("Failed to map hugetlb memfile for reading"); + + assert_eq!(&file[..5], b"Hello", "Invalid first write at 0"); + assert_eq!(&file[2000..2005], b"World", "Invalid second write at 2000"); + } } diff --git a/src/file/raw.rs b/src/file/raw.rs index 0aa1c0c..dee8762 100644 --- a/src/file/raw.rs +++ b/src/file/raw.rs @@ -128,6 +128,12 @@ macro_rules! impl_io_for_fd { ($type:ty => .$($fd_path:tt)+) => { const _:() = { use std::io; + use libc::{ + SEEK_SET, + SEEK_CUR, + SEEK_END, + lseek, + }; #[inline(always)] fn check_error() -> bool { @@ -203,6 +209,36 @@ macro_rules! impl_io_for_fd { } } } + + impl io::Seek for $type { + #[inline] + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + let fd = self.$($fd_path)+; + let (off, pos) = match pos { + io::SeekFrom::Current(off) => (off, SEEK_CUR), + io::SeekFrom::Start(off) => (off as i64, SEEK_SET), + io::SeekFrom::End(off) => (off, SEEK_END), + }; + + match unsafe { + lseek(fd, off, pos) + } { + -1 => Err(last_os_error()), + off => Ok(off as u64), + } + } + + #[inline] + fn seek_relative(&mut self, offset: i64) -> io::Result<()> { + let fd = self.$($fd_path)+; + match unsafe { + lseek(fd, offset, SEEK_CUR) + } { + -1 => Err(last_os_error()), + _ => Ok(()), + } + } + } }; }; } diff --git a/src/flags.rs b/src/flags.rs index bcb8132..bc591e3 100644 --- a/src/flags.rs +++ b/src/flags.rs @@ -35,7 +35,7 @@ impl Flags /// # `hugetlb` support /// For adding huge-page mapping flags to these, use `with_hugetlb()` instead. #[inline] - pub unsafe fn chain_with(self, flags: impl MapFlags) -> impl MapFlags + pub unsafe fn chain_with(self, flags: F) -> impl MapFlags + use { struct Chained(Flags, T); @@ -63,7 +63,7 @@ impl Flags { #[inline(always)] fn get_mmap_flags(&self) -> c_int { - self.0.get_flags() | self.1.compute_huge().map(MapHugeFlag::get_mask).unwrap_or(0) + self.0.get_flags() | (self.1.compute_huge().map(MapHugeFlag::get_mask).unwrap_or(0) | libc::MAP_HUGETLB) } } diff --git a/src/lib.rs b/src/lib.rs index b88bd8a..0595f2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -144,8 +144,8 @@ impl MappedFile { /// If `mmap()` succeeds, but returns an invalid address (e.g. 0) pub fn try_new(file: T, len: usize, perm: Perm, flags: impl flags::MapFlags) -> Result> { - const NULL: *mut libc::c_void = ptr::null_mut(); + let fd = file.as_raw_fd(); let slice = match unsafe { mmap(ptr::null_mut(), len, perm.get_prot(), flags.get_mmap_flags(), fd, 0)