diff --git a/src/ptr.rs b/src/ptr.rs index 59ac4f0..f96009c 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -1,10 +1,12 @@ use core::{ cmp::Ordering, fmt, hash, - marker::Send, + marker::{PhantomData, Send}, + mem, num::NonZero, ops::{Deref, DerefMut}, ptr::NonNull, + sync::atomic::{self, AtomicPtr, AtomicUsize}, }; #[repr(transparent)] @@ -156,3 +158,248 @@ impl SendNonNull { unsafe { Self(NonNull::new_unchecked(ptr)) } } } + +/// A tagged atomic pointer that can store a pointer and a tag `BITS` wide in the same space +/// as the pointer. +/// The pointer must be aligned to `BITS` bits, i.e. `align_of::() >= 2^BITS`. +#[repr(transparent)] +#[derive(Debug)] +pub struct TaggedAtomicPtr { + ptr: AtomicPtr<()>, + _pd: PhantomData, +} + +impl TaggedAtomicPtr { + const fn mask() -> usize { + !(!0usize << BITS) + } + + pub fn new(ptr: *mut T, tag: usize) -> TaggedAtomicPtr { + debug_assert!(mem::align_of::().ilog2() as u8 >= BITS); + let mask = Self::mask(); + Self { + ptr: AtomicPtr::new(ptr.with_addr((ptr.addr() & !mask) | (tag & mask)).cast()), + _pd: PhantomData, + } + } + + pub fn ptr(&self, order: atomic::Ordering) -> NonNull { + unsafe { + NonNull::new_unchecked( + self.ptr + .load(order) + .map_addr(|addr| addr & !Self::mask()) + .cast(), + ) + } + } + + pub fn tag(&self, order: atomic::Ordering) -> usize { + self.ptr.load(order).addr() & Self::mask() + } + + pub fn fetch_or_tag(&self, tag: usize, order: atomic::Ordering) -> usize { + let mask = Self::mask(); + + // TODO: switch to fetch_or when stable + // let old_ptr = self.ptr.fetch_or(tag & mask, order); + + let ptr = unsafe { AtomicUsize::from_ptr(self.ptr.as_ptr() as *mut usize) }; + let old_ptr = ptr.fetch_or(tag & mask, order); + + old_ptr & mask + } + + /// returns the tag and clears it + pub fn take_tag(&self, order: atomic::Ordering) -> usize { + let mask = Self::mask(); + + // TODO: switch to fetch_and when stable + // let old_ptr = self.ptr.fetch_and(!mask, order); + + let ptr = unsafe { AtomicUsize::from_ptr(self.ptr.as_ptr() as *mut usize) }; + let old_ptr = ptr.fetch_and(!mask, order); + + old_ptr & mask + } + + /// returns tag + #[inline(always)] + fn compare_exchange_tag_inner( + &self, + old: usize, + new: usize, + success: atomic::Ordering, + failure: atomic::Ordering, + cmpxchg: fn( + &AtomicPtr<()>, + *mut (), + *mut (), + atomic::Ordering, + atomic::Ordering, + ) -> Result<*mut (), *mut ()>, + ) -> Result { + let mask = Self::mask(); + let old_ptr = self.ptr.load(failure); + + let old = old_ptr.map_addr(|addr| (addr & !mask) | (old & mask)); + let new = old_ptr.map_addr(|addr| (addr & !mask) | (new & mask)); + + let result = cmpxchg(&self.ptr, old, new, success, failure); + + result + .map(|ptr| ptr.addr() & mask) + .map_err(|ptr| ptr.addr() & mask) + } + + /// returns tag + #[allow(dead_code)] + pub fn compare_exchange_tag( + &self, + old: usize, + new: usize, + success: atomic::Ordering, + failure: atomic::Ordering, + ) -> Result { + self.compare_exchange_tag_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange, + ) + } + + /// returns tag + pub fn compare_exchange_weak_tag( + &self, + old: usize, + new: usize, + success: atomic::Ordering, + failure: atomic::Ordering, + ) -> Result { + self.compare_exchange_tag_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange_weak, + ) + } + + #[allow(dead_code)] + pub fn set_ptr(&self, ptr: *mut T, success: atomic::Ordering, failure: atomic::Ordering) { + let mask = Self::mask(); + let ptr = ptr.cast::<()>(); + loop { + let old = self.ptr.load(failure); + let new = ptr.map_addr(|addr| (addr & !mask) | (old.addr() & mask)); + if self + .ptr + .compare_exchange_weak(old, new, success, failure) + .is_ok() + { + break; + } + } + } + + pub fn set_tag(&self, tag: usize, success: atomic::Ordering, failure: atomic::Ordering) { + let mask = Self::mask(); + loop { + let ptr = self.ptr.load(failure); + let new = ptr.map_addr(|addr| (addr & !mask) | (tag & mask)); + + if self + .ptr + .compare_exchange_weak(ptr, new, success, failure) + .is_ok() + { + break; + } + } + } + + pub fn ptr_and_tag(&self, order: atomic::Ordering) -> (NonNull, usize) { + let mask = Self::mask(); + let ptr = self.ptr.load(order); + let tag = ptr.addr() & mask; + let ptr = ptr.map_addr(|addr| addr & !mask); + let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) }; + (ptr, tag) + } +} + +#[cfg(test)] +mod tests { + use core::sync::atomic::Ordering; + use std::boxed::Box; + + use super::*; + + #[test] + fn tagged_ptr_zero_tag() { + let ptr = Box::into_raw(Box::new(42u32)); + let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + unsafe { + _ = Box::from_raw(ptr); + } + } + + #[test] + fn tagged_ptr_take_tag() { + let ptr = Box::into_raw(Box::new(42u32)); + let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + assert_eq!(tagged_ptr.take_tag(Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + unsafe { + _ = Box::from_raw(ptr); + } + } + + #[test] + fn tagged_ptr_fetch_or_tag() { + let ptr = Box::into_raw(Box::new(42u32)); + let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + assert_eq!(tagged_ptr.fetch_or_tag(0b10, Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11 | 0b10); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + unsafe { + _ = Box::from_raw(ptr); + } + } + + #[test] + fn tagged_ptr_exchange() { + let ptr = Box::into_raw(Box::new(42u32)); + let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + assert_eq!( + tagged_ptr + .compare_exchange_tag(0b11, 0b10, Ordering::Relaxed, Ordering::Relaxed) + .unwrap(), + 0b11 + ); + + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b10); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + unsafe { + _ = Box::from_raw(ptr); + } + } +}