AtomicTaggedPtr compareexchange ptr part, remove unsafe nonnull conversions

This commit is contained in:
Janis 2025-07-04 13:29:26 +02:00
parent 1ea8bcb3ed
commit c2f1d8d749

View file

@ -183,17 +183,15 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
}
}
pub fn ptr(&self, order: atomic::Ordering) -> NonNull<T> {
unsafe {
NonNull::new_unchecked(
self.ptr
.load(order)
.map_addr(|addr| addr & !Self::mask())
.cast(),
)
}
#[doc(alias = "load_ptr")]
pub fn ptr(&self, order: atomic::Ordering) -> *mut T {
self.ptr
.load(order)
.map_addr(|addr| addr & !Self::mask())
.cast()
}
#[doc(alias = "load_tag")]
pub fn tag(&self, order: atomic::Ordering) -> usize {
self.ptr.load(order).addr() & Self::mask()
}
@ -238,6 +236,68 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
}
}
/// returns tag
#[inline(always)]
fn compare_exchange_ptr_inner(
&self,
old: *mut T,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
cmpxchg: fn(
&AtomicPtr<()>,
*mut (),
*mut (),
atomic::Ordering,
atomic::Ordering,
) -> Result<*mut (), *mut ()>,
) -> Result<*mut T, *mut T> {
let mask = Self::mask();
let old_tag = self.ptr.load(failure).addr() & mask;
// old and new must be aligned to the mask, so no need to & with the mask.
let old = old.map_addr(|addr| addr | old_tag).cast();
let new = new.map_addr(|addr| addr | old_tag).cast();
let result = cmpxchg(&self.ptr, old, new, success, failure);
result
.map(|ptr| ptr.map_addr(|addr| addr & !mask).cast())
.map_err(|ptr| ptr.map_addr(|addr| addr & !mask).cast())
}
pub fn compare_exchange_ptr(
&self,
old: *mut T,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> Result<*mut T, *mut T> {
self.compare_exchange_ptr_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange,
)
}
pub fn compare_exchange_weak_ptr(
&self,
old: *mut T,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> Result<*mut T, *mut T> {
self.compare_exchange_ptr_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange_weak,
)
}
/// returns tag
#[inline(always)]
fn compare_exchange_tag_inner(
@ -268,7 +328,6 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
}
/// returns tag
#[allow(dead_code)]
pub fn compare_exchange_tag(
&self,
old: usize,
@ -302,7 +361,7 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
)
}
#[allow(dead_code)]
#[doc(alias = "store_ptr")]
pub fn set_ptr(&self, ptr: *mut T, success: atomic::Ordering, failure: atomic::Ordering) {
let mask = Self::mask();
let ptr = ptr.cast::<()>();
@ -319,6 +378,7 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
}
}
#[doc(alias = "store_tag")]
pub fn set_tag(&self, tag: usize, success: atomic::Ordering, failure: atomic::Ordering) {
let mask = Self::mask();
loop {
@ -335,6 +395,42 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
}
}
pub fn swap_tag(
&self,
new: usize,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> usize {
let mask = Self::mask();
loop {
let ptr = self.ptr.load(failure);
let new = ptr.map_addr(|addr| (addr & !mask) | (new & mask));
if let Ok(old) = self.ptr.compare_exchange_weak(ptr, new, success, failure) {
break old.addr() & mask;
}
}
}
pub fn swap_ptr(
&self,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> *mut T {
let mask = Self::mask();
let new = new.cast::<()>();
loop {
let old = self.ptr.load(failure);
let new = new.map_addr(|addr| (addr & !mask) | (old.addr() & mask));
if let Ok(old) = self.ptr.compare_exchange_weak(old, new, success, failure) {
break old.map_addr(|addr| addr & !mask).cast();
}
}
}
pub fn ptr_and_tag(&self, order: atomic::Ordering) -> (NonNull<T>, usize) {
let mask = Self::mask();
let ptr = self.ptr.load(order);
@ -357,7 +453,7 @@ mod tests {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
@ -369,11 +465,11 @@ mod tests {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
assert_eq!(tagged_ptr.take_tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
@ -385,11 +481,11 @@ mod tests {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
assert_eq!(tagged_ptr.fetch_or_tag(0b10, Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11 | 0b10);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
@ -397,11 +493,11 @@ mod tests {
}
#[test]
fn tagged_ptr_exchange() {
fn tagged_ptr_exchange_tag() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
assert_eq!(
tagged_ptr
@ -411,10 +507,34 @@ mod tests {
);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b10);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
}
}
#[test]
fn tagged_ptr_exchange_ptr() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
let new_ptr = Box::into_raw(Box::new(43u32));
assert_eq!(
tagged_ptr
.compare_exchange_ptr(ptr, new_ptr, Ordering::Relaxed, Ordering::Relaxed)
.unwrap(),
ptr
);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), new_ptr);
unsafe {
_ = Box::from_raw(ptr);
_ = Box::from_raw(new_ptr);
}
}
}