From 2ef2744eca9861588a05e538ab0cb3cc401e77fb Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 21 Feb 2025 10:24:22 +0100 Subject: [PATCH] refactoring compare-exchange function --- Cargo.toml | 1 - src/praetor/mod.rs | 134 ++++++++++++++++++++++--------------------- src/praetor/tests.rs | 32 ++++++++++- 3 files changed, 100 insertions(+), 67 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 23c44e7..212548b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ never-local = [] [profile.bench] debug = true -# opt-level = 0 [dependencies] diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index 183eb75..f18c35f 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -4,7 +4,6 @@ mod util { marker::PhantomData, mem::ManuallyDrop, num::NonZero, - ops::{Deref, DerefMut}, ptr::NonNull, sync::atomic::{AtomicPtr, Ordering}, }; @@ -59,6 +58,54 @@ mod util { self.0.load(order).addr() & Self::mask() } + /// returns tag + #[inline(always)] + fn compare_exchange_tag_inner( + &self, + old: usize, + new: usize, + success: Ordering, + failure: Ordering, + cmpxchg: fn( + &AtomicPtr<()>, + *mut (), + *mut (), + Ordering, + Ordering, + ) -> Result<*mut (), *mut ()>, + ) -> Result { + let mask = Self::mask(); + let old_ptr = self.0.load(failure); + + let old = old_ptr.with_addr((old_ptr.addr() & !mask) | (old & mask)); + let new = old_ptr.with_addr((old_ptr.addr() & !mask) | (new & mask)); + + let result = cmpxchg(&self.0, old, new, success, failure); + + result + .map(|ptr| ptr.addr() & mask) + .map_err(|ptr| ptr.addr() & mask) + } + + /// returns tag + #[inline] + #[allow(dead_code)] + pub fn compare_exchange_tag( + &self, + old: usize, + new: usize, + success: Ordering, + failure: Ordering, + ) -> Result { + self.compare_exchange_tag_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange, + ) + } + /// returns tag #[inline] pub fn compare_exchange_weak_tag( @@ -68,17 +115,13 @@ mod util { success: Ordering, failure: Ordering, ) -> Result { - let mask = Self::mask(); - let old_ptr = self.0.load(failure); - - let old = old_ptr.with_addr((old_ptr.addr() & !mask) | (old & mask)); - let new = old_ptr.with_addr((old_ptr.addr() & !mask) | (new & mask)); - - let result = self.0.compare_exchange_weak(old, new, success, failure); - - result - .map(|ptr| ptr.addr() & mask) - .map_err(|ptr| ptr.addr() & mask) + self.compare_exchange_tag_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange_weak, + ) } #[allow(dead_code)] @@ -122,43 +165,6 @@ mod util { (ptr, tag) } } - - #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] - #[repr(transparent)] - pub struct SendPtr(NonNull); - - impl SendPtr { - #[allow(dead_code)] - pub fn as_ptr(&self) -> *mut T { - self.0.as_ptr() - } - #[allow(dead_code)] - pub unsafe fn new_unchecked(t: *const T) -> Self { - unsafe { Self(NonNull::new_unchecked(t.cast_mut())) } - } - #[allow(dead_code)] - pub fn new(t: *const T) -> Option { - NonNull::new(t.cast_mut()).map(Self) - } - #[allow(dead_code)] - pub fn cast(self) -> SendPtr { - SendPtr(self.0.cast::()) - } - } - - impl Deref for SendPtr { - type Target = T; - - fn deref(&self) -> &Self::Target { - unsafe { &*self.0.as_ptr() } - } - } - - impl DerefMut for SendPtr { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { &mut *self.0.as_ptr() } - } - } } mod job { @@ -216,9 +222,9 @@ mod job { let inline = Self::is_inline(); if inline { - let this = MaybeUninit::new(Self(MaybeUninit::uninit())); + let mut this = MaybeUninit::new(Self(MaybeUninit::uninit())); unsafe { - this.as_ptr().cast::().cast_mut().write(value); + this.as_mut_ptr().cast::().write(value); this.assume_init() } } else { @@ -549,7 +555,7 @@ mod job { // after sleeping, state should be `Finished` } Err(state) => { - debug_assert_ne!(state, JobState::Pending as usize); + // debug_assert_ne!(state, JobState::Pending as usize); if state == JobState::Finished as usize { let err = unsafe { (&mut *self.err_or_link.get()).error.take() }; @@ -593,8 +599,8 @@ mod job { } return; } - Err(state) => { - debug_assert_ne!(state, JobState::Empty as usize); + Err(_) => { + // debug_assert_ne!(state, JobState::Empty as usize); eprintln!("######## what the sigma?"); spin.spin(); @@ -617,7 +623,7 @@ mod job { } } - fn complete(&self, result: std::thread::Result) { + pub(crate) fn complete(&self, result: std::thread::Result) { let mut spin = SpinWait::new(); loop { match self.harness_and_state.compare_exchange_weak_tag( @@ -630,8 +636,8 @@ mod job { debug_assert_eq!(state, JobState::Pending as usize); break; } - Err(state) => { - debug_assert_ne!(state, JobState::Pending as usize); + Err(_) => { + // debug_assert_ne!(state, JobState::Pending as usize); spin.spin(); } } @@ -662,8 +668,6 @@ mod job { } } - impl Job {} - #[allow(dead_code)] pub struct HeapJob { f: F, @@ -702,12 +706,14 @@ mod job { pub struct StackJob { f: UnsafeCell>, + _phantom: PhantomPinned, } impl StackJob { pub fn new(f: F) -> Self { Self { f: UnsafeCell::new(ManuallyDrop::new(f)), + _phantom: PhantomPinned, } } @@ -715,7 +721,7 @@ mod job { unsafe { ManuallyDrop::take(&mut *self.f.get()) } } - pub fn as_job(&self) -> Job<()> + pub fn as_job(self: Pin<&Self>) -> Job<()> where F: FnOnce(&super::Scope) -> T + Send, T: Send, @@ -735,7 +741,7 @@ mod job { } Job::new(harness::, unsafe { - NonNull::new_unchecked(self as *const _ as *mut ()) + NonNull::new_unchecked(&*self as *const _ as *mut ()) }) } } @@ -921,13 +927,13 @@ impl Scope { A: FnOnce(&Self) -> RA + Send, B: FnOnce(&Self) -> RB + Send, { - let a = StackJob::new(move |scope: &Scope| { + let a = pin!(StackJob::new(move |scope: &Scope| { scope.tick(); a(scope) - }); + })); - let job = pin!(a.as_job()); + let job = pin!(a.as_ref().as_job()); self.push_front(job.as_ref()); let rb = b(self); diff --git a/src/praetor/tests.rs b/src/praetor/tests.rs index f41cf03..dbabd5f 100644 --- a/src/praetor/tests.rs +++ b/src/praetor/tests.rs @@ -7,6 +7,28 @@ fn pin_ptr(pin: &Pin<&mut T>) -> NonNull { unsafe { NonNull::new_unchecked(a.cast_mut()) } } +#[test] +fn new_job() { + let mut list = JobList::new(); + + let stack = pin!(StackJob::new(|_: &Scope| 3 + 3)); + + let job = pin!(stack.as_ref().as_job()); + unsafe { + list.push_front(job.as_ref()); + } + + unsafe { + let job_ref = list.pop_front().unwrap().cast::>().as_ref(); + job_ref.set_pending(); + + _ = stack.unwrap(); + job_ref.complete(Ok(6)); + let result = job_ref.wait(); + assert_eq!(result.ok(), Some(6)); + } +} + #[test] fn job_list_pop_back() { let mut list = JobList::new(); @@ -144,6 +166,8 @@ fn unlink_job_single() { #[test] fn tagged_ptr_exchange() { let boxed = Box::into_raw(Box::new(42usize)); + let _guard = DropGuard::new(|| drop(unsafe { Box::from_raw(boxed) })); + let ptr = TaggedAtomicPtr::<_, 3>::new(boxed, 1usize); assert_eq!(ptr.tag(Ordering::Relaxed), 1); @@ -151,7 +175,9 @@ fn tagged_ptr_exchange() { ptr.set_tag(3, Ordering::Release, Ordering::Relaxed); assert_eq!(ptr.tag(Ordering::Relaxed), 3); assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); - let old = ptr.compare_exchange_weak_tag(3, 4, Ordering::Release, Ordering::Relaxed); + + let old = ptr.compare_exchange_tag(3, 4, Ordering::Release, Ordering::Relaxed); + assert_eq!(old, Ok(3)); assert_eq!(ptr.tag(Ordering::Relaxed), 4); assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); @@ -160,12 +186,14 @@ fn tagged_ptr_exchange() { #[test] fn tagged_ptr_exchange_failure() { let boxed = Box::into_raw(Box::new(42usize)); + let _guard = DropGuard::new(|| drop(unsafe { Box::from_raw(boxed) })); + let ptr = TaggedAtomicPtr::<_, 3>::new(boxed, 1usize); assert_eq!(ptr.tag(Ordering::Relaxed), 1); assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); - let old = ptr.compare_exchange_weak_tag(3, 4, Ordering::Release, Ordering::Relaxed); + let old = ptr.compare_exchange_tag(3, 4, Ordering::Release, Ordering::Relaxed); assert_eq!(old, Err(1)); assert_eq!(ptr.tag(Ordering::Relaxed), 1); assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed);