|
|
|
@ -0,0 +1,328 @@
|
|
|
|
|
use std::{
|
|
|
|
|
cell::UnsafeCell,
|
|
|
|
|
collections::{HashMap, HashSet},
|
|
|
|
|
marker::{PhantomData, PhantomPinned},
|
|
|
|
|
mem::{self, MaybeUninit},
|
|
|
|
|
pin::Pin,
|
|
|
|
|
sync::{
|
|
|
|
|
Arc,
|
|
|
|
|
atomic::{AtomicU8, AtomicU32, Ordering},
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
use crossbeam_utils::CachePadded;
|
|
|
|
|
|
|
|
|
|
// A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast).
|
|
|
|
|
// After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it.
|
|
|
|
|
|
|
|
|
|
struct QueueInner<T> {
|
|
|
|
|
parked: HashSet<ReceiverToken>,
|
|
|
|
|
owned: HashMap<ReceiverToken, CachePadded<Slot<T>>>,
|
|
|
|
|
messages: Vec<T>,
|
|
|
|
|
_phantom: std::marker::PhantomData<T>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Queue<T> {
|
|
|
|
|
inner: UnsafeCell<QueueInner<T>>,
|
|
|
|
|
lock: AtomicU32,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
enum SlotKey {
|
|
|
|
|
Owned(ReceiverToken),
|
|
|
|
|
Indexed(usize),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Receiver<T> {
|
|
|
|
|
queue: Arc<Queue<T>>,
|
|
|
|
|
lock: Pin<Box<(AtomicU32, PhantomPinned)>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Sender<T> {
|
|
|
|
|
queue: Arc<Queue<T>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: make this a linked list of slots so we can queue multiple messages for
|
|
|
|
|
// a single receiver
|
|
|
|
|
struct Slot<T> {
|
|
|
|
|
value: UnsafeCell<MaybeUninit<T>>,
|
|
|
|
|
state: AtomicU8,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<T> Slot<T> {
|
|
|
|
|
fn new() -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
value: UnsafeCell::new(MaybeUninit::uninit()),
|
|
|
|
|
state: AtomicU8::new(0), // 0 means empty
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn set(&self, value: T) {}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<T> Drop for Slot<T> {
|
|
|
|
|
fn drop(&mut self) {
|
|
|
|
|
// SAFETY: The value is only initialized when the state is set to 1.
|
|
|
|
|
if mem::needs_drop::<T>() && self.state.load(Ordering::Acquire) == 1 {
|
|
|
|
|
unsafe { self.value.as_mut_unchecked().assume_init_drop() };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// const BLOCK_SIZE: usize = 8;
|
|
|
|
|
// struct Block<T> {
|
|
|
|
|
// next: AtomicPtr<Block<T>>,
|
|
|
|
|
// slots: [CachePadded<Slot<T>>; BLOCK_SIZE],
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
/// A token that can be used to identify a specific receiver in a queue.
|
|
|
|
|
#[repr(transparent)]
|
|
|
|
|
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
|
|
|
|
pub struct ReceiverToken(werkzeug::util::Send<*const u32>);
|
|
|
|
|
|
|
|
|
|
impl<T> Queue<T> {
|
|
|
|
|
pub fn new() -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
inner: UnsafeCell::new(QueueInner {
|
|
|
|
|
parked: HashSet::new(),
|
|
|
|
|
messages: Vec::new(),
|
|
|
|
|
owned: HashMap::new(),
|
|
|
|
|
_phantom: PhantomData,
|
|
|
|
|
}),
|
|
|
|
|
lock: AtomicU32::new(0),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn new_sender(self: &Arc<Self>) -> Sender<T> {
|
|
|
|
|
Sender {
|
|
|
|
|
queue: self.clone(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn new_receiver(self: &Arc<Self>) -> Receiver<T> {
|
|
|
|
|
let recv = Receiver {
|
|
|
|
|
queue: self.clone(),
|
|
|
|
|
lock: Box::pin((AtomicU32::new(0), PhantomPinned)),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// allocate slot for the receiver
|
|
|
|
|
let token = recv.get_token();
|
|
|
|
|
let _guard = recv.queue.lock();
|
|
|
|
|
recv.queue.inner().owned.insert(
|
|
|
|
|
token,
|
|
|
|
|
CachePadded::new(Slot {
|
|
|
|
|
value: UnsafeCell::new(MaybeUninit::uninit()),
|
|
|
|
|
state: AtomicU8::new(0), // 0 means empty
|
|
|
|
|
}),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
drop(_guard);
|
|
|
|
|
recv
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn lock(&self) -> impl Drop {
|
|
|
|
|
unsafe {
|
|
|
|
|
let lock = werkzeug::sync::Lock::from_ptr(&self.lock as *const _ as _);
|
|
|
|
|
lock.lock();
|
|
|
|
|
werkzeug::drop_guard::DropGuard::new(|| lock.unlock())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn inner(&self) -> &mut QueueInner<T> {
|
|
|
|
|
// SAFETY: The inner is only accessed while the queue is locked.
|
|
|
|
|
unsafe { &mut *self.inner.get() }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<T> QueueInner<T> {
|
|
|
|
|
fn poll(&mut self, token: ReceiverToken) -> Option<T> {
|
|
|
|
|
// check if someone has sent a message to this receiver
|
|
|
|
|
let slot = self.owned.get(&token).unwrap();
|
|
|
|
|
if slot.state.swap(0, Ordering::Acquire) == 1 {
|
|
|
|
|
// SAFETY: the slot is owned by this receiver and contains a message.
|
|
|
|
|
return Some(unsafe { slot.value.as_ref_unchecked().assume_init_read() });
|
|
|
|
|
} else if let Some(t) = self.messages.pop() {
|
|
|
|
|
return Some(t);
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<T> Receiver<T> {
|
|
|
|
|
fn get_token(&self) -> ReceiverToken {
|
|
|
|
|
// the token is just the pointer to the lock of this receiver.
|
|
|
|
|
// the lock is pinned, so it's address is stable across calls to `receive`.
|
|
|
|
|
|
|
|
|
|
ReceiverToken(werkzeug::util::Send(
|
|
|
|
|
&self.lock.0 as *const AtomicU32 as *const u32,
|
|
|
|
|
))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<T> Drop for Receiver<T> {
|
|
|
|
|
fn drop(&mut self) {
|
|
|
|
|
if mem::needs_drop::<T>() {
|
|
|
|
|
// lock the queue
|
|
|
|
|
let _guard = self.queue.lock();
|
|
|
|
|
let queue = self.queue.inner();
|
|
|
|
|
|
|
|
|
|
// remove the receiver from the queue
|
|
|
|
|
_ = queue.owned.remove(&self.get_token());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<T: Send> Receiver<T> {
|
|
|
|
|
pub fn recv(&self) -> T {
|
|
|
|
|
let token = self.get_token();
|
|
|
|
|
|
|
|
|
|
loop {
|
|
|
|
|
// lock the queue
|
|
|
|
|
let _guard = self.queue.lock();
|
|
|
|
|
let queue = self.queue.inner();
|
|
|
|
|
|
|
|
|
|
// check if someone has sent a message to this receiver
|
|
|
|
|
if let Some(t) = queue.poll(token) {
|
|
|
|
|
queue.parked.remove(&token);
|
|
|
|
|
return t;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// there was no message for this receiver, so we need to park it
|
|
|
|
|
queue.parked.insert(token);
|
|
|
|
|
|
|
|
|
|
// wait for a message to be sent to this receiver
|
|
|
|
|
drop(_guard);
|
|
|
|
|
unsafe {
|
|
|
|
|
let lock = werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut());
|
|
|
|
|
lock.wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn try_recv(&self) -> Option<T> {
|
|
|
|
|
let token = self.get_token();
|
|
|
|
|
|
|
|
|
|
// lock the queue
|
|
|
|
|
let _guard = self.queue.lock();
|
|
|
|
|
let queue = self.queue.inner();
|
|
|
|
|
|
|
|
|
|
// check if someone has sent a message to this receiver
|
|
|
|
|
queue.poll(token)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<T: Send> Sender<T> {
|
|
|
|
|
/// Sends a message to one of the receivers in the queue, or makes it
|
|
|
|
|
/// available to any receiver that will park in the future.
|
|
|
|
|
pub fn anycast(&self, value: T) {
|
|
|
|
|
// look for a receiver that is parked
|
|
|
|
|
let _guard = self.queue.lock();
|
|
|
|
|
let queue = self.queue.inner();
|
|
|
|
|
if let Some((token, slot)) = queue.parked.iter().find_map(|token| {
|
|
|
|
|
// ensure the slot is available
|
|
|
|
|
queue.owned.get(token).and_then(|s| {
|
|
|
|
|
if s.state.load(Ordering::Acquire) == 0 {
|
|
|
|
|
Some((*token, s))
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}) {
|
|
|
|
|
// we found a receiver that is parked, so we can send the message to it
|
|
|
|
|
unsafe {
|
|
|
|
|
slot.value.as_mut_unchecked().write(value);
|
|
|
|
|
slot.state.store(1, Ordering::Release);
|
|
|
|
|
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
} else {
|
|
|
|
|
// no parked receiver found, so we want to add the message to the indexed slots
|
|
|
|
|
queue.messages.push(value);
|
|
|
|
|
|
|
|
|
|
// waking up a parked receiver is not necessary here, as any
|
|
|
|
|
// receivers that don't have a free slot are currently waking up.
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Sends a message to a specific receiver, waking it if it is parked.
|
|
|
|
|
pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> {
|
|
|
|
|
// lock the queue
|
|
|
|
|
let _guard = self.queue.lock();
|
|
|
|
|
let queue = self.queue.inner();
|
|
|
|
|
|
|
|
|
|
let Some(slot) = queue.owned.get_mut(&receiver) else {
|
|
|
|
|
return Err(value);
|
|
|
|
|
};
|
|
|
|
|
// SAFETY: The slot is owned by this receiver.
|
|
|
|
|
unsafe { slot.value.as_mut_unchecked().write(value) };
|
|
|
|
|
slot.state.store(1, Ordering::Release);
|
|
|
|
|
|
|
|
|
|
// check if the receiver is parked
|
|
|
|
|
if queue.parked.contains(&receiver) {
|
|
|
|
|
// wake the receiver
|
|
|
|
|
unsafe {
|
|
|
|
|
werkzeug::sync::Lock::from_ptr(receiver.0.into_inner().cast_mut()).wake_one();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn broadcast(&self, value: T)
|
|
|
|
|
where
|
|
|
|
|
T: Clone,
|
|
|
|
|
{
|
|
|
|
|
// lock the queue
|
|
|
|
|
let _guard = self.queue.lock();
|
|
|
|
|
let queue = self.queue.inner();
|
|
|
|
|
|
|
|
|
|
// send the message to all receivers
|
|
|
|
|
for (token, slot) in queue.owned.iter() {
|
|
|
|
|
// SAFETY: The slot is owned by this receiver.
|
|
|
|
|
|
|
|
|
|
if slot.state.load(Ordering::Acquire) != 0 {
|
|
|
|
|
// the slot is not available, so we skip it
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsafe {
|
|
|
|
|
slot.value.as_mut_unchecked().write(value.clone());
|
|
|
|
|
}
|
|
|
|
|
slot.state.store(1, Ordering::Release);
|
|
|
|
|
|
|
|
|
|
// check if the receiver is parked
|
|
|
|
|
if queue.parked.contains(token) {
|
|
|
|
|
// wake the receiver
|
|
|
|
|
unsafe {
|
|
|
|
|
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn test_queue() {
|
|
|
|
|
let queue = Arc::new(Queue::<i32>::new());
|
|
|
|
|
|
|
|
|
|
let sender = queue.new_sender();
|
|
|
|
|
let receiver1 = queue.new_receiver();
|
|
|
|
|
let receiver2 = queue.new_receiver();
|
|
|
|
|
|
|
|
|
|
let token1 = receiver1.get_token();
|
|
|
|
|
let token2 = receiver2.get_token();
|
|
|
|
|
|
|
|
|
|
sender.anycast(42);
|
|
|
|
|
|
|
|
|
|
assert_eq!(receiver1.recv(), 42);
|
|
|
|
|
|
|
|
|
|
sender.unicast(100, token2).unwrap();
|
|
|
|
|
assert_eq!(receiver1.try_recv(), None);
|
|
|
|
|
assert_eq!(receiver2.recv(), 100);
|
|
|
|
|
}
|
|
|
|
|
}
|