Compare commits

...

5 commits

10 changed files with 767 additions and 497 deletions

View file

@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[profile.bench] [profile.bench]
opt-level = 0
debug = true debug = true
[profile.release] [profile.release]

View file

@ -17,62 +17,9 @@ enum State {
Taken, Taken,
} }
// taken from `std` pub use werkzeug::sync::Parker;
#[derive(Debug)]
#[repr(transparent)]
pub struct Parker {
mutex: AtomicU32,
}
impl Parker { use crate::queue::Queue;
const PARKED: u32 = u32::MAX;
const EMPTY: u32 = 0;
const NOTIFIED: u32 = 1;
pub fn new() -> Self {
Self {
mutex: AtomicU32::new(Self::EMPTY),
}
}
pub fn is_parked(&self) -> bool {
self.mutex.load(Ordering::Acquire) == Self::PARKED
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(this = self as *const Self as usize)))]
pub fn park(&self) {
if self.mutex.fetch_sub(1, Ordering::Acquire) == Self::NOTIFIED {
// The thread was notified, so we can return immediately.
return;
}
loop {
atomic_wait::wait(&self.mutex, Self::PARKED);
// We check whether we were notified or woke up spuriously with
// acquire ordering in order to make-visible any writes made by the
// thread that notified us.
if self.mutex.swap(Self::EMPTY, Ordering::Acquire) == Self::NOTIFIED {
// The thread was notified, so we can return immediately.
return;
} else {
// spurious wakeup, so we need to re-park.
continue;
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(this = self as *const Self as usize)))]
pub fn unpark(&self) {
// write with Release ordering to ensure that any writes made by this
// thread are made-available to the unparked thread.
if self.mutex.swap(Self::NOTIFIED, Ordering::Release) == Self::PARKED {
// The thread was parked, so we need to notify it.
atomic_wait::wake_one(&self.mutex);
} else {
// The thread was not parked, so we don't need to do anything.
}
}
}
#[derive(Debug)] #[derive(Debug)]
#[repr(C)] #[repr(C)]
@ -104,6 +51,10 @@ impl<T: Send> Receiver<T> {
self.0.state.load(Ordering::Acquire) != State::Ready as u8 self.0.state.load(Ordering::Acquire) != State::Ready as u8
} }
pub fn sender(&self) -> Sender<T> {
Sender(self.0.clone())
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait(&self) { pub fn wait(&self) {
loop { loop {
@ -182,15 +133,15 @@ impl<T: Send> Receiver<T> {
// `State::Ready`. // `State::Ready`.
// //
// In either case, this thread now has unique access to `val`. // In either case, this thread now has unique access to `val`.
unsafe { self.take() }
}
unsafe fn take(&self) -> thread::Result<T> {
assert_eq!( assert_eq!(
self.0.state.swap(State::Taken as u8, Ordering::Acquire), self.0.state.swap(State::Taken as u8, Ordering::Acquire),
State::Ready as u8 State::Ready as u8
); );
unsafe { self.take() }
}
unsafe fn take(&self) -> thread::Result<T> {
let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() }; let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() };
result result
@ -221,6 +172,30 @@ impl<T: Send> Sender<T> {
} }
} }
} }
pub unsafe fn parker(&self) -> &Parker {
unsafe { self.0.waiting_thread.as_ref() }
}
/// The caller must ensure that this function or `send` are only ever called once.
pub unsafe fn send_as_ref(&self, val: thread::Result<T>) {
// SAFETY:
// Only this thread can write to `val` and none can read it
// yet.
unsafe {
*self.0.val.get() = Some(Box::new(val));
}
if self.0.state.swap(State::Ready as u8, Ordering::AcqRel) == State::Waiting as u8 {
// SAFETY:
// A `Receiver` already wrote its thread to `waiting_thread`
// *before* setting the `state` to `State::Waiting`.
unsafe {
let thread = self.0.waiting_thread.as_ref();
thread.unpark();
}
}
}
} }
pub fn channel<T: Send>(thread: NonNull<Parker>) -> (Sender<T>, Receiver<T>) { pub fn channel<T: Send>(thread: NonNull<Parker>) -> (Sender<T>, Receiver<T>) {

View file

@ -1,4 +1,9 @@
use std::{ use std::{
cell::UnsafeCell,
marker::PhantomPinned,
mem::{self, ManuallyDrop},
panic::{AssertUnwindSafe, catch_unwind},
pin::Pin,
ptr::NonNull, ptr::NonNull,
sync::{ sync::{
Arc, OnceLock, Arc, OnceLock,
@ -9,21 +14,28 @@ use std::{
use alloc::collections::BTreeMap; use alloc::collections::BTreeMap;
use async_task::Runnable; use async_task::Runnable;
use parking_lot::{Condvar, Mutex};
use crate::{ use crate::{
channel::{Parker, Sender}, channel::{Parker, Sender},
heartbeat::HeartbeatList, heartbeat::HeartbeatList,
job::{HeapJob, Job2 as Job, SharedJob, StackJob}, job::{HeapJob, Job2 as Job, SharedJob, StackJob},
queue::ReceiverToken,
util::DropGuard, util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread}, workerthread::{HeartbeatThread, WorkerThread},
}; };
pub struct Context { pub struct Context {
shared: Mutex<Shared>,
pub shared_job: Condvar,
should_exit: AtomicBool, should_exit: AtomicBool,
pub heartbeats: HeartbeatList, pub heartbeats: HeartbeatList,
pub(crate) queue: Arc<crate::queue::Queue<Message>>,
pub(crate) heartbeat: Parker,
}
pub(crate) enum Message {
Shared(SharedJob),
WakeUp,
Exit,
ScopeFinished,
} }
pub(crate) struct Shared { pub(crate) struct Shared {
@ -52,22 +64,15 @@ impl Shared {
} }
impl Context { impl Context {
pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
self.shared.lock()
}
pub fn new_with_threads(num_threads: usize) -> Arc<Self> { pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
tracing::trace!("Creating context with {} threads", num_threads); tracing::trace!("Creating context with {} threads", num_threads);
let this = Arc::new(Self { let this = Arc::new(Self {
shared: Mutex::new(Shared {
jobs: BTreeMap::new(),
injected_jobs: Vec::new(),
}),
shared_job: Condvar::new(),
should_exit: AtomicBool::new(false), should_exit: AtomicBool::new(false),
heartbeats: HeartbeatList::new(), heartbeats: HeartbeatList::new(),
queue: crate::queue::Queue::new(),
heartbeat: Parker::new(),
}); });
// Create a barrier to synchronize the worker threads and the heartbeat thread // Create a barrier to synchronize the worker threads and the heartbeat thread
@ -94,7 +99,7 @@ impl Context {
std::thread::Builder::new() std::thread::Builder::new()
.name("heartbeat-thread".to_string()) .name("heartbeat-thread".to_string())
.spawn(move || { .spawn(move || {
HeartbeatThread::new(ctx).run(barrier); HeartbeatThread::new(ctx, num_threads).run(barrier);
}) })
.expect("Failed to spawn heartbeat thread"); .expect("Failed to spawn heartbeat thread");
} }
@ -106,7 +111,7 @@ impl Context {
pub fn set_should_exit(&self) { pub fn set_should_exit(&self) {
self.should_exit.store(true, Ordering::Relaxed); self.should_exit.store(true, Ordering::Relaxed);
self.heartbeats.notify_all(); self.queue.as_sender().broadcast_with(|| Message::Exit);
} }
pub fn should_exit(&self) -> bool { pub fn should_exit(&self) -> bool {
@ -124,31 +129,7 @@ impl Context {
} }
pub fn inject_job(&self, job: SharedJob) { pub fn inject_job(&self, job: SharedJob) {
let mut shared = self.shared.lock(); self.queue.as_sender().anycast(Message::Shared(job));
shared.injected_jobs.push(job);
unsafe {
// SAFETY: we are holding the shared lock, so it is safe to notify
self.notify_job_shared();
}
}
/// caller should hold the shared lock while calling this
pub unsafe fn notify_job_shared(&self) {
let heartbeats = self.heartbeats.inner();
if let Some((i, sender)) = heartbeats
.iter()
.find(|(_, heartbeat)| heartbeat.is_waiting())
.or_else(|| heartbeats.iter().next())
{
_ = i;
#[cfg(feature = "tracing")]
tracing::trace!("Notifying worker thread {} about job sharing", i);
sender.wake();
} else {
#[cfg(feature = "tracing")]
tracing::warn!("No worker found to notify about job sharing");
}
} }
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
@ -160,13 +141,18 @@ impl Context {
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
// SAFETY: we are waiting on this latch in this thread. // SAFETY: we are waiting on this latch in this thread.
let job = StackJob::new(move |worker: &WorkerThread| f(worker)); let _pinned = StackJob::new(move |worker: &WorkerThread| f(worker));
let job = unsafe { Pin::new_unchecked(&_pinned) };
let job = Job::from_stackjob(&job); let job = Job::from_stackjob(&job);
unsafe {
self.inject_job(job.share(Some(worker.receiver.get_token().as_parker())));
}
self.inject_job(job.share(Some(worker.heartbeat.parker()))); let t = worker.wait_until_recv(job.take_receiver().expect("Job should have a receiver"));
let t = worker.wait_until_shared_job(&job).unwrap(); // touch the job to ensure it is dropped after we are done with it.
drop(_pinned);
crate::util::unwrap_or_panic(t) crate::util::unwrap_or_panic(t)
} }
@ -180,15 +166,69 @@ impl Context {
// current thread isn't a worker thread, create job and inject into context // current thread isn't a worker thread, create job and inject into context
let parker = Parker::new(); let parker = Parker::new();
let job = StackJob::new(move |worker: &WorkerThread| f(worker)); struct CrossJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
_pin: PhantomPinned,
}
let job = Job::from_stackjob(&job); impl<F> CrossJob<F> {
fn new(f: F) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
_pin: PhantomPinned,
}
}
self.inject_job(job.share(Some(&parker))); fn into_job<T>(self: &Self) -> Job<T>
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
}
let recv = job.take_receiver().unwrap(); unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
crate::util::unwrap_or_panic(recv.recv()) #[align(8)]
unsafe fn harness<T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
let this: &CrossJob<F> = unsafe { this.cast().as_ref() };
let f = unsafe { this.unwrap() };
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
let sender = sender.unwrap();
unsafe {
sender.send_as_ref(result);
worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
}
}
let pinned = CrossJob::new(move |worker: &WorkerThread| f(worker));
let job2 = pinned.into_job();
self.inject_job(job2.share(Some(&parker)));
let recv = job2.take_receiver().unwrap();
let out = crate::util::unwrap_or_panic(recv.recv());
// touch the job to ensure it is dropped after we are done with it.
drop(pinned);
drop(parker);
out
} }
/// Run closure in this context. /// Run closure in this context.
@ -375,9 +415,12 @@ mod tests {
ctx.inject_job(job.share(Some(&parker))); ctx.inject_job(job.share(Some(&parker)));
// Wait for the job to be executed // Wait for the job to be executed
let recv = job.take_receiver().unwrap(); let recv = job.take_receiver().expect("Job should have a receiver");
let result = recv.recv(); let Some(result) = recv.poll() else {
let result = crate::util::unwrap_or_panic(result); panic!("Expected a finished message");
};
let result = crate::util::unwrap_or_panic::<i32>(result);
assert_eq!(result, 42); assert_eq!(result, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1); assert_eq!(counter.load(Ordering::SeqCst), 1);
} }

View file

@ -4,13 +4,15 @@ use core::{
mem::{self, ManuallyDrop}, mem::{self, ManuallyDrop},
ptr::NonNull, ptr::NonNull,
}; };
use std::cell::Cell; use std::{cell::Cell, marker::PhantomPinned};
use alloc::boxed::Box; use alloc::boxed::Box;
use crate::{ use crate::{
WorkerThread, WorkerThread,
channel::{Parker, Sender}, channel::{Parker, Receiver, Sender},
context::Message,
queue::ReceiverToken,
}; };
#[repr(transparent)] #[repr(transparent)]
@ -43,65 +45,89 @@ impl<F> HeapJob<F> {
} }
} }
type JobHarness = type JobHarness = unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<Sender>);
unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<crate::channel::Sender>);
#[repr(C)] #[repr(C)]
pub struct Job2<T = ()> { pub struct Job2<T = ()> {
harness: JobHarness, inner: UnsafeCell<Job2Inner<T>>,
this: NonNull<()>,
receiver: Cell<Option<crate::channel::Receiver<T>>>,
} }
impl<T> Debug for Job2<T> { impl<T> Debug for Job2<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Job2") f.debug_struct("Job2").field("inner", &self.inner).finish()
.field("harness", &self.harness)
.field("this", &self.this)
.finish_non_exhaustive()
} }
} }
#[repr(C)]
pub enum Job2Inner<T = ()> {
Local {
harness: JobHarness,
this: NonNull<()>,
_pin: PhantomPinned,
},
Shared {
receiver: Cell<Option<Receiver<T>>>,
},
}
#[derive(Debug)] #[derive(Debug)]
pub struct SharedJob { pub struct SharedJob {
harness: JobHarness, harness: JobHarness,
this: NonNull<()>, this: NonNull<()>,
sender: Option<crate::channel::Sender>, sender: Option<Sender<()>>,
} }
unsafe impl Send for SharedJob {}
impl<T: Send> Job2<T> { impl<T: Send> Job2<T> {
fn new(harness: JobHarness, this: NonNull<()>) -> Self { fn new(harness: JobHarness, this: NonNull<()>) -> Self {
let this = Self { let this = Self {
harness, inner: UnsafeCell::new(Job2Inner::Local {
this, harness: harness,
receiver: Cell::new(None), this,
_pin: PhantomPinned,
}),
}; };
#[cfg(feature = "tracing")]
tracing::trace!("new job: {:?}", this);
this this
} }
pub fn share(&self, parker: Option<&Parker>) -> SharedJob { pub fn share(&self, parker: Option<&Parker>) -> SharedJob {
#[cfg(feature = "tracing")]
tracing::trace!("sharing job: {:?}", self);
let (sender, receiver) = parker let (sender, receiver) = parker
.map(|parker| crate::channel::channel::<T>(parker.into())) .map(|parker| crate::channel::channel::<T>(parker.into()))
.unzip(); .unzip();
self.receiver.set(receiver); // self.receiver.set(receiver);
if let Job2Inner::Local {
SharedJob { harness,
harness: self.harness, this,
this: self.this, _pin: _,
sender: unsafe { mem::transmute(sender) }, } = unsafe {
self.inner.replace(Job2Inner::Shared {
receiver: Cell::new(receiver),
})
} {
// SAFETY: `this` is a valid pointer to the job.
unsafe {
SharedJob {
harness,
this,
sender: mem::transmute(sender), // Convert `Option<Sender<T>>` to `Option<Sender<()>>`
}
}
} else {
panic!("Job2 is already shared");
} }
} }
pub fn take_receiver(&self) -> Option<crate::channel::Receiver<T>> { pub fn take_receiver(&self) -> Option<Receiver<T>> {
self.receiver.take() unsafe {
if let Job2Inner::Shared { receiver } = self.inner.as_ref_unchecked() {
receiver.take()
} else {
None
}
}
} }
pub fn from_stackjob<F>(job: &StackJob<F>) -> Self pub fn from_stackjob<F>(job: &StackJob<F>) -> Self
@ -119,9 +145,9 @@ impl<T: Send> Job2<T> {
T: Send, T: Send,
{ {
use std::panic::{AssertUnwindSafe, catch_unwind}; use std::panic::{AssertUnwindSafe, catch_unwind};
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
let f = unsafe { this.cast::<StackJob<F>>().as_ref().unwrap() }; let f = unsafe { this.cast::<StackJob<F>>().as_ref().unwrap() };
let sender: Sender<T> = unsafe { mem::transmute(sender) };
// #[cfg(feature = "metrics")] // #[cfg(feature = "metrics")]
// if worker.heartbeat.parker() == mutex { // if worker.heartbeat.parker() == mutex {
@ -132,7 +158,18 @@ impl<T: Send> Job2<T> {
// tracing::trace!("job sent to self"); // tracing::trace!("job sent to self");
// } // }
sender.send(catch_unwind(AssertUnwindSafe(|| f(worker)))); let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
if let Some(sender) = sender {
unsafe {
sender.send_as_ref(result);
worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
}
} }
Self::new(harness::<F, T>, NonNull::from(job).cast()) Self::new(harness::<F, T>, NonNull::from(job).cast())
@ -153,6 +190,7 @@ impl<T: Send> Job2<T> {
T: Send, T: Send,
{ {
use std::panic::{AssertUnwindSafe, catch_unwind}; use std::panic::{AssertUnwindSafe, catch_unwind};
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
// expect MIRI to complain about this, but it is actually correct. // expect MIRI to complain about this, but it is actually correct.
// because I am so much smarter than MIRI, naturally, obviously. // because I am so much smarter than MIRI, naturally, obviously.
@ -160,9 +198,15 @@ impl<T: Send> Job2<T> {
let f = unsafe { (*Box::from_non_null(this.cast::<HeapJob<F>>())).into_inner() }; let f = unsafe { (*Box::from_non_null(this.cast::<HeapJob<F>>())).into_inner() };
let result = catch_unwind(AssertUnwindSafe(|| f(worker))); let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
if let Some(sender) = sender { if let Some(sender) = sender {
sender.send(result); unsafe {
sender.send_as_ref(result);
_ = worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
} }
} }
@ -177,10 +221,6 @@ impl<T: Send> Job2<T> {
pub fn from_harness(harness: JobHarness, this: NonNull<()>) -> Self { pub fn from_harness(harness: JobHarness, this: NonNull<()>) -> Self {
Self::new(harness, this) Self::new(harness, this)
} }
pub fn is_shared(&self) -> bool {
unsafe { (&*self.receiver.as_ptr()).is_some() }
}
} }
impl SharedJob { impl SharedJob {

View file

@ -1,7 +1,7 @@
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::{hint::cold_path, sync::Arc}; use std::{hint::cold_path, pin::Pin, sync::Arc};
use crate::{ use crate::{
context::Context, context::Context,
@ -84,7 +84,6 @@ impl WorkerThread {
// SAFETY: this function runs in a worker thread, so we can access the queue safely. // SAFETY: this function runs in a worker thread, so we can access the queue safely.
if count == 0 || queue_len < 3 { if count == 0 || queue_len < 3 {
cold_path();
self.join_heartbeat2(a, b) self.join_heartbeat2(a, b)
} else { } else {
(a.run_inline(self), b(self)) (a.run_inline(self), b(self))
@ -103,12 +102,14 @@ impl WorkerThread {
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed); self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
let job = a.into_job(); let _pinned = a.into_job();
let job = unsafe { Pin::new_unchecked(&_pinned) };
self.push_back(&job); self.push_back(&*job);
self.tick(); self.tick();
// let rb = b(self);
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) { let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
Ok(val) => val, Ok(val) => val,
Err(payload) => { Err(payload) => {
@ -117,32 +118,16 @@ impl WorkerThread {
cold_path(); cold_path();
// if b panicked, we need to wait for a to finish // if b panicked, we need to wait for a to finish
let mut receiver = job.take_receiver(); if let Some(recv) = job.take_receiver() {
self.wait_until_pred(|| match &receiver { _ = self.wait_until_recv(recv);
Some(recv) => recv.poll().is_some(), }
None => {
receiver = job.take_receiver();
false
}
});
resume_unwind(payload); resume_unwind(payload);
} }
}; };
let ra = if let Some(recv) = job.take_receiver() { let ra = if let Some(recv) = job.take_receiver() {
match self.wait_until_recv(recv) { crate::util::unwrap_or_panic(self.wait_until_recv(recv))
Some(t) => crate::util::unwrap_or_panic(t),
None => {
#[cfg(feature = "tracing")]
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
a.run_inline(self)
}
}
} else { } else {
self.pop_back(); self.pop_back();
@ -152,6 +137,9 @@ impl WorkerThread {
a.run_inline(self) a.run_inline(self)
}; };
// touch the job to ensure it is not dropped while we are still using it.
drop(_pinned);
(ra, rb) (ra, rb)
} }
@ -183,41 +171,23 @@ impl WorkerThread {
cold_path(); cold_path();
// if b panicked, we need to wait for a to finish // if b panicked, we need to wait for a to finish
let mut receiver = job.take_receiver(); if let Some(recv) = job.take_receiver() {
self.wait_until_pred(|| match &receiver { _ = self.wait_until_recv(recv);
Some(recv) => recv.poll().is_some(), }
None => {
receiver = job.take_receiver();
false
}
});
resume_unwind(payload); resume_unwind(payload);
} }
}; };
let ra = if let Some(recv) = job.take_receiver() { let ra = if let Some(recv) = job.take_receiver() {
match self.wait_until_recv(recv) { crate::util::unwrap_or_panic(self.wait_until_recv(recv))
Some(t) => crate::util::unwrap_or_panic(t),
None => {
#[cfg(feature = "tracing")]
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
unsafe { a.unwrap()(self) }
}
}
} else { } else {
self.pop_back(); self.pop_back();
unsafe { // SAFETY: we just popped the job from the queue, so it is safe to unwrap.
// SAFETY: we just popped the job from the queue, so it is safe to unwrap. #[cfg(feature = "tracing")]
#[cfg(feature = "tracing")] tracing::trace!("join_heartbeat: job was not shared, running a() inline");
tracing::trace!("join_heartbeat: job was not shared, running a() inline"); a.run_inline(self)
a.unwrap()(self)
}
}; };
(ra, rb) (ra, rb)

View file

@ -1,68 +1,172 @@
use std::{ use std::{
cell::UnsafeCell, cell::UnsafeCell,
collections::{HashMap, HashSet}, collections::HashMap,
marker::{PhantomData, PhantomPinned}, marker::{PhantomData, PhantomPinned},
mem::{self, MaybeUninit}, mem::{self, MaybeUninit},
pin::Pin, pin::Pin,
ptr::{self, NonNull},
sync::{ sync::{
Arc, Arc,
atomic::{AtomicU8, AtomicU32, Ordering}, atomic::{AtomicU32, Ordering},
}, },
}; };
use crossbeam_utils::CachePadded; use werkzeug::CachePadded;
use werkzeug::sync::Parker;
use werkzeug::ptr::TaggedAtomicPtr;
// A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast). // A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast).
// After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it. // After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it.
struct QueueInner<T> { struct QueueInner<T> {
parked: HashSet<ReceiverToken>, receivers: HashMap<ReceiverToken, CachePadded<(Slot<T>, bool)>>,
owned: HashMap<ReceiverToken, CachePadded<Slot<T>>>,
messages: Vec<T>, messages: Vec<T>,
_phantom: std::marker::PhantomData<T>, _phantom: std::marker::PhantomData<T>,
} }
struct Queue<T> { pub struct Queue<T> {
inner: UnsafeCell<QueueInner<T>>, inner: UnsafeCell<QueueInner<T>>,
lock: AtomicU32, lock: AtomicU32,
} }
unsafe impl<T> Send for Queue<T> {}
unsafe impl<T> Sync for Queue<T> where T: Send {}
enum SlotKey { enum SlotKey {
Owned(ReceiverToken), Owned(ReceiverToken),
Indexed(usize), Indexed(usize),
} }
struct Receiver<T> { pub struct Receiver<T> {
queue: Arc<Queue<T>>, queue: Arc<Queue<T>>,
lock: Pin<Box<(AtomicU32, PhantomPinned)>>, lock: Pin<Box<(Parker, PhantomPinned)>>,
} }
struct Sender<T> { #[repr(transparent)]
pub struct Sender<T> {
queue: Arc<Queue<T>>, queue: Arc<Queue<T>>,
} }
// TODO: make this a linked list of slots so we can queue multiple messages for // TODO: make this a linked list of slots so we can queue multiple messages for
// a single receiver // a single receiver
const SLOT_ALIGN: u8 = core::mem::align_of::<usize>().ilog2() as u8;
struct Slot<T> { struct Slot<T> {
value: UnsafeCell<MaybeUninit<T>>, value: UnsafeCell<MaybeUninit<T>>,
state: AtomicU8, next_and_state: TaggedAtomicPtr<Self, SLOT_ALIGN>,
_phantom: PhantomData<Self>,
} }
impl<T> Slot<T> { impl<T> Slot<T> {
fn new() -> Self { fn new() -> Self {
Self { Self {
value: UnsafeCell::new(MaybeUninit::uninit()), value: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(0), // 0 means empty next_and_state: TaggedAtomicPtr::new(ptr::null_mut(), 0), // 0 means empty
_phantom: PhantomData,
} }
} }
fn set(&self, value: T) {} fn is_set(&self) -> bool {
self.next_and_state.tag(Ordering::Acquire) == 1
}
unsafe fn pop(&self) -> Option<T> {
NonNull::new(self.next_and_state.ptr(Ordering::Acquire))
.and_then(|next| {
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
unsafe { next.as_ref().pop() }
})
.or_else(|| {
if self
.next_and_state
.swap_tag(0, Ordering::AcqRel, Ordering::Relaxed)
== 1
{
// SAFETY: The value is only initialized when the state is set to 1.
Some(unsafe { self.value.as_ref_unchecked().assume_init_read() })
} else {
None
}
})
}
/// the caller must ensure that they have exclusive access to the slot
unsafe fn push(&self, value: T) {
if self.is_set() {
let next = self.next_ptr();
unsafe {
(next.as_ref()).push(value);
}
} else {
// SAFETY: The value is only initialized when the state is set to 1.
unsafe { self.value.as_mut_unchecked().write(value) };
self.next_and_state
.set_tag(1, Ordering::Release, Ordering::Relaxed);
}
}
fn next_ptr(&self) -> NonNull<Slot<T>> {
if let Some(next) = NonNull::new(self.next_and_state.ptr(Ordering::Acquire)) {
next.cast()
} else {
self.alloc_next()
}
}
fn alloc_next(&self) -> NonNull<Slot<T>> {
let next = Box::into_raw(Box::new(Slot::new()));
let next = loop {
match self.next_and_state.compare_exchange_weak_ptr(
ptr::null_mut(),
next,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => break next,
Err(other) => {
if other.is_null() {
eprintln!("What the sigma? Slot::alloc_next: other is null");
continue;
}
// next was allocated under us, so we need to drop the slot we just allocated again.
#[cfg(feature = "tracing")]
tracing::trace!(
"Slot::alloc_next: next was allocated under us, dropping it. ours: {:p}, other: {:p}",
next,
other
);
_ = unsafe { Box::from_raw(next) };
break other;
}
}
};
unsafe {
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
NonNull::new_unchecked(next)
}
}
} }
impl<T> Drop for Slot<T> { impl<T> Drop for Slot<T> {
fn drop(&mut self) { fn drop(&mut self) {
// drop next chain
if let Some(next) = NonNull::new(self.next_and_state.swap_ptr(
ptr::null_mut(),
Ordering::Release,
Ordering::Relaxed,
)) {
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
// We drop this in place because idk..
unsafe {
next.drop_in_place();
_ = Box::<mem::ManuallyDrop<Self>>::from_non_null(next.cast());
}
}
// SAFETY: The value is only initialized when the state is set to 1. // SAFETY: The value is only initialized when the state is set to 1.
if mem::needs_drop::<T>() && self.state.load(Ordering::Acquire) == 1 { if mem::needs_drop::<T>() && self.next_and_state.tag(Ordering::Acquire) == 1 {
unsafe { self.value.as_mut_unchecked().assume_init_drop() }; unsafe { self.value.as_mut_unchecked().assume_init_drop() };
} }
} }
@ -77,19 +181,35 @@ impl<T> Drop for Slot<T> {
/// A token that can be used to identify a specific receiver in a queue. /// A token that can be used to identify a specific receiver in a queue.
#[repr(transparent)] #[repr(transparent)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct ReceiverToken(werkzeug::util::Send<*const u32>); pub struct ReceiverToken(werkzeug::util::Send<NonNull<u32>>);
impl ReceiverToken {
pub fn as_ptr(&self) -> *mut u32 {
self.0.into_inner().as_ptr()
}
pub unsafe fn as_parker(&self) -> &Parker {
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
unsafe { Parker::from_ptr(self.as_ptr()) }
}
pub unsafe fn from_parker(parker: &Parker) -> Self {
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
let ptr = NonNull::from(parker).cast::<u32>();
ReceiverToken(werkzeug::util::Send(ptr))
}
}
impl<T> Queue<T> { impl<T> Queue<T> {
pub fn new() -> Self { pub fn new() -> Arc<Self> {
Self { Arc::new(Self {
inner: UnsafeCell::new(QueueInner { inner: UnsafeCell::new(QueueInner {
parked: HashSet::new(),
messages: Vec::new(), messages: Vec::new(),
owned: HashMap::new(), receivers: HashMap::new(),
_phantom: PhantomData, _phantom: PhantomData,
}), }),
lock: AtomicU32::new(0), lock: AtomicU32::new(0),
} })
} }
pub fn new_sender(self: &Arc<Self>) -> Sender<T> { pub fn new_sender(self: &Arc<Self>) -> Sender<T> {
@ -98,22 +218,28 @@ impl<T> Queue<T> {
} }
} }
pub fn num_receivers(self: &Arc<Self>) -> usize {
let _guard = self.lock();
self.inner().receivers.len()
}
pub fn as_sender(self: &Arc<Self>) -> &Sender<T> {
unsafe { mem::transmute::<&Arc<Self>, &Sender<T>>(self) }
}
pub fn new_receiver(self: &Arc<Self>) -> Receiver<T> { pub fn new_receiver(self: &Arc<Self>) -> Receiver<T> {
let recv = Receiver { let recv = Receiver {
queue: self.clone(), queue: self.clone(),
lock: Box::pin((AtomicU32::new(0), PhantomPinned)), lock: Box::pin((Parker::new(), PhantomPinned)),
}; };
// allocate slot for the receiver // allocate slot for the receiver
let token = recv.get_token(); let token = recv.get_token();
let _guard = recv.queue.lock(); let _guard = recv.queue.lock();
recv.queue.inner().owned.insert( recv.queue
token, .inner()
CachePadded::new(Slot { .receivers
value: UnsafeCell::new(MaybeUninit::uninit()), .insert(token, CachePadded::new((Slot::new(), false)));
state: AtomicU8::new(0), // 0 means empty
}),
);
drop(_guard); drop(_guard);
recv recv
@ -134,28 +260,27 @@ impl<T> Queue<T> {
} }
impl<T> QueueInner<T> { impl<T> QueueInner<T> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn poll(&mut self, token: ReceiverToken) -> Option<T> { fn poll(&mut self, token: ReceiverToken) -> Option<T> {
// check if someone has sent a message to this receiver // check if someone has sent a message to this receiver
let slot = self.owned.get(&token).unwrap(); let CachePadded((slot, _)) = self.receivers.get(&token)?;
if slot.state.swap(0, Ordering::Acquire) == 1 {
// SAFETY: the slot is owned by this receiver and contains a message. unsafe { slot.pop() }.or_else(|| {
return Some(unsafe { slot.value.as_ref_unchecked().assume_init_read() }); // if the slot is empty, we can check the indexed messages
} else if let Some(t) = self.messages.pop() { #[cfg(feature = "tracing")]
return Some(t); tracing::trace!("QueueInner::poll: checking open messages");
} else {
None self.messages.pop()
} })
} }
} }
impl<T> Receiver<T> { impl<T> Receiver<T> {
fn get_token(&self) -> ReceiverToken { pub fn get_token(&self) -> ReceiverToken {
// the token is just the pointer to the lock of this receiver. // the token is just the pointer to the lock of this receiver.
// the lock is pinned, so it's address is stable across calls to `receive`. // the lock is pinned, so it's address is stable across calls to `receive`.
ReceiverToken(werkzeug::util::Send( ReceiverToken(werkzeug::util::Send(NonNull::from(&self.lock.0).cast()))
&self.lock.0 as *const AtomicU32 as *const u32,
))
} }
} }
@ -167,12 +292,13 @@ impl<T> Drop for Receiver<T> {
let queue = self.queue.inner(); let queue = self.queue.inner();
// remove the receiver from the queue // remove the receiver from the queue
_ = queue.owned.remove(&self.get_token()); _ = queue.receivers.remove(&self.get_token());
} }
} }
} }
impl<T: Send> Receiver<T> { impl<T: Send> Receiver<T> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn recv(&self) -> T { pub fn recv(&self) -> T {
let token = self.get_token(); let token = self.get_token();
@ -183,22 +309,23 @@ impl<T: Send> Receiver<T> {
// check if someone has sent a message to this receiver // check if someone has sent a message to this receiver
if let Some(t) = queue.poll(token) { if let Some(t) = queue.poll(token) {
queue.parked.remove(&token); queue.receivers.get_mut(&token).unwrap().1 = false; // mark the slot as not parked
return t; return t;
} }
// there was no message for this receiver, so we need to park it // there was no message for this receiver, so we need to park it
queue.parked.insert(token); queue.receivers.get_mut(&token).unwrap().1 = true; // mark the slot as parked
// wait for a message to be sent to this receiver self.lock.0.park_with_callback(move || {
drop(_guard); // drop the lock guard after having set the lock state to waiting.
unsafe { // this avoids a deadlock if the sender tries to send a message
let lock = werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()); // while the receiver is in the process of parking (I think..)
lock.wait(); drop(_guard);
} });
} }
} }
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn try_recv(&self) -> Option<T> { pub fn try_recv(&self) -> Option<T> {
let token = self.get_token(); let token = self.get_token();
@ -214,61 +341,92 @@ impl<T: Send> Receiver<T> {
impl<T: Send> Sender<T> { impl<T: Send> Sender<T> {
/// Sends a message to one of the receivers in the queue, or makes it /// Sends a message to one of the receivers in the queue, or makes it
/// available to any receiver that will park in the future. /// available to any receiver that will park in the future.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn anycast(&self, value: T) { pub fn anycast(&self, value: T) {
// look for a receiver that is parked
let _guard = self.queue.lock(); let _guard = self.queue.lock();
// SAFETY: The queue is locked, so we can safely access the inner queue.
match unsafe { self.try_anycast_inner(value) } {
Ok(_) => {}
Err(value) => {
#[cfg(feature = "tracing")]
tracing::trace!(
"Queue::anycast: no parked receiver found, adding message to indexed slots"
);
// no parked receiver found, so we want to add the message to the indexed slots
let queue = self.queue.inner();
queue.messages.push(value);
// waking up a parked receiver is not necessary here, as any
// receivers that don't have a free slot are currently waking up.
}
}
}
pub fn try_anycast(&self, value: T) -> Result<(), T> {
// lock the queue
let _guard = self.queue.lock();
// SAFETY: The queue is locked, so we can safely access the inner queue.
unsafe { self.try_anycast_inner(value) }
}
/// The caller must hold the lock on the queue for the duration of this function.
unsafe fn try_anycast_inner(&self, value: T) -> Result<(), T> {
// look for a receiver that is parked
let queue = self.queue.inner(); let queue = self.queue.inner();
if let Some((token, slot)) = queue.parked.iter().find_map(|token| { if let Some((token, slot)) =
// ensure the slot is available queue
queue.owned.get(token).and_then(|s| { .receivers
if s.state.load(Ordering::Acquire) == 0 { .iter()
Some((*token, s)) .find_map(|(token, CachePadded((slot, is_parked)))| {
} else { // ensure the slot is available
None if *is_parked && !slot.is_set() {
} Some((*token, slot))
}) } else {
}) { None
}
})
{
// we found a receiver that is parked, so we can send the message to it // we found a receiver that is parked, so we can send the message to it
unsafe { unsafe {
slot.value.as_mut_unchecked().write(value); slot.value.as_mut_unchecked().write(value);
slot.state.store(1, Ordering::Release); slot.next_and_state
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one(); .set_tag(1, Ordering::Release, Ordering::Relaxed);
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
} }
return; return Ok(());
} else { } else {
// no parked receiver found, so we want to add the message to the indexed slots return Err(value);
queue.messages.push(value);
// waking up a parked receiver is not necessary here, as any
// receivers that don't have a free slot are currently waking up.
} }
} }
/// Sends a message to a specific receiver, waking it if it is parked. /// Sends a message to a specific receiver, waking it if it is parked.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> { pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> {
// lock the queue // lock the queue
let _guard = self.queue.lock(); let _guard = self.queue.lock();
let queue = self.queue.inner(); let queue = self.queue.inner();
let Some(slot) = queue.owned.get_mut(&receiver) else { let Some(CachePadded((slot, _))) = queue.receivers.get_mut(&receiver) else {
return Err(value); return Err(value);
}; };
// SAFETY: The slot is owned by this receiver.
unsafe { slot.value.as_mut_unchecked().write(value) };
slot.state.store(1, Ordering::Release);
// check if the receiver is parked unsafe {
if queue.parked.contains(&receiver) { slot.push(value);
// wake the receiver }
unsafe {
werkzeug::sync::Lock::from_ptr(receiver.0.into_inner().cast_mut()).wake_one(); // wake the receiver
} unsafe {
Parker::from_ptr(receiver.0.into_inner().as_ptr()).unpark();
} }
Ok(()) Ok(())
} }
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn broadcast(&self, value: T) pub fn broadcast(&self, value: T)
where where
T: Clone, T: Clone,
@ -278,25 +436,37 @@ impl<T: Send> Sender<T> {
let queue = self.queue.inner(); let queue = self.queue.inner();
// send the message to all receivers // send the message to all receivers
for (token, slot) in queue.owned.iter() { for (token, CachePadded((slot, _))) in queue.receivers.iter() {
// SAFETY: The slot is owned by this receiver. // SAFETY: The slot is owned by this receiver.
if slot.state.load(Ordering::Acquire) != 0 { unsafe { slot.push(value.clone()) };
// the slot is not available, so we skip it
continue;
}
// wake the receiver
unsafe { unsafe {
slot.value.as_mut_unchecked().write(value.clone()); Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
} }
slot.state.store(1, Ordering::Release); }
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn broadcast_with<F>(&self, mut f: F)
where
F: FnMut() -> T,
{
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
// send the message to all receivers
for (token, CachePadded((slot, _))) in queue.receivers.iter() {
// SAFETY: The slot is owned by this receiver.
unsafe { slot.push(f()) };
// check if the receiver is parked // check if the receiver is parked
if queue.parked.contains(token) { // wake the receiver
// wake the receiver unsafe {
unsafe { Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
}
} }
} }
} }
@ -308,13 +478,12 @@ mod tests {
#[test] #[test]
fn test_queue() { fn test_queue() {
let queue = Arc::new(Queue::<i32>::new()); let queue = Queue::<i32>::new();
let sender = queue.new_sender(); let sender = queue.new_sender();
let receiver1 = queue.new_receiver(); let receiver1 = queue.new_receiver();
let receiver2 = queue.new_receiver(); let receiver2 = queue.new_receiver();
let token1 = receiver1.get_token();
let token2 = receiver2.get_token(); let token2 = receiver2.get_token();
sender.anycast(42); sender.anycast(42);
@ -325,4 +494,146 @@ mod tests {
assert_eq!(receiver1.try_recv(), None); assert_eq!(receiver1.try_recv(), None);
assert_eq!(receiver2.recv(), 100); assert_eq!(receiver2.recv(), 100);
} }
#[test]
fn queue_broadcast() {
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver1 = queue.new_receiver();
let receiver2 = queue.new_receiver();
sender.broadcast(42);
assert_eq!(receiver1.recv(), 42);
assert_eq!(receiver2.recv(), 42);
}
#[test]
fn queue_multiple_messages() {
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver = queue.new_receiver();
sender.anycast(1);
sender.unicast(2, receiver.get_token()).unwrap();
assert_eq!(receiver.recv(), 2);
assert_eq!(receiver.recv(), 1);
}
#[test]
fn queue_threaded() {
#[derive(Debug, Clone, Copy)]
enum Message {
Send(i32),
Exit,
}
let queue = Queue::<Message>::new();
let sender = queue.new_sender();
let threads = (0..5)
.map(|_| {
let queue_clone = queue.clone();
let receiver = queue_clone.new_receiver();
std::thread::spawn(move || {
loop {
match receiver.recv() {
Message::Send(value) => {
println!("Receiver {:?} Received: {}", receiver.get_token(), value);
}
Message::Exit => {
println!("Exiting thread");
break;
}
}
}
})
})
.collect::<Vec<_>>();
// Send messages to the receivers
for i in 0..10 {
sender.anycast(Message::Send(i));
}
// Send exit messages to all receivers
sender.broadcast(Message::Exit);
for thread in threads {
thread.join().unwrap();
}
println!("All threads have exited.");
}
#[test]
fn drop_slot() {
// Test that dropping a slot does not cause a double free or panic
let slot = Slot::<i32>::new();
unsafe {
slot.push(42);
drop(slot);
}
}
#[test]
fn drop_slot_chain() {
struct DropCheck<'a>(&'a AtomicU32);
impl Drop for DropCheck<'_> {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::SeqCst);
}
}
impl<'a> DropCheck<'a> {
fn new(counter: &'a AtomicU32) -> Self {
counter.fetch_add(1, Ordering::SeqCst);
Self(counter)
}
}
let counter = AtomicU32::new(0);
let slot = Slot::<DropCheck>::new();
for _ in 0..10 {
unsafe {
slot.push(DropCheck::new(&counter));
}
}
assert_eq!(counter.load(Ordering::SeqCst), 10);
drop(slot);
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"All DropCheck instances should have been dropped"
);
}
#[test]
fn send_self() {
// Test that sending a message to self works
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver = queue.new_receiver();
sender.unicast(42, receiver.get_token()).unwrap();
assert_eq!(receiver.recv(), 42);
}
#[test]
fn send_self_many() {
// Test that sending multiple messages to self works
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver = queue.new_receiver();
for i in 0..10 {
sender.unicast(i, receiver.get_token()).unwrap();
}
for i in (0..10).rev() {
assert_eq!(receiver.recv(), i);
}
}
} }

View file

@ -1,6 +1,6 @@
use std::{ use std::{
any::Any, any::Any,
marker::PhantomData, marker::{PhantomData, PhantomPinned},
panic::{AssertUnwindSafe, catch_unwind}, panic::{AssertUnwindSafe, catch_unwind},
pin::{self, Pin}, pin::{self, Pin},
ptr::{self, NonNull}, ptr::{self, NonNull},
@ -11,15 +11,17 @@ use std::{
}; };
use async_task::Runnable; use async_task::Runnable;
use werkzeug::util;
use crate::{ use crate::{
channel::Sender, channel::Sender,
context::Context, context::{Context, Message},
job::{ job::{
HeapJob, Job2 as Job, HeapJob, Job2 as Job, SharedJob,
traits::{InlineJob, IntoJob}, traits::{InlineJob, IntoJob},
}, },
latch::{CountLatch, Probe}, latch::{CountLatch, Probe},
queue::ReceiverToken,
util::{DropGuard, SendPtr}, util::{DropGuard, SendPtr},
workerthread::WorkerThread, workerthread::WorkerThread,
}; };
@ -47,7 +49,7 @@ use crate::{
struct ScopeInner { struct ScopeInner {
outstanding_jobs: AtomicUsize, outstanding_jobs: AtomicUsize,
parker: NonNull<crate::channel::Parker>, parker: ReceiverToken,
panic: AtomicPtr<Box<dyn Any + Send + 'static>>, panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
} }
@ -66,7 +68,7 @@ impl ScopeInner {
fn from_worker(worker: &WorkerThread) -> Self { fn from_worker(worker: &WorkerThread) -> Self {
Self { Self {
outstanding_jobs: AtomicUsize::new(0), outstanding_jobs: AtomicUsize::new(0),
parker: worker.heartbeat.parker().into(), parker: worker.receiver.get_token(),
panic: AtomicPtr::new(ptr::null_mut()), panic: AtomicPtr::new(ptr::null_mut()),
} }
} }
@ -75,11 +77,13 @@ impl ScopeInner {
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed); self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
} }
fn decrement(&self) { fn decrement(&self, worker: &WorkerThread) {
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 { if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
unsafe { worker
self.parker.as_ref().unpark(); .context
} .queue
.as_sender()
.unicast(Message::ScopeFinished, self.parker);
} }
} }
@ -196,19 +200,31 @@ impl<'scope, 'env> Scope<'scope, 'env> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn wait_for_jobs(&self) { fn wait_for_jobs(&self) {
#[cfg(feature = "tracing")] loop {
tracing::trace!(
"waiting for {} jobs to finish.",
self.inner().outstanding_jobs.load(Ordering::Relaxed)
);
self.worker().wait_until_pred(|| {
// SAFETY: we are in a worker thread, so the inner is valid.
let count = self.inner().outstanding_jobs.load(Ordering::Relaxed); let count = self.inner().outstanding_jobs.load(Ordering::Relaxed);
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
tracing::trace!("waiting for {} jobs to finish.", count); tracing::trace!("waiting for {} jobs to finish.", count);
count == 0 if count == 0 {
}); break;
}
match self.worker().receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self.worker());
},
Message::ScopeFinished => {
#[cfg(feature = "tracing")]
tracing::trace!("scope finished, decrementing outstanding jobs.");
assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0);
break;
}
Message::WakeUp | Message::Exit => {}
}
}
}
fn decrement(&self) {
self.inner().decrement(self.worker());
} }
fn inner(&self) -> &ScopeInner { fn inner(&self) -> &ScopeInner {
@ -243,6 +259,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
) )
} }
#[align(8)]
unsafe fn harness<'scope, 'env, T>( unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread, worker: &WorkerThread,
this: NonNull<()>, this: NonNull<()>,
@ -268,7 +285,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
scope.inner().panicked(payload); scope.inner().panicked(payload);
} }
scope.inner().decrement(); scope.decrement();
}, },
self.inner, self.inner,
); );
@ -309,7 +326,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let future = async move { let future = async move {
let _guard = DropGuard::new(move || { let _guard = DropGuard::new(move || {
scope.inner().decrement(); scope.decrement();
}); });
// TODO: handle panics here // TODO: handle panics here
@ -358,6 +375,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
struct ScopeJob<F> { struct ScopeJob<F> {
f: UnsafeCell<ManuallyDrop<F>>, f: UnsafeCell<ManuallyDrop<F>>,
inner: SendPtr<ScopeInner>, inner: SendPtr<ScopeInner>,
_pin: PhantomPinned,
} }
impl<F> ScopeJob<F> { impl<F> ScopeJob<F> {
@ -365,22 +383,24 @@ impl<'scope, 'env> Scope<'scope, 'env> {
Self { Self {
f: UnsafeCell::new(ManuallyDrop::new(f)), f: UnsafeCell::new(ManuallyDrop::new(f)),
inner, inner,
_pin: PhantomPinned,
} }
} }
fn into_job<'scope, 'env, T>(&self) -> Job<T> fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job<T>
where where
F: FnOnce(Scope<'scope, 'env>) -> T + Send, F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope, 'env: 'scope,
T: Send, T: Send,
{ {
Job::from_harness(Self::harness, NonNull::from(self).cast()) Job::from_harness(Self::harness, NonNull::from(&*self).cast())
} }
unsafe fn unwrap(&self) -> F { unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) } unsafe { ManuallyDrop::take(&mut *self.f.get()) }
} }
#[align(8)]
unsafe fn harness<'scope, 'env, T>( unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread, worker: &WorkerThread,
this: NonNull<()>, this: NonNull<()>,
@ -391,16 +411,25 @@ impl<'scope, 'env> Scope<'scope, 'env> {
T: Send, T: Send,
{ {
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() }; let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
let f = unsafe { this.unwrap() }; let f = unsafe { this.unwrap() };
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) }; let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
let sender: Sender<T> = unsafe { mem::transmute(sender) };
// SAFETY: we are in a worker thread, so the inner is valid. let result = catch_unwind(AssertUnwindSafe(|| f(scope)));
sender.send(catch_unwind(AssertUnwindSafe(|| f(scope))));
let sender = sender.unwrap();
unsafe {
sender.send_as_ref(result);
worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
} }
} }
impl<'scope, 'env, F, T> IntoJob<T> for &ScopeJob<F> impl<'scope, 'env, F, T> IntoJob<T> for Pin<&ScopeJob<F>>
where where
F: FnOnce(Scope<'scope, 'env>) -> T + Send, F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope, 'env: 'scope,
@ -411,7 +440,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
} }
} }
impl<'scope, 'env, F, T> InlineJob<T> for &ScopeJob<F> impl<'scope, 'env, F, T> InlineJob<T> for Pin<&ScopeJob<F>>
where where
F: FnOnce(Scope<'scope, 'env>) -> T + Send, F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope, 'env: 'scope,
@ -422,8 +451,14 @@ impl<'scope, 'env> Scope<'scope, 'env> {
} }
} }
return worker let _pinned = ScopeJob::new(a, self.inner);
.join_heartbeat2_every::<_, _, _, _, 64>(&ScopeJob::new(a, self.inner), |_| b(*self)); let job = unsafe { Pin::new_unchecked(&_pinned) };
let (a, b) = worker.join_heartbeat2(job, |_| b(*self));
// touch job here to ensure it is not dropped before we run the join.
drop(_pinned);
(a, b)
// let stack = ScopeJob::new(a, self.inner); // let stack = ScopeJob::new(a, self.inner);
// let job = ScopeJob::into_job(&stack); // let job = ScopeJob::into_job(&stack);
@ -528,13 +563,17 @@ mod tests {
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_join_one() { fn scope_join_one() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let count = AtomicU8::new(0);
let a = pool.scope(|scope| { let a = pool.scope(|scope| {
let (a, b) = scope.join(|_| 3 + 4, |_| 5 + 6); let (a, b) = scope.join(
|_| count.fetch_add(1, Ordering::Relaxed) + 4,
|_| count.fetch_add(2, Ordering::Relaxed) + 6,
);
a + b a + b
}); });
assert_eq!(a, 18); assert_eq!(count.load(Ordering::Relaxed), 3);
} }
#[test] #[test]
@ -553,9 +592,9 @@ mod tests {
} }
pool.scope(|scope| { pool.scope(|scope| {
let total = sum(scope, 10); let total = sum(scope, 5);
assert_eq!(total, 1023); // assert_eq!(total, 1023);
// eprintln!("Total sum: {}", total); eprintln!("Total sum: {}", total);
}); });
} }

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use crate::{Scope, context::Context, scope::scope_with_context}; use crate::{Scope, context::Context, scope::scope_with_context};
#[repr(transparent)]
pub struct ThreadPool { pub struct ThreadPool {
pub(crate) context: Arc<Context>, pub(crate) context: Arc<Context>,
} }
@ -9,7 +10,7 @@ pub struct ThreadPool {
impl Drop for ThreadPool { impl Drop for ThreadPool {
fn drop(&mut self) { fn drop(&mut self) {
// TODO: Ensure that the context is properly cleaned up when the thread pool is dropped. // TODO: Ensure that the context is properly cleaned up when the thread pool is dropped.
// self.context.set_should_exit(); self.context.set_should_exit();
} }
} }
@ -25,9 +26,9 @@ impl ThreadPool {
Self { context } Self { context }
} }
pub fn global() -> Self { pub fn global() -> &'static Self {
let context = Context::global_context().clone(); // SAFETY: ThreadPool is a transparent wrapper around Arc<Context>,
Self { context } unsafe { core::mem::transmute(Context::global_context()) }
} }
pub fn scope<'env, F, R>(&self, f: F) -> R pub fn scope<'env, F, R>(&self, f: F) -> R

View file

@ -8,18 +8,21 @@ use std::{
time::Duration, time::Duration,
}; };
use crossbeam_utils::CachePadded; #[cfg(feature = "metrics")]
use werkzeug::CachePadded;
use crate::{ use crate::{
channel::Receiver, channel::Receiver,
context::Context, context::{Context, Message},
heartbeat::OwnedHeartbeatReceiver, heartbeat::OwnedHeartbeatReceiver,
job::{Job2 as Job, JobQueue as JobList, SharedJob}, job::{Job2 as Job, JobQueue as JobList, SharedJob},
queue,
util::DropGuard, util::DropGuard,
}; };
pub struct WorkerThread { pub struct WorkerThread {
pub(crate) context: Arc<Context>, pub(crate) context: Arc<Context>,
pub(crate) receiver: queue::Receiver<Message>,
pub(crate) queue: UnsafeCell<JobList>, pub(crate) queue: UnsafeCell<JobList>,
pub(crate) heartbeat: OwnedHeartbeatReceiver, pub(crate) heartbeat: OwnedHeartbeatReceiver,
pub(crate) join_count: Cell<u8>, pub(crate) join_count: Cell<u8>,
@ -37,6 +40,7 @@ impl WorkerThread {
let heartbeat = context.heartbeats.new_heartbeat(); let heartbeat = context.heartbeats.new_heartbeat();
Self { Self {
receiver: context.queue.new_receiver(),
context, context,
queue: UnsafeCell::new(JobList::new()), queue: UnsafeCell::new(JobList::new()),
heartbeat, heartbeat,
@ -82,85 +86,26 @@ impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn run_inner(&self) { fn run_inner(&self) {
let mut job = None; loop {
'outer: loop { if self.context.should_exit() {
if let Some(job) = job.take() { break;
self.execute(job);
} }
// no more jobs, wait to be notified of a new job or a heartbeat. match self.receiver.recv() {
while job.is_none() { Message::Shared(shared_job) => {
if self.context.should_exit() { self.execute(shared_job);
// if the context is stopped, break out of the outer loop which
// will exit the thread.
break 'outer;
} }
Message::Exit => break,
job = self.find_work_or_wait(); Message::WakeUp | Message::ScopeFinished => {}
} }
} }
} }
} }
impl WorkerThread { impl WorkerThread {
/// Looks for work in the local queue, then in the shared context, and if no /// Checks if the worker thread has received a heartbeat, and if so,
/// work is found, waits for the thread to be notified of a new job, after /// attempts to share a job with other workers. If a job was popped from
/// which it returns `None`. /// the queue, but not shared, this function runs the job locally.
/// The caller should then check for `should_exit` to determine if the
/// thread should exit, or look for work again.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn find_work_or_wait(&self) -> Option<SharedJob> {
if let Some(job) = self.find_work() {
return Some(job);
}
#[cfg(feature = "tracing")]
tracing::trace!("waiting for new job");
self.heartbeat.parker().park();
#[cfg(feature = "tracing")]
tracing::trace!("woken up from wait");
None
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn find_work_or_wait_unless<F>(&self, mut pred: F) -> Option<SharedJob>
where
F: FnMut() -> bool,
{
if let Some(job) = self.find_work() {
return Some(job);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Check the predicate while holding the lock. This is very important,
// because the lock must be held when notifying us of the result of a
// job we scheduled.
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// no jobs found, wait for a heartbeat or a new job
#[cfg(feature = "tracing")]
tracing::trace!(worker = self.heartbeat.index(), "waiting for new job");
if !pred() {
self.heartbeat.parker().park();
}
#[cfg(feature = "tracing")]
tracing::trace!(worker = self.heartbeat.index(), "woken up from wait");
None
}
fn find_work(&self) -> Option<SharedJob> {
let mut guard = self.context.shared();
if let Some(job) = guard.pop_job() {
#[cfg(feature = "metrics")]
self.metrics.num_jobs_stolen.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!("WorkerThread::find_work_inner: found shared job: {:?}", job);
return Some(job);
}
None
}
pub(crate) fn tick(&self) { pub(crate) fn tick(&self) {
if self.heartbeat.take() { if self.heartbeat.take() {
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
@ -170,6 +115,7 @@ impl WorkerThread {
"received heartbeat, thread id: {:?}", "received heartbeat, thread id: {:?}",
self.heartbeat.index() self.heartbeat.index()
); );
self.heartbeat_cold(); self.heartbeat_cold();
} }
} }
@ -177,28 +123,31 @@ impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn execute(&self, job: SharedJob) { fn execute(&self, job: SharedJob) {
unsafe { SharedJob::execute(job, self) }; unsafe { SharedJob::execute(job, self) };
self.tick(); // TODO: maybe tick here?
} }
/// Attempts to share a job with other workers within the same context.
/// returns `true` if the job was shared, `false` if it was not.
#[cold] #[cold]
fn heartbeat_cold(&self) { fn heartbeat_cold(&self) {
let mut guard = self.context.shared(); if let Some(job) = self.pop_back() {
#[cfg(feature = "tracing")]
tracing::trace!("heartbeat: sharing job: {:?}", job);
if !guard.jobs.contains_key(&self.heartbeat.id()) { #[cfg(feature = "metrics")]
if let Some(job) = self.pop_back() { self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!("heartbeat: sharing job: {:?}", job);
#[cfg(feature = "metrics")]
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
if let Err(Message::Shared(job)) =
self.context
.queue
.as_sender()
.try_anycast(Message::Shared(unsafe {
job.as_ref()
.share(Some(self.receiver.get_token().as_parker()))
}))
{
unsafe { unsafe {
guard.jobs.insert( SharedJob::execute(job, self);
self.heartbeat.id(),
job.as_ref().share(Some(self.heartbeat.parker())),
);
// SAFETY: we are holding the lock on the shared context.
self.context.notify_job_shared();
} }
} }
} }
@ -264,13 +213,14 @@ impl WorkerThread {
pub struct HeartbeatThread { pub struct HeartbeatThread {
ctx: Arc<Context>, ctx: Arc<Context>,
num_workers: usize,
} }
impl HeartbeatThread { impl HeartbeatThread {
const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100); const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
pub fn new(ctx: Arc<Context>) -> Self { pub fn new(ctx: Arc<Context>, num_workers: usize) -> Self {
Self { ctx } Self { ctx, num_workers }
} }
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))] #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
@ -282,6 +232,14 @@ impl HeartbeatThread {
let mut i = 0; let mut i = 0;
loop { loop {
let sleep_for = { let sleep_for = {
// loop {
// if self.ctx.should_exit() || self.ctx.queue.num_receivers() != self.num_workers
// {
// break;
// }
// self.ctx.heartbeat.park();
// }
if self.ctx.should_exit() { if self.ctx.should_exit() {
break; break;
} }
@ -306,88 +264,18 @@ impl HeartbeatThread {
} }
impl WorkerThread { impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))] #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> { pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> std::thread::Result<T> {
let recv = (*job).take_receiver().unwrap(); loop {
if let Some(result) = recv.poll() {
let mut out = recv.poll(); break result;
while std::hint::unlikely(out.is_none()) {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
} }
out = recv.poll(); match self.receiver.recv() {
} Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self);
out },
} Message::WakeUp | Message::Exit | Message::ScopeFinished => {}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> Option<std::thread::Result<T>> {
if self
.context
.shared()
.jobs
.remove(&self.heartbeat.id())
.is_some()
{
#[cfg(feature = "tracing")]
tracing::trace!("reclaiming shared job");
return None;
}
while recv.is_empty() {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
continue;
}
recv.wait();
}
Some(recv.recv())
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_pred<F>(&self, mut pred: F)
where
F: FnMut() -> bool,
{
if !pred() {
#[cfg(feature = "tracing")]
tracing::trace!("thread {:?} waiting on predicate", self.heartbeat.index());
self.wait_until_latch_cold(pred);
}
}
#[cold]
fn wait_until_latch_cold<F>(&self, mut pred: F)
where
F: FnMut() -> bool,
{
if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
#[cfg(feature = "tracing")]
tracing::trace!(
"thread {:?} reclaiming shared job: {:?}",
self.heartbeat.index(),
shared_job
);
unsafe { SharedJob::execute(shared_job, self) };
}
// do the usual thing and wait for the job's latch
// do the usual thing??? chatgipity really said this..
while !pred() {
// check local jobs before locking shared context
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
} }
} }
} }

View file

@ -87,6 +87,7 @@ fn join_distaff(tree_size: usize) {
let sum = sum(&tree, tree.root().unwrap(), s); let sum = sum(&tree, tree.root().unwrap(), s);
sum sum
}); });
eprintln!("sum: {sum}");
std::hint::black_box(sum); std::hint::black_box(sum);
} }
} }
@ -134,7 +135,7 @@ fn join_rayon(tree_size: usize) {
} }
fn main() { fn main() {
//tracing_subscriber::fmt::init(); // tracing_subscriber::fmt::init();
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
tracing::subscriber::set_global_default( tracing::subscriber::set_global_default(
tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()), tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()),
@ -166,6 +167,7 @@ fn main() {
} }
eprintln!("Done!"); eprintln!("Done!");
println!("Done!");
// // wait for user input before exiting // // wait for user input before exiting
// std::io::stdin().read_line(&mut String::new()).unwrap(); // std::io::stdin().read_line(&mut String::new()).unwrap();
} }