refactoring compare-exchange function

This commit is contained in:
Janis 2025-02-21 10:24:22 +01:00
parent f735348762
commit 2ef2744eca
3 changed files with 100 additions and 67 deletions

View file

@ -14,7 +14,6 @@ never-local = []
[profile.bench]
debug = true
# opt-level = 0
[dependencies]

View file

@ -4,7 +4,6 @@ mod util {
marker::PhantomData,
mem::ManuallyDrop,
num::NonZero,
ops::{Deref, DerefMut},
ptr::NonNull,
sync::atomic::{AtomicPtr, Ordering},
};
@ -59,6 +58,54 @@ mod util {
self.0.load(order).addr() & Self::mask()
}
/// returns tag
#[inline(always)]
fn compare_exchange_tag_inner(
&self,
old: usize,
new: usize,
success: Ordering,
failure: Ordering,
cmpxchg: fn(
&AtomicPtr<()>,
*mut (),
*mut (),
Ordering,
Ordering,
) -> Result<*mut (), *mut ()>,
) -> Result<usize, usize> {
let mask = Self::mask();
let old_ptr = self.0.load(failure);
let old = old_ptr.with_addr((old_ptr.addr() & !mask) | (old & mask));
let new = old_ptr.with_addr((old_ptr.addr() & !mask) | (new & mask));
let result = cmpxchg(&self.0, old, new, success, failure);
result
.map(|ptr| ptr.addr() & mask)
.map_err(|ptr| ptr.addr() & mask)
}
/// returns tag
#[inline]
#[allow(dead_code)]
pub fn compare_exchange_tag(
&self,
old: usize,
new: usize,
success: Ordering,
failure: Ordering,
) -> Result<usize, usize> {
self.compare_exchange_tag_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange,
)
}
/// returns tag
#[inline]
pub fn compare_exchange_weak_tag(
@ -68,17 +115,13 @@ mod util {
success: Ordering,
failure: Ordering,
) -> Result<usize, usize> {
let mask = Self::mask();
let old_ptr = self.0.load(failure);
let old = old_ptr.with_addr((old_ptr.addr() & !mask) | (old & mask));
let new = old_ptr.with_addr((old_ptr.addr() & !mask) | (new & mask));
let result = self.0.compare_exchange_weak(old, new, success, failure);
result
.map(|ptr| ptr.addr() & mask)
.map_err(|ptr| ptr.addr() & mask)
self.compare_exchange_tag_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange_weak,
)
}
#[allow(dead_code)]
@ -122,43 +165,6 @@ mod util {
(ptr, tag)
}
}
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(transparent)]
pub struct SendPtr<T>(NonNull<T>);
impl<T> SendPtr<T> {
#[allow(dead_code)]
pub fn as_ptr(&self) -> *mut T {
self.0.as_ptr()
}
#[allow(dead_code)]
pub unsafe fn new_unchecked(t: *const T) -> Self {
unsafe { Self(NonNull::new_unchecked(t.cast_mut())) }
}
#[allow(dead_code)]
pub fn new(t: *const T) -> Option<Self> {
NonNull::new(t.cast_mut()).map(Self)
}
#[allow(dead_code)]
pub fn cast<U>(self) -> SendPtr<U> {
SendPtr(self.0.cast::<U>())
}
}
impl<T> Deref for SendPtr<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.0.as_ptr() }
}
}
impl<T> DerefMut for SendPtr<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.0.as_ptr() }
}
}
}
mod job {
@ -216,9 +222,9 @@ mod job {
let inline = Self::is_inline();
if inline {
let this = MaybeUninit::new(Self(MaybeUninit::uninit()));
let mut this = MaybeUninit::new(Self(MaybeUninit::uninit()));
unsafe {
this.as_ptr().cast::<T>().cast_mut().write(value);
this.as_mut_ptr().cast::<T>().write(value);
this.assume_init()
}
} else {
@ -549,7 +555,7 @@ mod job {
// after sleeping, state should be `Finished`
}
Err(state) => {
debug_assert_ne!(state, JobState::Pending as usize);
// debug_assert_ne!(state, JobState::Pending as usize);
if state == JobState::Finished as usize {
let err = unsafe { (&mut *self.err_or_link.get()).error.take() };
@ -593,8 +599,8 @@ mod job {
}
return;
}
Err(state) => {
debug_assert_ne!(state, JobState::Empty as usize);
Err(_) => {
// debug_assert_ne!(state, JobState::Empty as usize);
eprintln!("######## what the sigma?");
spin.spin();
@ -617,7 +623,7 @@ mod job {
}
}
fn complete(&self, result: std::thread::Result<T>) {
pub(crate) fn complete(&self, result: std::thread::Result<T>) {
let mut spin = SpinWait::new();
loop {
match self.harness_and_state.compare_exchange_weak_tag(
@ -630,8 +636,8 @@ mod job {
debug_assert_eq!(state, JobState::Pending as usize);
break;
}
Err(state) => {
debug_assert_ne!(state, JobState::Pending as usize);
Err(_) => {
// debug_assert_ne!(state, JobState::Pending as usize);
spin.spin();
}
}
@ -662,8 +668,6 @@ mod job {
}
}
impl Job {}
#[allow(dead_code)]
pub struct HeapJob<F> {
f: F,
@ -702,12 +706,14 @@ mod job {
pub struct StackJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
_phantom: PhantomPinned,
}
impl<F> StackJob<F> {
pub fn new(f: F) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
_phantom: PhantomPinned,
}
}
@ -715,7 +721,7 @@ mod job {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
pub fn as_job<T>(&self) -> Job<()>
pub fn as_job<T>(self: Pin<&Self>) -> Job<()>
where
F: FnOnce(&super::Scope) -> T + Send,
T: Send,
@ -735,7 +741,7 @@ mod job {
}
Job::new(harness::<F, T>, unsafe {
NonNull::new_unchecked(self as *const _ as *mut ())
NonNull::new_unchecked(&*self as *const _ as *mut ())
})
}
}
@ -921,13 +927,13 @@ impl Scope {
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
let a = StackJob::new(move |scope: &Scope| {
let a = pin!(StackJob::new(move |scope: &Scope| {
scope.tick();
a(scope)
});
}));
let job = pin!(a.as_job());
let job = pin!(a.as_ref().as_job());
self.push_front(job.as_ref());
let rb = b(self);

View file

@ -7,6 +7,28 @@ fn pin_ptr<T>(pin: &Pin<&mut T>) -> NonNull<T> {
unsafe { NonNull::new_unchecked(a.cast_mut()) }
}
#[test]
fn new_job() {
let mut list = JobList::new();
let stack = pin!(StackJob::new(|_: &Scope| 3 + 3));
let job = pin!(stack.as_ref().as_job());
unsafe {
list.push_front(job.as_ref());
}
unsafe {
let job_ref = list.pop_front().unwrap().cast::<Job<i32>>().as_ref();
job_ref.set_pending();
_ = stack.unwrap();
job_ref.complete(Ok(6));
let result = job_ref.wait();
assert_eq!(result.ok(), Some(6));
}
}
#[test]
fn job_list_pop_back() {
let mut list = JobList::new();
@ -144,6 +166,8 @@ fn unlink_job_single() {
#[test]
fn tagged_ptr_exchange() {
let boxed = Box::into_raw(Box::new(42usize));
let _guard = DropGuard::new(|| drop(unsafe { Box::from_raw(boxed) }));
let ptr = TaggedAtomicPtr::<_, 3>::new(boxed, 1usize);
assert_eq!(ptr.tag(Ordering::Relaxed), 1);
@ -151,7 +175,9 @@ fn tagged_ptr_exchange() {
ptr.set_tag(3, Ordering::Release, Ordering::Relaxed);
assert_eq!(ptr.tag(Ordering::Relaxed), 3);
assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed);
let old = ptr.compare_exchange_weak_tag(3, 4, Ordering::Release, Ordering::Relaxed);
let old = ptr.compare_exchange_tag(3, 4, Ordering::Release, Ordering::Relaxed);
assert_eq!(old, Ok(3));
assert_eq!(ptr.tag(Ordering::Relaxed), 4);
assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed);
@ -160,12 +186,14 @@ fn tagged_ptr_exchange() {
#[test]
fn tagged_ptr_exchange_failure() {
let boxed = Box::into_raw(Box::new(42usize));
let _guard = DropGuard::new(|| drop(unsafe { Box::from_raw(boxed) }));
let ptr = TaggedAtomicPtr::<_, 3>::new(boxed, 1usize);
assert_eq!(ptr.tag(Ordering::Relaxed), 1);
assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed);
let old = ptr.compare_exchange_weak_tag(3, 4, Ordering::Release, Ordering::Relaxed);
let old = ptr.compare_exchange_tag(3, 4, Ordering::Release, Ordering::Relaxed);
assert_eq!(old, Err(1));
assert_eq!(ptr.tag(Ordering::Relaxed), 1);
assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed);