initial queue

This commit is contained in:
Janis 2025-07-03 14:46:13 +02:00
parent f8aa8d9615
commit 41166898ff
3 changed files with 297 additions and 0 deletions

View file

@ -23,6 +23,8 @@ parking_lot_core = "0.9.10"
crossbeam-utils = "0.8.21"
either = "1.15.0"
werkzeug = {path = "../../werkzeug", features = ["std", "nightly"]}
async-task = "4.7.1"
[dev-dependencies]

View file

@ -23,6 +23,7 @@ mod join;
mod latch;
#[cfg(feature = "metrics")]
mod metrics;
mod queue;
mod scope;
mod threadpool;
pub mod util;

294
distaff/src/queue.rs Normal file
View file

@ -0,0 +1,294 @@
use std::{
cell::UnsafeCell,
collections::{HashMap, HashSet},
marker::{PhantomData, PhantomPinned},
mem::MaybeUninit,
pin::Pin,
sync::{
Arc,
atomic::{AtomicU8, AtomicU32, Ordering},
},
};
use crossbeam_utils::CachePadded;
// A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast).
// After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it.
struct QueueInner<T> {
parked: HashSet<ReceiverToken>,
owned: HashMap<ReceiverToken, CachePadded<Slot<T>>>,
messages: Vec<T>,
_phantom: std::marker::PhantomData<T>,
}
struct Queue<T> {
inner: UnsafeCell<QueueInner<T>>,
lock: AtomicU32,
}
enum SlotKey {
Owned(ReceiverToken),
Indexed(usize),
}
struct Receiver<T> {
queue: Arc<Queue<T>>,
lock: Pin<Box<(AtomicU32, PhantomPinned)>>,
}
struct Sender<T> {
queue: Arc<Queue<T>>,
}
// TODO: make this a linked list of slots so we can queue multiple messages for
// a single receiver
struct Slot<T> {
value: UnsafeCell<MaybeUninit<T>>,
state: AtomicU8,
}
// const BLOCK_SIZE: usize = 8;
// struct Block<T> {
// next: AtomicPtr<Block<T>>,
// slots: [CachePadded<Slot<T>>; BLOCK_SIZE],
// }
/// A token that can be used to identify a specific receiver in a queue.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct ReceiverToken(werkzeug::util::Send<*const u32>);
impl<T> Queue<T> {
pub fn new() -> Self {
Self {
inner: UnsafeCell::new(QueueInner {
parked: HashSet::new(),
messages: Vec::new(),
owned: HashMap::new(),
_phantom: PhantomData,
}),
lock: AtomicU32::new(0),
}
}
pub fn new_sender(self: &Arc<Self>) -> Sender<T> {
Sender {
queue: self.clone(),
}
}
pub fn new_receiver(self: &Arc<Self>) -> Receiver<T> {
let recv = Receiver {
queue: self.clone(),
lock: Box::pin((AtomicU32::new(0), PhantomPinned)),
};
// allocate slot for the receiver
let token = recv.get_token();
let _guard = recv.queue.lock();
recv.queue.inner().owned.insert(
token,
CachePadded::new(Slot {
value: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(0), // 0 means empty
}),
);
drop(_guard);
recv
}
fn lock(&self) -> impl Drop {
unsafe {
let lock = werkzeug::sync::Lock::from_ptr(&self.lock as *const _ as _);
lock.lock();
werkzeug::drop_guard::DropGuard::new(|| lock.unlock())
}
}
fn inner(&self) -> &mut QueueInner<T> {
// SAFETY: The inner is only accessed while the queue is locked.
unsafe { &mut *self.inner.get() }
}
}
impl<T> QueueInner<T> {
fn poll(&mut self, token: ReceiverToken) -> Option<T> {
// check if someone has sent a message to this receiver
let slot = self.owned.get(&token).unwrap();
if slot.state.swap(0, Ordering::Acquire) == 1 {
// SAFETY: the slot is owned by this receiver and contains a message.
return Some(unsafe { slot.value.as_ref_unchecked().assume_init_read() });
} else if let Some(t) = self.messages.pop() {
return Some(t);
} else {
None
}
}
}
impl<T> Receiver<T> {
fn get_token(&self) -> ReceiverToken {
// the token is just the pointer to the lock of this receiver.
// the lock is pinned, so it's address is stable across calls to `receive`.
ReceiverToken(werkzeug::util::Send(
&self.lock.0 as *const AtomicU32 as *const u32,
))
}
}
impl<T: Send> Receiver<T> {
pub fn recv(&self) -> T {
let token = self.get_token();
loop {
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
// check if someone has sent a message to this receiver
if let Some(t) = queue.poll(token) {
return t;
}
// there was no message for this receiver, so we need to park it
queue.parked.insert(token);
// wait for a message to be sent to this receiver
drop(_guard);
unsafe {
let lock = werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut());
lock.wait();
}
}
}
pub fn try_recv(&self) -> Option<T> {
let token = self.get_token();
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
// check if someone has sent a message to this receiver
queue.poll(token)
}
}
impl<T: Send> Sender<T> {
/// Sends a message to one of the receivers in the queue, or makes it
/// available to any receiver that will park in the future.
pub fn anycast(&self, value: T) {
// look for a receiver that is parked
let _guard = self.queue.lock();
let queue = self.queue.inner();
if let Some((token, slot)) = queue.parked.iter().find_map(|token| {
// ensure the slot is available
queue.owned.get(token).and_then(|s| {
if s.state.load(Ordering::Acquire) == 0 {
Some((*token, s))
} else {
None
}
})
}) {
// we found a receiver that is parked, so we can send the message to it
unsafe {
slot.value.as_mut_unchecked().write(value);
slot.state.store(1, Ordering::Release);
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
}
return;
} else {
// no parked receiver found, so we want to add the message to the indexed slots
queue.messages.push(value);
// waking up a parked receiver is not necessary here, as any
// receivers that don't have a free slot are currently waking up.
}
}
/// Sends a message to a specific receiver, waking it if it is parked.
pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> {
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
let Some(slot) = queue.owned.get_mut(&receiver) else {
return Err(value);
};
// SAFETY: The slot is owned by this receiver.
unsafe { slot.value.as_mut_unchecked().write(value) };
slot.state.store(1, Ordering::Release);
// check if the receiver is parked
if queue.parked.contains(&receiver) {
// wake the receiver
unsafe {
werkzeug::sync::Lock::from_ptr(receiver.0.into_inner().cast_mut()).wake_one();
}
}
Ok(())
}
pub fn broadcast(&self, value: T)
where
T: Clone,
{
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
// send the message to all receivers
for (token, slot) in queue.owned.iter() {
// SAFETY: The slot is owned by this receiver.
if slot.state.load(Ordering::Acquire) != 0 {
// the slot is not available, so we skip it
continue;
}
unsafe {
slot.value.as_mut_unchecked().write(value.clone());
}
slot.state.store(1, Ordering::Release);
// check if the receiver is parked
if queue.parked.contains(token) {
// wake the receiver
unsafe {
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_queue() {
let queue = Arc::new(Queue::<i32>::new());
let sender = queue.new_sender();
let receiver1 = queue.new_receiver();
let receiver2 = queue.new_receiver();
let token1 = receiver1.get_token();
let token2 = receiver2.get_token();
sender.anycast(42);
assert_eq!(receiver1.recv(), 42);
sender.unicast(100, token2).unwrap();
assert_eq!(receiver1.try_recv(), None);
assert_eq!(receiver2.recv(), 100);
}
}