werkzeug/src/ptr.rs
2025-09-16 18:55:26 +02:00

646 lines
18 KiB
Rust

use core::{
cmp::Ordering,
fmt, hash,
marker::{PhantomData, Send},
mem::{self, ManuallyDrop},
num::NonZero,
ops::{Deref, DerefMut},
pin::Pin,
ptr::NonNull,
sync::atomic::{self, AtomicPtr},
};
/// This is a wrapper around `NonNull<T>` that is `Send` even if `T` is not
/// `Send`. This is useful for types that use `NonNull<T>` internally but are
/// safe to send to other threads.
#[repr(transparent)]
pub struct SendNonNull<T>(NonNull<T>);
impl<T> fmt::Debug for SendNonNull<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Pointer::fmt(&self.as_ptr(), f)
}
}
impl<T> Copy for SendNonNull<T> {}
impl<T> Clone for SendNonNull<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Eq for SendNonNull<T> {}
impl<T> PartialEq for SendNonNull<T> {
fn eq(&self, other: &Self) -> bool {
self.as_ptr() == other.as_ptr()
}
}
impl<T> Ord for SendNonNull<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.as_ptr().cmp(&other.as_ptr())
}
}
impl<T> PartialOrd for SendNonNull<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.as_ptr().partial_cmp(&other.as_ptr())
}
}
impl<T> hash::Hash for SendNonNull<T> {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.as_ptr().hash(state);
}
}
impl<T> From<NonNull<T>> for SendNonNull<T> {
fn from(ptr: NonNull<T>) -> Self {
Self(ptr)
}
}
impl<T> From<SendNonNull<T>> for NonNull<T> {
fn from(ptr: SendNonNull<T>) -> Self {
ptr.0
}
}
impl<T> From<&mut T> for SendNonNull<T> {
fn from(ptr: &mut T) -> Self {
Self(NonNull::from(ptr))
}
}
impl<T> From<&T> for SendNonNull<T> {
fn from(ptr: &T) -> Self {
Self(NonNull::from(ptr))
}
}
impl<T> fmt::Pointer for SendNonNull<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<NonNull<T> as fmt::Pointer>::fmt(&self.0, f)
}
}
unsafe impl<T> Send for SendNonNull<T> {}
impl<T> Deref for SendNonNull<T> {
type Target = NonNull<T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for SendNonNull<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> SendNonNull<T> {
/// Creates a new `SendNonNull<T>` if `ptr` is non-null, otherwise returns `None`.
pub const fn new(ptr: *mut T) -> Option<Self> {
match NonNull::new(ptr) {
Some(ptr) => Some(Self(ptr)),
None => None,
}
}
/// Creates a new `SendNonNull<T>` that is dangling.
pub const fn dangling() -> Self {
Self(NonNull::dangling())
}
/// Casts the pointer to a different type
pub const fn cast<U>(self) -> SendNonNull<U> {
SendNonNull(self.0.cast())
}
/// Creates a new `SendNonNull<T>` with the given address, keeping the provenance of `self`.
pub fn with_addr(self, addr: NonZero<usize>) -> Self {
// SAFETY: addr is non-zero, so the pointer is valid.
unsafe {
Self(NonNull::new_unchecked(
self.as_ptr().with_addr(addr.get()) as *mut _
))
}
}
/// Maps the address of the pointer using the given function, keeping the provenance of `self`.
pub fn map_addr(self, f: impl FnOnce(NonZero<usize>) -> NonZero<usize>) -> Self {
// SAFETY: addr is non-zero, so the pointer is valid.
self.with_addr(f(self.addr()))
}
/// Returns a new pointer, offset from `self` by `offset` elements.
///
/// # Safety
///
/// The caller must ensure that the resulting pointer points at the same allocation as `self`.
pub unsafe fn offset(self, offset: isize) -> Self {
// SAFETY: self is a valid pointer, offset is guaranteed to point to a valid memory location by the contract of `offset`
unsafe { Self(NonNull::new_unchecked(self.as_ptr().offset(offset))) }
}
pub unsafe fn byte_offset(self, offset: isize) -> Self {
// SAFETY: self is a valid pointer, offset is guaranteed to point to a valid memory location by the contract of `offset`
unsafe { Self(NonNull::new_unchecked(self.as_ptr().byte_offset(offset))) }
}
pub unsafe fn add(self, count: usize) -> Self {
// SAFETY: self is a valid pointer, count is guaranteed to point to a valid memory location by the contract of `add`
unsafe { Self(NonNull::new_unchecked(self.as_ptr().add(count))) }
}
pub unsafe fn byte_add(self, count: usize) -> Self {
// SAFETY: self is a valid pointer, count is guaranteed to point to a valid memory location by the contract of `add`
unsafe { Self(NonNull::new_unchecked(self.as_ptr().byte_add(count))) }
}
pub const fn new_const(ptr: *const T) -> Option<Self> {
Self::new(ptr.cast_mut())
}
/// ptr must be non-null
pub const unsafe fn new_unchecked(ptr: *mut T) -> Self {
// SAFETY: ptr must be non-null, which is guaranteed by the caller.
unsafe { Self(NonNull::new_unchecked(ptr)) }
}
}
/// A tagged atomic pointer that can store a pointer and a tag `BITS` wide in the same space
/// as the pointer.
/// The pointer must be aligned to `BITS` bits, i.e. `align_of::<T>() >= 2^BITS`.
#[repr(transparent)]
#[derive(Debug)]
pub struct TaggedAtomicPtr<T, const BITS: u8> {
ptr: AtomicPtr<()>,
_pd: PhantomData<T>,
}
impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
const fn mask() -> usize {
!(!0usize << BITS)
}
pub fn new(ptr: *mut T, tag: usize) -> TaggedAtomicPtr<T, BITS> {
debug_assert!(mem::align_of::<T>().ilog2() as u8 >= BITS);
let mask = Self::mask();
Self {
ptr: AtomicPtr::new(ptr.with_addr((ptr.addr() & !mask) | (tag & mask)).cast()),
_pd: PhantomData,
}
}
#[doc(alias = "load_ptr")]
pub fn ptr(&self, order: atomic::Ordering) -> *mut T {
self.ptr
.load(order)
.map_addr(|addr| addr & !Self::mask())
.cast()
}
#[doc(alias = "load_tag")]
pub fn tag(&self, order: atomic::Ordering) -> usize {
self.ptr.load(order).addr() & Self::mask()
}
pub fn fetch_or_tag(&self, tag: usize, order: atomic::Ordering) -> usize {
let mask = Self::mask();
// TODO: switch to fetch_or when stable
// let old_ptr = self.ptr.fetch_or(tag & mask, order);
#[cfg(feature = "nightly")]
{
let old_ptr = self.ptr.fetch_or(tag & mask, order);
old_ptr.addr() & mask
}
#[cfg(not(feature = "nightly"))]
{
use core::sync::atomic::AtomicUsize;
let ptr = unsafe { AtomicUsize::from_ptr(self.ptr.as_ptr() as *mut usize) };
let old_ptr = ptr.fetch_or(tag & mask, order);
old_ptr & mask
}
}
/// returns the tag and clears it
pub fn take_tag(&self, order: atomic::Ordering) -> usize {
let mask = Self::mask();
#[cfg(feature = "nightly")]
{
let old_ptr = self.ptr.fetch_and(!mask, order);
old_ptr.addr() & mask
}
#[cfg(not(feature = "nightly"))]
{
use core::sync::atomic::AtomicUsize;
let ptr = unsafe { AtomicUsize::from_ptr(self.ptr.as_ptr() as *mut usize) };
let old_ptr = ptr.fetch_and(!mask, order);
old_ptr & mask
}
}
/// returns tag
#[inline(always)]
fn compare_exchange_ptr_inner(
&self,
old: *mut T,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
cmpxchg: fn(
&AtomicPtr<()>,
*mut (),
*mut (),
atomic::Ordering,
atomic::Ordering,
) -> Result<*mut (), *mut ()>,
) -> Result<*mut T, *mut T> {
let mask = Self::mask();
let old_tag = self.ptr.load(failure).addr() & mask;
// old and new must be aligned to the mask, so no need to & with the mask.
let old = old.map_addr(|addr| addr | old_tag).cast();
let new = new.map_addr(|addr| addr | old_tag).cast();
let result = cmpxchg(&self.ptr, old, new, success, failure);
result
.map(|ptr| ptr.map_addr(|addr| addr & !mask).cast())
.map_err(|ptr| ptr.map_addr(|addr| addr & !mask).cast())
}
pub fn compare_exchange_ptr(
&self,
old: *mut T,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> Result<*mut T, *mut T> {
self.compare_exchange_ptr_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange,
)
}
pub fn compare_exchange_weak_ptr(
&self,
old: *mut T,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> Result<*mut T, *mut T> {
self.compare_exchange_ptr_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange_weak,
)
}
/// returns tag
#[inline(always)]
fn compare_exchange_tag_inner(
&self,
old: usize,
new: usize,
success: atomic::Ordering,
failure: atomic::Ordering,
cmpxchg: fn(
&AtomicPtr<()>,
*mut (),
*mut (),
atomic::Ordering,
atomic::Ordering,
) -> Result<*mut (), *mut ()>,
) -> Result<usize, usize> {
let mask = Self::mask();
let old_ptr = self.ptr.load(failure);
let old = old_ptr.map_addr(|addr| (addr & !mask) | (old & mask));
let new = old_ptr.map_addr(|addr| (addr & !mask) | (new & mask));
let result = cmpxchg(&self.ptr, old, new, success, failure);
result
.map(|ptr| ptr.addr() & mask)
.map_err(|ptr| ptr.addr() & mask)
}
/// returns tag
pub fn compare_exchange_tag(
&self,
old: usize,
new: usize,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> Result<usize, usize> {
self.compare_exchange_tag_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange,
)
}
/// returns tag
pub fn compare_exchange_weak_tag(
&self,
old: usize,
new: usize,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> Result<usize, usize> {
self.compare_exchange_tag_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange_weak,
)
}
#[doc(alias = "store_ptr")]
pub fn set_ptr(&self, ptr: *mut T, success: atomic::Ordering, failure: atomic::Ordering) {
let mask = Self::mask();
let ptr = ptr.cast::<()>();
loop {
let old = self.ptr.load(failure);
let new = ptr.map_addr(|addr| (addr & !mask) | (old.addr() & mask));
if self
.ptr
.compare_exchange_weak(old, new, success, failure)
.is_ok()
{
break;
}
}
}
#[doc(alias = "store_tag")]
pub fn set_tag(&self, tag: usize, success: atomic::Ordering, failure: atomic::Ordering) {
let mask = Self::mask();
loop {
let ptr = self.ptr.load(failure);
let new = ptr.map_addr(|addr| (addr & !mask) | (tag & mask));
if self
.ptr
.compare_exchange_weak(ptr, new, success, failure)
.is_ok()
{
break;
}
}
}
pub fn swap_tag(
&self,
new: usize,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> usize {
let mask = Self::mask();
loop {
let ptr = self.ptr.load(failure);
let new = ptr.map_addr(|addr| (addr & !mask) | (new & mask));
if let Ok(old) = self.ptr.compare_exchange_weak(ptr, new, success, failure) {
break old.addr() & mask;
}
}
}
pub fn swap_ptr(
&self,
new: *mut T,
success: atomic::Ordering,
failure: atomic::Ordering,
) -> *mut T {
let mask = Self::mask();
let new = new.cast::<()>();
loop {
let old = self.ptr.load(failure);
let new = new.map_addr(|addr| (addr & !mask) | (old.addr() & mask));
if let Ok(old) = self.ptr.compare_exchange_weak(old, new, success, failure) {
break old.map_addr(|addr| addr & !mask).cast();
}
}
}
pub fn ptr_and_tag(&self, order: atomic::Ordering) -> (NonNull<T>, usize) {
let mask = Self::mask();
let ptr = self.ptr.load(order);
let tag = ptr.addr() & mask;
let ptr = ptr.map_addr(|addr| addr & !mask);
let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) };
(ptr, tag)
}
pub fn copy_from(
&self,
other: &Self,
load: atomic::Ordering,
store: atomic::Ordering,
) -> (*mut T, usize) {
let old = self.ptr.swap(other.ptr.load(load), store);
let mask = Self::mask();
(old.map_addr(|addr| addr & !mask).cast(), old.addr() & mask)
}
}
#[repr(transparent)]
pub struct UniquePtr<'a, T> {
ptr: NonNull<T>,
_marker: PhantomData<&'a mut T>,
}
impl<'a, T> UniquePtr<'a, T> {
#[inline]
pub fn map<U, F>(value: T, f: F) -> U
where
F: FnOnce(UniquePtr<'_, T>) -> U,
{
let mut inner = ManuallyDrop::new(value);
let this = UniquePtr::new(&mut inner);
f(this)
}
pub fn new_pinned(inner: Pin<&'a mut ManuallyDrop<T>>) -> Pin<Self> {
// SAFETY: `inner` is pinned, so it must remain pinned for the lifetime of `Self`.
unsafe {
Pin::new_unchecked(Self {
ptr: NonNull::new_unchecked(core::mem::transmute::<_, _>(inner)),
_marker: PhantomData,
})
}
}
pub fn new(inner: &'a mut ManuallyDrop<T>) -> Self {
Self {
ptr: NonNull::from(&mut **inner),
_marker: PhantomData,
}
}
pub unsafe fn new_unchecked(ptr: *mut T) -> Self {
Self {
ptr: unsafe { NonNull::new_unchecked(ptr) },
_marker: PhantomData,
}
}
pub fn as_ptr(&self) -> *mut T {
self.ptr.as_ptr()
}
pub fn as_non_null(&self) -> NonNull<T> {
self.ptr
}
pub unsafe fn cast<U>(self) -> UniquePtr<'a, U> {
UniquePtr {
ptr: self.ptr.cast(),
_marker: PhantomData,
}
}
}
impl<'a, T> Deref for UniquePtr<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.ptr.as_ref() }
}
}
impl<'a, T> DerefMut for UniquePtr<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.ptr.as_mut() }
}
}
impl<'a, T> Drop for UniquePtr<'a, T> {
fn drop(&mut self) {
unsafe {
core::ptr::drop_in_place(&raw mut **self);
}
}
}
#[cfg(test)]
mod tests {
use core::sync::atomic::Ordering;
use std::boxed::Box;
use super::*;
#[test]
fn tagged_ptr_zero_tag() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
}
}
#[test]
fn tagged_ptr_take_tag() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
assert_eq!(tagged_ptr.take_tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
}
}
#[test]
fn tagged_ptr_fetch_or_tag() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
assert_eq!(tagged_ptr.fetch_or_tag(0b10, Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11 | 0b10);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
}
}
#[test]
fn tagged_ptr_exchange_tag() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
assert_eq!(
tagged_ptr
.compare_exchange_tag(0b11, 0b10, Ordering::Relaxed, Ordering::Relaxed)
.unwrap(),
0b11
);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b10);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
unsafe {
_ = Box::from_raw(ptr);
}
}
#[test]
fn tagged_ptr_exchange_ptr() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), ptr);
let new_ptr = Box::into_raw(Box::new(43u32));
assert_eq!(
tagged_ptr
.compare_exchange_ptr(ptr, new_ptr, Ordering::Relaxed, Ordering::Relaxed)
.unwrap(),
ptr
);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed), new_ptr);
unsafe {
_ = Box::from_raw(ptr);
_ = Box::from_raw(new_ptr);
}
}
}