From 0fcca646ef6e2ab0e88b7d0193e1d5e18e690867 Mon Sep 17 00:00:00 2001 From: Avril Date: Fri, 26 Mar 2021 21:39:41 +0000 Subject: [PATCH] fix size bug added drop for types that need dropping --- src/lib.rs | 49 +++++++++++++++++++++++++++++-------------------- src/tests.rs | 14 ++++++++++++++ 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 000405c..b8fa7c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ use std::{ mem::{ + self, MaybeUninit, ManuallyDrop, }, @@ -101,6 +102,12 @@ where F: FnOnce(&mut [MaybeUninit]) -> T } } +#[inline(always)] fn align_buffer_to(ptr: *mut u8) -> *mut T +{ + use std::mem::align_of; + ((ptr as usize) + align_of::() - (ptr as usize) % align_of::()) as *mut T +} + #[inline(always)] unsafe fn slice_assume_init_mut(buf: &mut [MaybeUninit]) -> &mut [T] { &mut *(buf as *mut [MaybeUninit] as *mut [T]) // MaybeUninit::slice_assume_init_mut() @@ -121,11 +128,6 @@ where F: FnOnce(&mut [u8]) -> T }) } -#[inline(always)] fn align_buffer_to(ptr: *mut u8) -> *mut T -{ - use std::mem::align_of; - ((ptr as usize) + align_of::() - (ptr as usize) % align_of::()) as *mut T -} /// Allocate a runtime length slice of uninitialised `T` on the stack, call `callback` with this buffer, and then deallocate the buffer. /// @@ -133,27 +135,16 @@ where F: FnOnce(&mut [u8]) -> T #[inline] pub fn stackalloc_uninit(size: usize, callback: F) -> U where F: FnOnce(&mut [MaybeUninit]) -> U { - let size = (std::mem::size_of::() * size) + std::mem::align_of::(); - alloca(size, move |buf| { + let size_bytes = (std::mem::size_of::() * size) + std::mem::align_of::(); + alloca(size_bytes, move |buf| { let abuf = align_buffer_to::>(buf.as_mut_ptr() as *mut u8); + debug_assert!(buf.as_ptr_range().contains(&(abuf as *const _ as *const MaybeUninit))); unsafe { callback(slice::from_raw_parts_mut(abuf, size)) } }) } -/// Allocate a runtime length slice of `T` on the stack, fill it by cloning `init`, call `callback` with this buffer, and then deallocate the buffer. -#[inline] pub fn stackalloc(size: usize, init: T, callback: F) -> U -where F: FnOnce(&mut [T]) -> U, -T: Clone -{ - stackalloc_uninit(size, move |buf| { - buf.fill_with(move || MaybeUninit::new(init.clone())); - // SAFETY: We have initialised the buffer above - callback(unsafe { slice_assume_init_mut(buf) }) - }) -} - /// Allocate a runtime length slice of `T` on the stack, fill it by calling `init_with`, call `callback` with this buffer, and then deallocate the buffer. #[inline] pub fn stackalloc_with(size: usize, mut init_with: I, callback: F) -> U where F: FnOnce(&mut [T]) -> U, @@ -162,10 +153,28 @@ I: FnMut() -> T stackalloc_uninit(size, move |buf| { buf.fill_with(move || MaybeUninit::new(init_with())); // SAFETY: We have initialised the buffer above - callback(unsafe { slice_assume_init_mut(buf) }) + let buf = unsafe { slice_assume_init_mut(buf) }; + let ret = callback(buf); + if mem::needs_drop::() + { + // SAFETY: We have initialised the buffer above + unsafe { + ptr::drop_in_place(buf as *mut _); + } + } + ret }) } +/// Allocate a runtime length slice of `T` on the stack, fill it by cloning `init`, call `callback` with this buffer, and then deallocate the buffer. +#[inline] pub fn stackalloc(size: usize, init: T, callback: F) -> U +where F: FnOnce(&mut [T]) -> U, +T: Clone +{ + stackalloc_with(size, move || init.clone(), callback) +} + + /// Allocate a runtime length slice of `T` on the stack, fill it by calling `T::default()`, call `callback` with this buffer, and then deallocate the buffer. #[inline] pub fn stackalloc_with_default(size: usize, callback: F) -> U where F: FnOnce(&mut [T]) -> U, diff --git a/src/tests.rs b/src/tests.rs index 311c005..22982d2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -59,6 +59,20 @@ fn raw_trampoline() assert_eq!(output, (0..size).sum::()); } +#[test] fn non_primitive_type() +{ + assert_eq!(super::stackalloc(10, String::from("Hello world"), |strings| { + strings.iter().cloned().collect::() + }), std::iter::repeat(String::from("Hello world")).take(10).collect::()); +} + +#[test] fn primitive_type() +{ + assert_eq!(super::stackalloc(10, 12.0, |floats| { + floats.iter().copied().map(|x| x / 2.0).sum::() + }), std::iter::repeat(12.0).take(10).map(|x| x / 2.0).sum()); +} + #[cfg(nightly)] mod bench {