diff --git a/src/ptr.rs b/src/ptr.rs index dca0c4d..568d0a6 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -183,17 +183,15 @@ impl TaggedAtomicPtr { } } - pub fn ptr(&self, order: atomic::Ordering) -> NonNull { - unsafe { - NonNull::new_unchecked( - self.ptr - .load(order) - .map_addr(|addr| addr & !Self::mask()) - .cast(), - ) - } + #[doc(alias = "load_ptr")] + pub fn ptr(&self, order: atomic::Ordering) -> *mut T { + self.ptr + .load(order) + .map_addr(|addr| addr & !Self::mask()) + .cast() } + #[doc(alias = "load_tag")] pub fn tag(&self, order: atomic::Ordering) -> usize { self.ptr.load(order).addr() & Self::mask() } @@ -238,6 +236,68 @@ impl TaggedAtomicPtr { } } + /// returns tag + #[inline(always)] + fn compare_exchange_ptr_inner( + &self, + old: *mut T, + new: *mut T, + success: atomic::Ordering, + failure: atomic::Ordering, + cmpxchg: fn( + &AtomicPtr<()>, + *mut (), + *mut (), + atomic::Ordering, + atomic::Ordering, + ) -> Result<*mut (), *mut ()>, + ) -> Result<*mut T, *mut T> { + let mask = Self::mask(); + let old_tag = self.ptr.load(failure).addr() & mask; + + // old and new must be aligned to the mask, so no need to & with the mask. + let old = old.map_addr(|addr| addr | old_tag).cast(); + let new = new.map_addr(|addr| addr | old_tag).cast(); + + let result = cmpxchg(&self.ptr, old, new, success, failure); + + result + .map(|ptr| ptr.map_addr(|addr| addr & !mask).cast()) + .map_err(|ptr| ptr.map_addr(|addr| addr & !mask).cast()) + } + + pub fn compare_exchange_ptr( + &self, + old: *mut T, + new: *mut T, + success: atomic::Ordering, + failure: atomic::Ordering, + ) -> Result<*mut T, *mut T> { + self.compare_exchange_ptr_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange, + ) + } + + pub fn compare_exchange_weak_ptr( + &self, + old: *mut T, + new: *mut T, + success: atomic::Ordering, + failure: atomic::Ordering, + ) -> Result<*mut T, *mut T> { + self.compare_exchange_ptr_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange_weak, + ) + } + /// returns tag #[inline(always)] fn compare_exchange_tag_inner( @@ -268,7 +328,6 @@ impl TaggedAtomicPtr { } /// returns tag - #[allow(dead_code)] pub fn compare_exchange_tag( &self, old: usize, @@ -302,7 +361,7 @@ impl TaggedAtomicPtr { ) } - #[allow(dead_code)] + #[doc(alias = "store_ptr")] pub fn set_ptr(&self, ptr: *mut T, success: atomic::Ordering, failure: atomic::Ordering) { let mask = Self::mask(); let ptr = ptr.cast::<()>(); @@ -319,6 +378,7 @@ impl TaggedAtomicPtr { } } + #[doc(alias = "store_tag")] pub fn set_tag(&self, tag: usize, success: atomic::Ordering, failure: atomic::Ordering) { let mask = Self::mask(); loop { @@ -335,6 +395,42 @@ impl TaggedAtomicPtr { } } + pub fn swap_tag( + &self, + new: usize, + success: atomic::Ordering, + failure: atomic::Ordering, + ) -> usize { + let mask = Self::mask(); + loop { + let ptr = self.ptr.load(failure); + let new = ptr.map_addr(|addr| (addr & !mask) | (new & mask)); + + if let Ok(old) = self.ptr.compare_exchange_weak(ptr, new, success, failure) { + break old.addr() & mask; + } + } + } + + pub fn swap_ptr( + &self, + new: *mut T, + success: atomic::Ordering, + failure: atomic::Ordering, + ) -> *mut T { + let mask = Self::mask(); + let new = new.cast::<()>(); + + loop { + let old = self.ptr.load(failure); + let new = new.map_addr(|addr| (addr & !mask) | (old.addr() & mask)); + + if let Ok(old) = self.ptr.compare_exchange_weak(old, new, success, failure) { + break old.map_addr(|addr| addr & !mask).cast(); + } + } + } + pub fn ptr_and_tag(&self, order: atomic::Ordering) -> (NonNull, usize) { let mask = Self::mask(); let ptr = self.ptr.load(order); @@ -357,7 +453,7 @@ mod tests { let ptr = Box::into_raw(Box::new(42u32)); let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0); assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0); - assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); unsafe { _ = Box::from_raw(ptr); @@ -369,11 +465,11 @@ mod tests { let ptr = Box::into_raw(Box::new(42u32)); let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); - assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); assert_eq!(tagged_ptr.take_tag(Ordering::Relaxed), 0b11); assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0); - assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); unsafe { _ = Box::from_raw(ptr); @@ -385,11 +481,11 @@ mod tests { let ptr = Box::into_raw(Box::new(42u32)); let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); - assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); assert_eq!(tagged_ptr.fetch_or_tag(0b10, Ordering::Relaxed), 0b11); assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11 | 0b10); - assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); unsafe { _ = Box::from_raw(ptr); @@ -397,11 +493,11 @@ mod tests { } #[test] - fn tagged_ptr_exchange() { + fn tagged_ptr_exchange_tag() { let ptr = Box::into_raw(Box::new(42u32)); let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); - assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); assert_eq!( tagged_ptr @@ -411,10 +507,34 @@ mod tests { ); assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b10); - assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); unsafe { _ = Box::from_raw(ptr); } } + + #[test] + fn tagged_ptr_exchange_ptr() { + let ptr = Box::into_raw(Box::new(42u32)); + let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr); + + let new_ptr = Box::into_raw(Box::new(43u32)); + assert_eq!( + tagged_ptr + .compare_exchange_ptr(ptr, new_ptr, Ordering::Relaxed, Ordering::Relaxed) + .unwrap(), + ptr + ); + + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), new_ptr); + + unsafe { + _ = Box::from_raw(ptr); + _ = Box::from_raw(new_ptr); + } + } }