use core::{ mem, sync::atomic::{AtomicU32, Ordering}, }; /// A simple lock implementation using an atomic u32. #[repr(transparent)] pub struct Lock { inner: AtomicU32, } impl Lock { const LOCKED: u32 = 0b001; const EMPTY: u32 = 0; /// Creates a new lock in the unlocked state. pub const fn new() -> Self { Self { inner: AtomicU32::new(0), } } pub fn as_ptr(&self) -> *mut u32 { self.inner.as_ptr() } pub unsafe fn from_ptr<'a>(ptr: *mut u32) -> &'a Self { // SAFETY: The caller must ensure that `ptr` is not aliased, and lasts // for the lifetime of the `Lock`. unsafe { mem::transmute(AtomicU32::from_ptr(ptr)) } } /// Acquires the lock, blocking until it is available. pub fn lock(&self) { // attempt acquiring the lock with no contention. if self .inner .compare_exchange_weak( Self::EMPTY, Self::LOCKED, Ordering::Acquire, Ordering::Relaxed, ) .is_ok() { // We successfully acquired the lock. return; } else { self.lock_slow(); } } pub fn unlock(&self) { // use release semantics to ensure that all previous writes are // available to other threads. self.inner.fetch_and(!Self::LOCKED, Ordering::Release); } fn lock_slow(&self) { // The lock is either locked, or someone is waiting for it: let mut spin_wait = SpinWait::new(); let mut state = self.inner.load(Ordering::Acquire); loop { // If the lock isn't locked, we can try to acquire it. if state & Self::LOCKED == 0 { // Try to acquire the lock. match self.inner.compare_exchange_weak( state, state | Self::LOCKED, Ordering::Acquire, Ordering::Relaxed, ) { Ok(_) => { // We successfully acquired the lock. return; } Err(new_state) => { // The state changed, we need to check again. state = new_state; continue; } } } if { let spun: bool; #[cfg(feature = "std")] { spun = spin_wait.spin_yield(); } #[cfg(not(feature = "std"))] { spun = spin_wait.spin(); } spun } { // We can spin for a little while and see if it becomes available. state = self.inner.load(Ordering::Relaxed); continue; } // If we reach here, we need to park the thread. atomic_wait::wait(&self.inner, Self::LOCKED); if self .inner .compare_exchange_weak( state, state | Self::LOCKED, Ordering::Acquire, Ordering::Relaxed, ) .is_ok() { // We successfully acquired the lock after being woken up. return; } spin_wait.reset(); state = self.inner.load(Ordering::Relaxed); } } pub fn wait(&self) { let state = self.inner.load(Ordering::Acquire); atomic_wait::wait(&self.inner, state); } pub fn wake_one(&self) { // Notify one thread waiting on this lock. atomic_wait::wake_one(&self.inner); } } // from parking_lot_core pub struct SpinWait { counter: u32, } impl SpinWait { /// Creates a new `SpinWait` with an initial counter value. pub const fn new() -> Self { Self { counter: 0 } } /// Resets the counter to zero. pub fn reset(&mut self) { self.counter = 0; } pub fn spin(&mut self) -> bool { if self.counter >= 10 { // If the counter is too high, we signal the caller to potentially park. return false; } self.counter += 1; // spin for a small number of iterations based on the counter value. for _ in 0..(1 << self.counter) { core::hint::spin_loop(); } true } #[cfg(feature = "std")] pub fn spin_yield(&mut self) -> bool { if self.counter >= 10 { // If the counter is too high, we signal the caller to potentially park. return false; } self.counter += 1; if self.counter >= 3 { // spin for a small number of iterations based on the counter value. for _ in 0..(1 << self.counter) { core::hint::spin_loop(); } } else { // yield the thread and wait for the OS to reschedule us. std::thread::yield_now(); } true } } // taken from `std` #[derive(Debug)] #[repr(transparent)] pub struct Parker { mutex: AtomicU32, } impl Parker { const PARKED: u32 = u32::MAX; const EMPTY: u32 = 0; const NOTIFIED: u32 = 1; pub fn new() -> Self { Self { mutex: AtomicU32::new(Self::EMPTY), } } pub fn as_ptr(&self) -> *mut u32 { self.mutex.as_ptr() } pub unsafe fn from_ptr<'a>(ptr: *mut u32) -> &'a Self { // SAFETY: The caller must ensure that `ptr` is not aliased, and lasts // for the lifetime of the `Parker`. unsafe { mem::transmute(AtomicU32::from_ptr(ptr)) } } pub fn is_parked(&self) -> bool { self.mutex.load(Ordering::Acquire) == Self::PARKED } pub fn park(&self) { self.park_inner(|| ()); } pub fn park_with_callback(&self, before_sleep: F) where F: FnOnce(), { // This function is called when the thread is about to park. self.park_inner(before_sleep); } // If the caller wants to park the thread on a mutex'd condition, it is // possible for a deadlock to occur when the mutex is dropped just before // this parker atomically changes it's state to `PARKED`. For that reason, // we use a callback to allow the caller to perform any necessary actions // before parking, but after the parker is set to `PARKED`. If another // thread then checks the condition and checks if this thread should be // woken, we will be immediately notified. // Thusly, it is important that any caller synchronise any conditionals with // a mutex externally, and then unlock that mutex in `before_sleep`. fn park_inner(&self, before_sleep: F) where F: FnOnce(), { if self.mutex.fetch_sub(1, Ordering::Acquire) == Self::NOTIFIED { before_sleep(); // The thread was notified, so we can return immediately. return; } before_sleep(); loop { atomic_wait::wait(&self.mutex, Self::PARKED); // We check whether we were notified or woke up spuriously with // acquire ordering in order to make-visible any writes made by the // thread that notified us. if self.mutex.swap(Self::EMPTY, Ordering::Acquire) == Self::NOTIFIED { // The thread was notified, so we can return immediately. return; } else { // spurious wakeup, so we need to re-park. continue; } } } pub fn unpark(&self) { // write with Release ordering to ensure that any writes made by this // thread are made-available to the unparked thread. if self.mutex.swap(Self::NOTIFIED, Ordering::Release) == Self::PARKED { // The thread was parked, so we need to notify it. atomic_wait::wake_one(&self.mutex); } else { // The thread was not parked, so we don't need to do anything. } } } #[cfg(feature = "alloc")] pub mod channel { use alloc::sync::Arc; use core::{ cell::{Cell, UnsafeCell}, marker::PhantomData, mem::MaybeUninit, sync::atomic::{AtomicU32, Ordering}, }; #[repr(C)] #[derive(Debug)] struct Channel { state: AtomicU32, val: UnsafeCell>, } unsafe impl Send for Channel {} unsafe impl Sync for Channel {} impl Channel { const OCCUPIED_BIT: u32 = 0b01; const WAITING_BIT: u32 = 0b10; fn new() -> Self { Self { state: AtomicU32::new(0), val: UnsafeCell::new(MaybeUninit::uninit()), } } } pub fn channel() -> (Sender, Receiver) { let channel = Arc::new(Channel::::new()); let receiver = Receiver(channel.clone(), PhantomData); let sender = Sender(channel); (sender, receiver) } #[derive(Debug)] #[repr(transparent)] // `PhantomData>` is used to ensure that `Receiver` is `!Sync` but `Send`. pub struct Receiver(Arc>, PhantomData>); #[derive(Debug)] #[repr(transparent)] pub struct Sender(Arc>); impl Receiver { pub fn is_empty(&self) -> bool { self.0.state.load(Ordering::Acquire) & Channel::::OCCUPIED_BIT == 0 } pub fn as_sender(self) -> Sender { Sender(self.0.clone()) } fn wait(&mut self) { loop { let state = self .0 .state .fetch_or(Channel::::WAITING_BIT, Ordering::Acquire); if state & Channel::::OCCUPIED_BIT == 0 { // The channel is empty, so we need to wait for a value to be sent. // We will block until the sender wakes us up. atomic_wait::wait(&self.0.state, Channel::::WAITING_BIT); } else { // The channel is occupied, so we can return. self.0 .state .fetch_and(!Channel::::WAITING_BIT, Ordering::Release); break; } } } /// Takes the value from the channel, if it is present. /// this function must only ever return `Some` once. pub unsafe fn take(&mut self) -> Option { // unset the OCCUPIED_BIT to indicate that we are taking the value, if any is present. if self .0 .state .fetch_and(!Channel::::OCCUPIED_BIT, Ordering::Acquire) & Channel::::OCCUPIED_BIT == 0 { // The channel was empty, so we return None. None } else { // SAFETY: we only ever access this field by pointer // the OCCUPIED_BIT was set, so we can safely read the value. // this function is only called once, within `recv`, // guaranteeing that the value will only be dropped once. unsafe { Some(self.0.val.get().read().assume_init_read()) } } } pub fn recv(mut self) -> T { loop { // SAFETY: recv can only be called once, since it takes ownership of `self`. // if `take` returns a value, it will never be called again. if let Some(t) = unsafe { self.take() } { return t; } self.wait(); } } } impl Sender { pub fn send(self, value: T) { unsafe { self.0.val.get().write(MaybeUninit::new(value)); } // Set the OCCUPIED_BIT to indicate that a value is present. let state = self .0 .state .fetch_or(Channel::::OCCUPIED_BIT, Ordering::Release); assert!( state & Channel::::OCCUPIED_BIT == 0, "Channel is already occupied" ); // If there are any receivers waiting, we need to wake them up. if state & Channel::::WAITING_BIT != 0 { // There are receivers waiting, so we need to wake them up. atomic_wait::wake_all(&self.0.state); } } } } #[cfg(feature = "alloc")] pub mod queue { //! A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast). //! After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it. use alloc::{boxed::Box, sync::Arc, vec::Vec}; use core::{ cell::UnsafeCell, marker::{PhantomData, PhantomPinned}, mem::{self, MaybeUninit}, pin::Pin, ptr::{self, NonNull}, sync::atomic::{AtomicU32, Ordering}, }; use hashbrown::HashMap; use crate::{CachePadded, ptr::TaggedAtomicPtr}; use super::Parker; struct QueueInner { receivers: HashMap, bool)>>, messages: Vec, _phantom: core::marker::PhantomData, } pub struct Queue { inner: UnsafeCell>, lock: AtomicU32, } unsafe impl Send for Queue {} unsafe impl Sync for Queue where T: Send {} pub struct Receiver { queue: Arc>, lock: Pin>, } #[repr(transparent)] pub struct Sender { queue: Arc>, } // TODO: make this a linked list of slots so we can queue multiple messages for // a single receiver const SLOT_ALIGN: u8 = core::mem::align_of::().ilog2() as u8; struct Slot { value: UnsafeCell>, next_and_state: TaggedAtomicPtr, _phantom: PhantomData, } impl Slot { fn new() -> Self { Self { value: UnsafeCell::new(MaybeUninit::uninit()), next_and_state: TaggedAtomicPtr::new(ptr::null_mut(), 0), // 0 means empty _phantom: PhantomData, } } fn is_set(&self) -> bool { self.next_and_state.tag(Ordering::Acquire) == 1 } unsafe fn pop(&self) -> Option { NonNull::new(self.next_and_state.ptr(Ordering::Acquire)) .and_then(|next| { // SAFETY: The next slot is a valid pointer to a Slot that was allocated by us. unsafe { next.as_ref().pop() } }) .or_else(|| { if self .next_and_state .swap_tag(0, Ordering::AcqRel, Ordering::Relaxed) == 1 { // SAFETY: The value is only initialized when the state is set to 1. Some(unsafe { (&mut *self.value.get()).assume_init_read() }) } else { None } }) } /// this operation isn't atomic. #[allow(dead_code)] unsafe fn pop_front(&self) -> Option { // swap the slot at `next` with self, and return the value of self. // get next ptr, if it is non-null. if let Some(next) = NonNull::new(self.next_and_state.ptr(Ordering::Acquire)) { unsafe { // copy the next slot's next_and_state into self's next_and_state let (_, old) = self.next_and_state.copy_from( &next.as_ref().next_and_state, Ordering::Acquire, Ordering::Release, ); // copy the next slot's value into self's value mem::swap(&mut *self.value.get(), &mut *next.as_ref().value.get()); if old == 1 { // SAFETY: The value is only initialized when the state is set to 1. Some(next.as_ref().value.get().read().assume_init()) } else { // next was empty, so we return None. None } } } else { // next is null, so popping from the back or front is the same. unsafe { self.pop() } } } /// the caller must ensure that they have exclusive access to the slot unsafe fn push(&self, value: T) { if self.is_set() { let next = self.next_ptr(); unsafe { (next.as_ref()).push(value); } } else { // SAFETY: The value is only initialized when the state is set to 1. unsafe { (&mut *self.value.get()).write(value) }; self.next_and_state .set_tag(1, Ordering::Release, Ordering::Relaxed); } } fn next_ptr(&self) -> NonNull> { if let Some(next) = NonNull::new(self.next_and_state.ptr(Ordering::Acquire)) { next.cast() } else { self.alloc_next() } } fn alloc_next(&self) -> NonNull> { let next = Box::into_raw(Box::new(Slot::new())); let next = loop { match self.next_and_state.compare_exchange_weak_ptr( ptr::null_mut(), next, Ordering::Release, Ordering::Acquire, ) { Ok(_) => break next, Err(other) => { if other.is_null() { continue; } // next was allocated under us, so we need to drop the slot we just allocated again. _ = unsafe { Box::from_raw(next) }; break other; } } }; unsafe { // SAFETY: The next slot is a valid pointer to a Slot that was allocated by us. NonNull::new_unchecked(next) } } } impl Drop for Slot { fn drop(&mut self) { // drop next chain if let Some(next) = NonNull::new(self.next_and_state.swap_ptr( ptr::null_mut(), Ordering::Release, Ordering::Relaxed, )) { // SAFETY: The next slot is a valid pointer to a Slot that was allocated by us. // We drop this in place because idk.. unsafe { next.drop_in_place(); _ = Box::>::from_raw(next.cast().as_ptr()); } } // SAFETY: The value is only initialized when the state is set to 1. if mem::needs_drop::() && self.next_and_state.tag(Ordering::Acquire) == 1 { unsafe { (&mut *self.value.get()).assume_init_drop() }; } } } // const BLOCK_SIZE: usize = 8; // struct Block { // next: AtomicPtr>, // slots: [CachePadded>; BLOCK_SIZE], // } /// A token that can be used to identify a specific receiver in a queue. #[repr(transparent)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub struct ReceiverToken(crate::util::Send>); impl ReceiverToken { pub fn as_ptr(&self) -> *mut u32 { self.0.into_inner().as_ptr() } pub unsafe fn as_parker(&self) -> &Parker { // SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker. unsafe { Parker::from_ptr(self.as_ptr()) } } pub unsafe fn from_parker(parker: &Parker) -> Self { // SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker. let ptr = NonNull::from(parker).cast::(); ReceiverToken(crate::util::Send(ptr)) } } impl Queue { pub fn new() -> Arc { Arc::new(Self { inner: UnsafeCell::new(QueueInner { messages: Vec::new(), receivers: HashMap::new(), _phantom: PhantomData, }), lock: AtomicU32::new(0), }) } pub fn new_sender(self: &Arc) -> Sender { Sender { queue: self.clone(), } } pub fn num_receivers(self: &Arc) -> usize { let _guard = self.lock(); self.inner().receivers.len() } pub fn as_sender(self: &Arc) -> &Sender { unsafe { mem::transmute::<&Arc, &Sender>(self) } } pub fn new_receiver(self: &Arc) -> Receiver { let recv = Receiver { queue: self.clone(), lock: Box::pin((Parker::new(), PhantomPinned)), }; // allocate slot for the receiver let token = recv.get_token(); let _guard = recv.queue.lock(); recv.queue .inner() .receivers .insert(token, CachePadded::new((Slot::::new(), false))); drop(_guard); recv } fn lock(&self) -> impl Drop { unsafe { let lock = crate::sync::Lock::from_ptr(&self.lock as *const _ as _); lock.lock(); crate::drop_guard::DropGuard::new(|| lock.unlock()) } } fn inner(&self) -> &mut QueueInner { // SAFETY: The inner is only accessed while the queue is locked. unsafe { &mut *self.inner.get() } } } impl QueueInner { fn poll(&mut self, token: ReceiverToken) -> Option { // check if someone has sent a message to this receiver let CachePadded((slot, _)) = self.receivers.get(&token)?; unsafe { slot.pop() }.or_else(|| { // if the slot is empty, we can check the indexed messages self.messages.pop() }) } } impl Receiver { pub fn get_token(&self) -> ReceiverToken { // the token is just the pointer to the lock of this receiver. // the lock is pinned, so it's address is stable across calls to `receive`. ReceiverToken(crate::util::Send(NonNull::from(&self.lock.0).cast())) } } impl Drop for Receiver { fn drop(&mut self) { if mem::needs_drop::() { // lock the queue let _guard = self.queue.lock(); let queue = self.queue.inner(); // remove the receiver from the queue _ = queue.receivers.remove(&self.get_token()); } } } impl Receiver { pub fn recv(&self) -> T { let token = self.get_token(); loop { // lock the queue let _guard = self.queue.lock(); let queue = self.queue.inner(); // check if someone has sent a message to this receiver if let Some(t) = queue.poll(token) { queue.receivers.get_mut(&token).unwrap().1 = false; // mark the slot as not parked return t; } // there was no message for this receiver, so we need to park it queue.receivers.get_mut(&token).unwrap().1 = true; // mark the slot as parked self.lock.0.park_with_callback(move || { // drop the lock guard after having set the lock state to waiting. // this avoids a deadlock if the sender tries to send a message // while the receiver is in the process of parking (I think..) drop(_guard); }); } } pub fn try_recv(&self) -> Option { let token = self.get_token(); // lock the queue let _guard = self.queue.lock(); let queue = self.queue.inner(); // check if someone has sent a message to this receiver queue.poll(token) } } impl Sender { /// Sends a message to one of the receivers in the queue, or makes it /// available to any receiver that will park in the future. pub fn anycast(&self, value: T) { let _guard = self.queue.lock(); // SAFETY: The queue is locked, so we can safely access the inner queue. match unsafe { self.try_anycast_inner(value) } { Ok(_) => {} Err(value) => { // no parked receiver found, so we want to add the message to the indexed slots let queue = self.queue.inner(); queue.messages.push(value); // waking up a parked receiver is not necessary here, as any // receivers that don't have a free slot are currently waking up. } } } pub fn try_anycast(&self, value: T) -> Result<(), T> { // lock the queue let _guard = self.queue.lock(); // SAFETY: The queue is locked, so we can safely access the inner queue. unsafe { self.try_anycast_inner(value) } } /// The caller must hold the lock on the queue for the duration of this function. unsafe fn try_anycast_inner(&self, value: T) -> Result<(), T> { // look for a receiver that is parked let queue = self.queue.inner(); if let Some((token, slot)) = queue .receivers .iter() .find_map(|(token, CachePadded((slot, is_parked)))| { // ensure the slot is available if *is_parked && !slot.is_set() { Some((*token, slot)) } else { None } }) { // we found a receiver that is parked, so we can send the message to it unsafe { (&mut *slot.value.get()).write(value); slot.next_and_state .set_tag(1, Ordering::Release, Ordering::Relaxed); Parker::from_ptr(token.0.into_inner().as_ptr()).unpark(); } return Ok(()); } else { return Err(value); } } /// Sends a message to a specific receiver, waking it if it is parked. pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> { // lock the queue let _guard = self.queue.lock(); let queue = self.queue.inner(); let Some(CachePadded((slot, _))) = queue.receivers.get_mut(&receiver) else { return Err(value); }; unsafe { slot.push(value); } // wake the receiver unsafe { Parker::from_ptr(receiver.0.into_inner().as_ptr()).unpark(); } Ok(()) } pub fn broadcast(&self, value: T) where T: Clone, { // lock the queue let _guard = self.queue.lock(); let queue = self.queue.inner(); // send the message to all receivers for (token, CachePadded((slot, _))) in queue.receivers.iter() { // SAFETY: The slot is owned by this receiver. unsafe { slot.push(value.clone()) }; // wake the receiver unsafe { Parker::from_ptr(token.0.into_inner().as_ptr()).unpark(); } } } pub fn broadcast_with(&self, mut f: F) where F: FnMut() -> T, { // lock the queue let _guard = self.queue.lock(); let queue = self.queue.inner(); // send the message to all receivers for (token, CachePadded((slot, _))) in queue.receivers.iter() { // SAFETY: The slot is owned by this receiver. unsafe { slot.push(f()) }; // check if the receiver is parked // wake the receiver unsafe { Parker::from_ptr(token.0.into_inner().as_ptr()).unpark(); } } } } #[cfg(test)] mod tests { use std::println; use super::*; #[test] fn test_queue() { let queue = Queue::::new(); let sender = queue.new_sender(); let receiver1 = queue.new_receiver(); let receiver2 = queue.new_receiver(); let token2 = receiver2.get_token(); sender.anycast(42); assert_eq!(receiver1.recv(), 42); sender.unicast(100, token2).unwrap(); assert_eq!(receiver1.try_recv(), None); assert_eq!(receiver2.recv(), 100); } #[test] fn queue_broadcast() { let queue = Queue::::new(); let sender = queue.new_sender(); let receiver1 = queue.new_receiver(); let receiver2 = queue.new_receiver(); sender.broadcast(42); assert_eq!(receiver1.recv(), 42); assert_eq!(receiver2.recv(), 42); } #[test] fn queue_multiple_messages() { let queue = Queue::::new(); let sender = queue.new_sender(); let receiver = queue.new_receiver(); sender.anycast(1); sender.unicast(2, receiver.get_token()).unwrap(); assert_eq!(receiver.recv(), 2); assert_eq!(receiver.recv(), 1); } #[test] fn queue_threaded() { #[derive(Debug, Clone, Copy)] enum Message { Send(i32), Exit, } let queue = Queue::::new(); let sender = queue.new_sender(); let threads = (0..5) .map(|_| { let queue_clone = queue.clone(); let receiver = queue_clone.new_receiver(); std::thread::spawn(move || { loop { match receiver.recv() { Message::Send(value) => { println!( "Receiver {:?} Received: {}", receiver.get_token(), value ); } Message::Exit => { println!("Exiting thread"); break; } } } }) }) .collect::>(); // Send messages to the receivers for i in 0..10 { sender.anycast(Message::Send(i)); } // Send exit messages to all receivers sender.broadcast(Message::Exit); for thread in threads { thread.join().unwrap(); } println!("All threads have exited."); } #[test] fn drop_slot() { // Test that dropping a slot does not cause a double free or panic let slot = Slot::::new(); unsafe { slot.push(42); drop(slot); } } #[test] fn drop_slot_chain() { struct DropCheck<'a>(&'a AtomicU32); impl Drop for DropCheck<'_> { fn drop(&mut self) { self.0.fetch_sub(1, Ordering::SeqCst); } } impl<'a> DropCheck<'a> { fn new(counter: &'a AtomicU32) -> Self { counter.fetch_add(1, Ordering::SeqCst); Self(counter) } } let counter = AtomicU32::new(0); let slot = Slot::::new(); for _ in 0..10 { unsafe { slot.push(DropCheck::new(&counter)); } } assert_eq!(counter.load(Ordering::SeqCst), 10); drop(slot); assert_eq!( counter.load(Ordering::SeqCst), 0, "All DropCheck instances should have been dropped" ); } #[test] fn send_self() { // Test that sending a message to self works let queue = Queue::::new(); let sender = queue.new_sender(); let receiver = queue.new_receiver(); sender.unicast(42, receiver.get_token()).unwrap(); assert_eq!(receiver.recv(), 42); } #[test] fn send_self_many() { // Test that sending multiple messages to self works let queue = Queue::::new(); let sender = queue.new_sender(); let receiver = queue.new_receiver(); for i in 0..10 { sender.unicast(i, receiver.get_token()).unwrap(); } for i in (0..10).rev() { assert_eq!(receiver.recv(), i); } } #[test] fn slot_pop_front() { // Test that popping from the front of a slot works correctly let slot = Slot::::new(); unsafe { slot.push(1); slot.push(2); slot.push(3); } assert_eq!(unsafe { slot.pop_front() }, Some(1)); assert_eq!(unsafe { slot.pop_front() }, Some(2)); assert_eq!(unsafe { slot.pop_front() }, Some(3)); assert_eq!(unsafe { slot.pop_front() }, None); } } }