From 1ea8bcb3edb669c9f1b46c8966483aaf4d49a05f Mon Sep 17 00:00:00 2001 From: Janis Date: Thu, 3 Jul 2025 16:50:41 +0200 Subject: [PATCH] atomiccell --- src/atomic.rs | 262 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 2 files changed, 263 insertions(+) create mode 100644 src/atomic.rs diff --git a/src/atomic.rs b/src/atomic.rs new file mode 100644 index 0000000..e351e3b --- /dev/null +++ b/src/atomic.rs @@ -0,0 +1,262 @@ +use core::{ + cell::UnsafeCell, + mem::{self, ManuallyDrop, MaybeUninit}, + sync::atomic::{AtomicU8, AtomicU16, AtomicU32, AtomicU64, AtomicUsize, Ordering}, +}; + +use crate::sync::SpinWait; + +macro_rules! atomic { + (@check, $t:ty, $atomic:ty, $a:ident, $op:expr) => { + if crate::can_transmute::<$t, $atomic>() { + let $a: &$atomic; + break $op; + } + }; + ($t:ty, $a:ident, $op:expr, $fallback:expr) => { + loop { + atomic!(@check, $t, AtomicU8, $a, $op); + atomic!(@check, $t, AtomicU16, $a, $op); + atomic!(@check, $t, AtomicU32, $a, $op); + atomic!(@check, $t, AtomicU64, $a, $op); + atomic!(@check, $t, AtomicUsize, $a, $op); + + // Fallback to the provided expression if no atomic type is found. + break $fallback; + } + }; +} + +pub struct AtomicCell { + inner: AtomicCellInner, + _phantom: core::marker::PhantomData, +} + +impl AtomicCell { + pub const fn new() -> Self { + Self { + inner: AtomicCellInner::none(), + _phantom: core::marker::PhantomData, + } + } + + pub fn set(&self, value: T) { + self.inner.set(value); + } + + pub fn take(&self) -> Option { + self.inner.take() + } + + pub fn get(&self) -> Option + where + T: Copy, + { + self.inner.get() + } + + pub fn swap(&self, value: Option) -> Option { + self.inner.swap(value) + } +} + +struct AtomicCellInner { + value: UnsafeCell>>, + state: AtomicU8, +} + +impl AtomicCellInner { + const EMPTY: u8 = 0; + const FULL: u8 = 1; + const LOCKED: u8 = 2; + + const fn none() -> Self { + Self { + value: UnsafeCell::new(ManuallyDrop::new(MaybeUninit::uninit())), + state: AtomicU8::new(Self::EMPTY), + } + } + + fn from_option(value: Option) -> Self { + match value { + Some(v) => Self { + value: UnsafeCell::new(ManuallyDrop::new(MaybeUninit::new(v))), + state: AtomicU8::new(Self::FULL), + }, + None => Self { + value: UnsafeCell::new(ManuallyDrop::new(MaybeUninit::uninit())), + state: AtomicU8::new(Self::EMPTY), + }, + } + } + + unsafe fn copy_from(&self, other: &Self, load: Ordering, store: Ordering) { + unsafe { + self.value.get().write(other.value.get().read()); + self.state.store(other.state.load(load), store); + } + } + + fn set(&self, value: T) { + self.swap(Some(value)); + } + + fn take(&self) -> Option { + self.swap(None) + } + + fn get(&self) -> Option + where + T: Copy, + { + let this: Self; + + atomic! { + Self, a, + { + unsafe { + a = &*(self as *const Self as *const _); + let old = a.load(Ordering::Acquire); + this = mem::transmute_copy(&old); + } + }, + { + let mut state = self.state.load(Ordering::Acquire); + + if state == Self::EMPTY { + this = Self::none(); + } else { + // if the state is `FULL`, we have to lock + + let mut spin_wait = SpinWait::new(); + let old = loop { + // if the state is `LOCKED`, we need to wait + if state == Self::LOCKED { + spin_wait.spin(); + continue; + } + + // if the state is `FULL`, we can try locking and swapping the value` + if self.state.compare_exchange_weak( + state, + Self::LOCKED, + Ordering::Acquire, + Ordering::Relaxed, + ).is_ok() { + break state; + } else { + // the state changed, we need to check again + state = self.state.load(Ordering::Relaxed); + continue; + } + }; + + let empty = Self::none(); + if old == Self::FULL { + // copy the value out of the cell + unsafe { + empty.copy_from(&self, Ordering::Relaxed, Ordering::Release); + + } + } + this = empty; + } + } + } + + match this.state.load(Ordering::Relaxed) { + Self::FULL => { + // SAFETY: We are returning the value only if it was previously full. + unsafe { Some(ManuallyDrop::into_inner(this.value.get().read()).assume_init()) } + } + _ => None, + } + } + + fn swap(&self, value: Option) -> Option { + let mut this = Self::from_option(value); + + atomic! { + Self, a, + { + // SAFETY: this block is only executed if `Self` can be transmuted into an atomic type. + // self.state cannot be `LOCKED` here, so we can safely swap the value. + unsafe { + // turn `self` into an atomic pointer + a = &*(self as *const Self as *const _); + // swap the value atomically + let old = a.swap(mem::transmute_copy(&this), Ordering::Release); + this = mem::transmute_copy(&old); + + if this.state.load(Ordering::Relaxed) == Self::FULL { + // SAFETY: We are returning the value only if it was previously full. + Some( ManuallyDrop::into_inner(this.value.into_inner()).assume_init() ) + } else { + None + } + } + + }, + { + // Fallback if no atomic type is found. + // we need to lock the cell to swap the value. + + // attempt to lock optimistically + match self.state.compare_exchange_weak( + Self::EMPTY, + Self::LOCKED, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => { + // SAFETY: We are the only thread that can access this cell now. + unsafe { + self.copy_from(&this, Ordering::Relaxed, Ordering::Release); + } + None + } + Err(mut state) => { + let mut spin_wait = SpinWait::new(); + let old = loop { + // if the state is `LOCKED`, we need to wait + if state == Self::LOCKED { + spin_wait.spin(); + continue; + } + + // if the state is not `LOCKED`, we can try locking and swapping the value` + if self.state.compare_exchange_weak( + state, + Self::LOCKED, + Ordering::Acquire, + Ordering::Relaxed, + ).is_ok() { + break state; + } else { + // the state changed, we need to check again + state = self.state.load(Ordering::Relaxed); + continue; + } + }; + + let old = if old == Self::FULL { + // SAFETY: the cell is locked, and is initialised. + unsafe { + Some(ManuallyDrop::into_inner(self.value.get().read()).assume_init()) + } + } else {None}; + + // SAFETY: the cell is locked, so we can safely copy the value + unsafe { + self.copy_from(&this, Ordering::Relaxed, Ordering::Release); + } + + old + } + } + + } + }; + None + } +} diff --git a/src/lib.rs b/src/lib.rs index 94a749a..e263396 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ extern crate alloc; #[cfg(any(test, feature = "std"))] extern crate std; +pub mod atomic; pub mod cachepadded; pub mod drop_guard; pub mod ptr;