Compare commits

...

5 commits

10 changed files with 767 additions and 497 deletions

View file

@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2024"
[profile.bench]
opt-level = 0
debug = true
[profile.release]

View file

@ -17,62 +17,9 @@ enum State {
Taken,
}
// taken from `std`
#[derive(Debug)]
#[repr(transparent)]
pub struct Parker {
mutex: AtomicU32,
}
pub use werkzeug::sync::Parker;
impl Parker {
const PARKED: u32 = u32::MAX;
const EMPTY: u32 = 0;
const NOTIFIED: u32 = 1;
pub fn new() -> Self {
Self {
mutex: AtomicU32::new(Self::EMPTY),
}
}
pub fn is_parked(&self) -> bool {
self.mutex.load(Ordering::Acquire) == Self::PARKED
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(this = self as *const Self as usize)))]
pub fn park(&self) {
if self.mutex.fetch_sub(1, Ordering::Acquire) == Self::NOTIFIED {
// The thread was notified, so we can return immediately.
return;
}
loop {
atomic_wait::wait(&self.mutex, Self::PARKED);
// We check whether we were notified or woke up spuriously with
// acquire ordering in order to make-visible any writes made by the
// thread that notified us.
if self.mutex.swap(Self::EMPTY, Ordering::Acquire) == Self::NOTIFIED {
// The thread was notified, so we can return immediately.
return;
} else {
// spurious wakeup, so we need to re-park.
continue;
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(this = self as *const Self as usize)))]
pub fn unpark(&self) {
// write with Release ordering to ensure that any writes made by this
// thread are made-available to the unparked thread.
if self.mutex.swap(Self::NOTIFIED, Ordering::Release) == Self::PARKED {
// The thread was parked, so we need to notify it.
atomic_wait::wake_one(&self.mutex);
} else {
// The thread was not parked, so we don't need to do anything.
}
}
}
use crate::queue::Queue;
#[derive(Debug)]
#[repr(C)]
@ -104,6 +51,10 @@ impl<T: Send> Receiver<T> {
self.0.state.load(Ordering::Acquire) != State::Ready as u8
}
pub fn sender(&self) -> Sender<T> {
Sender(self.0.clone())
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait(&self) {
loop {
@ -182,15 +133,15 @@ impl<T: Send> Receiver<T> {
// `State::Ready`.
//
// In either case, this thread now has unique access to `val`.
unsafe { self.take() }
}
unsafe fn take(&self) -> thread::Result<T> {
assert_eq!(
self.0.state.swap(State::Taken as u8, Ordering::Acquire),
State::Ready as u8
);
unsafe { self.take() }
}
unsafe fn take(&self) -> thread::Result<T> {
let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() };
result
@ -221,6 +172,30 @@ impl<T: Send> Sender<T> {
}
}
}
pub unsafe fn parker(&self) -> &Parker {
unsafe { self.0.waiting_thread.as_ref() }
}
/// The caller must ensure that this function or `send` are only ever called once.
pub unsafe fn send_as_ref(&self, val: thread::Result<T>) {
// SAFETY:
// Only this thread can write to `val` and none can read it
// yet.
unsafe {
*self.0.val.get() = Some(Box::new(val));
}
if self.0.state.swap(State::Ready as u8, Ordering::AcqRel) == State::Waiting as u8 {
// SAFETY:
// A `Receiver` already wrote its thread to `waiting_thread`
// *before* setting the `state` to `State::Waiting`.
unsafe {
let thread = self.0.waiting_thread.as_ref();
thread.unpark();
}
}
}
}
pub fn channel<T: Send>(thread: NonNull<Parker>) -> (Sender<T>, Receiver<T>) {

View file

@ -1,4 +1,9 @@
use std::{
cell::UnsafeCell,
marker::PhantomPinned,
mem::{self, ManuallyDrop},
panic::{AssertUnwindSafe, catch_unwind},
pin::Pin,
ptr::NonNull,
sync::{
Arc, OnceLock,
@ -9,21 +14,28 @@ use std::{
use alloc::collections::BTreeMap;
use async_task::Runnable;
use parking_lot::{Condvar, Mutex};
use crate::{
channel::{Parker, Sender},
heartbeat::HeartbeatList,
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
queue::ReceiverToken,
util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread},
};
pub struct Context {
shared: Mutex<Shared>,
pub shared_job: Condvar,
should_exit: AtomicBool,
pub heartbeats: HeartbeatList,
pub(crate) queue: Arc<crate::queue::Queue<Message>>,
pub(crate) heartbeat: Parker,
}
pub(crate) enum Message {
Shared(SharedJob),
WakeUp,
Exit,
ScopeFinished,
}
pub(crate) struct Shared {
@ -52,22 +64,15 @@ impl Shared {
}
impl Context {
pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
self.shared.lock()
}
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
#[cfg(feature = "tracing")]
tracing::trace!("Creating context with {} threads", num_threads);
let this = Arc::new(Self {
shared: Mutex::new(Shared {
jobs: BTreeMap::new(),
injected_jobs: Vec::new(),
}),
shared_job: Condvar::new(),
should_exit: AtomicBool::new(false),
heartbeats: HeartbeatList::new(),
queue: crate::queue::Queue::new(),
heartbeat: Parker::new(),
});
// Create a barrier to synchronize the worker threads and the heartbeat thread
@ -94,7 +99,7 @@ impl Context {
std::thread::Builder::new()
.name("heartbeat-thread".to_string())
.spawn(move || {
HeartbeatThread::new(ctx).run(barrier);
HeartbeatThread::new(ctx, num_threads).run(barrier);
})
.expect("Failed to spawn heartbeat thread");
}
@ -106,7 +111,7 @@ impl Context {
pub fn set_should_exit(&self) {
self.should_exit.store(true, Ordering::Relaxed);
self.heartbeats.notify_all();
self.queue.as_sender().broadcast_with(|| Message::Exit);
}
pub fn should_exit(&self) -> bool {
@ -124,31 +129,7 @@ impl Context {
}
pub fn inject_job(&self, job: SharedJob) {
let mut shared = self.shared.lock();
shared.injected_jobs.push(job);
unsafe {
// SAFETY: we are holding the shared lock, so it is safe to notify
self.notify_job_shared();
}
}
/// caller should hold the shared lock while calling this
pub unsafe fn notify_job_shared(&self) {
let heartbeats = self.heartbeats.inner();
if let Some((i, sender)) = heartbeats
.iter()
.find(|(_, heartbeat)| heartbeat.is_waiting())
.or_else(|| heartbeats.iter().next())
{
_ = i;
#[cfg(feature = "tracing")]
tracing::trace!("Notifying worker thread {} about job sharing", i);
sender.wake();
} else {
#[cfg(feature = "tracing")]
tracing::warn!("No worker found to notify about job sharing");
}
self.queue.as_sender().anycast(Message::Shared(job));
}
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
@ -160,13 +141,18 @@ impl Context {
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
// SAFETY: we are waiting on this latch in this thread.
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
let _pinned = StackJob::new(move |worker: &WorkerThread| f(worker));
let job = unsafe { Pin::new_unchecked(&_pinned) };
let job = Job::from_stackjob(&job);
unsafe {
self.inject_job(job.share(Some(worker.receiver.get_token().as_parker())));
}
self.inject_job(job.share(Some(worker.heartbeat.parker())));
let t = worker.wait_until_recv(job.take_receiver().expect("Job should have a receiver"));
let t = worker.wait_until_shared_job(&job).unwrap();
// touch the job to ensure it is dropped after we are done with it.
drop(_pinned);
crate::util::unwrap_or_panic(t)
}
@ -180,15 +166,69 @@ impl Context {
// current thread isn't a worker thread, create job and inject into context
let parker = Parker::new();
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
struct CrossJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
_pin: PhantomPinned,
}
let job = Job::from_stackjob(&job);
impl<F> CrossJob<F> {
fn new(f: F) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
_pin: PhantomPinned,
}
}
self.inject_job(job.share(Some(&parker)));
fn into_job<T>(self: &Self) -> Job<T>
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
}
let recv = job.take_receiver().unwrap();
unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
crate::util::unwrap_or_panic(recv.recv())
#[align(8)]
unsafe fn harness<T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
let this: &CrossJob<F> = unsafe { this.cast().as_ref() };
let f = unsafe { this.unwrap() };
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
let sender = sender.unwrap();
unsafe {
sender.send_as_ref(result);
worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
}
}
let pinned = CrossJob::new(move |worker: &WorkerThread| f(worker));
let job2 = pinned.into_job();
self.inject_job(job2.share(Some(&parker)));
let recv = job2.take_receiver().unwrap();
let out = crate::util::unwrap_or_panic(recv.recv());
// touch the job to ensure it is dropped after we are done with it.
drop(pinned);
drop(parker);
out
}
/// Run closure in this context.
@ -375,9 +415,12 @@ mod tests {
ctx.inject_job(job.share(Some(&parker)));
// Wait for the job to be executed
let recv = job.take_receiver().unwrap();
let result = recv.recv();
let result = crate::util::unwrap_or_panic(result);
let recv = job.take_receiver().expect("Job should have a receiver");
let Some(result) = recv.poll() else {
panic!("Expected a finished message");
};
let result = crate::util::unwrap_or_panic::<i32>(result);
assert_eq!(result, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}

View file

@ -4,13 +4,15 @@ use core::{
mem::{self, ManuallyDrop},
ptr::NonNull,
};
use std::cell::Cell;
use std::{cell::Cell, marker::PhantomPinned};
use alloc::boxed::Box;
use crate::{
WorkerThread,
channel::{Parker, Sender},
channel::{Parker, Receiver, Sender},
context::Message,
queue::ReceiverToken,
};
#[repr(transparent)]
@ -43,65 +45,89 @@ impl<F> HeapJob<F> {
}
}
type JobHarness =
unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<crate::channel::Sender>);
type JobHarness = unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<Sender>);
#[repr(C)]
pub struct Job2<T = ()> {
harness: JobHarness,
this: NonNull<()>,
receiver: Cell<Option<crate::channel::Receiver<T>>>,
inner: UnsafeCell<Job2Inner<T>>,
}
impl<T> Debug for Job2<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Job2")
.field("harness", &self.harness)
.field("this", &self.this)
.finish_non_exhaustive()
f.debug_struct("Job2").field("inner", &self.inner).finish()
}
}
#[repr(C)]
pub enum Job2Inner<T = ()> {
Local {
harness: JobHarness,
this: NonNull<()>,
_pin: PhantomPinned,
},
Shared {
receiver: Cell<Option<Receiver<T>>>,
},
}
#[derive(Debug)]
pub struct SharedJob {
harness: JobHarness,
this: NonNull<()>,
sender: Option<crate::channel::Sender>,
sender: Option<Sender<()>>,
}
unsafe impl Send for SharedJob {}
impl<T: Send> Job2<T> {
fn new(harness: JobHarness, this: NonNull<()>) -> Self {
let this = Self {
harness,
this,
receiver: Cell::new(None),
inner: UnsafeCell::new(Job2Inner::Local {
harness: harness,
this,
_pin: PhantomPinned,
}),
};
#[cfg(feature = "tracing")]
tracing::trace!("new job: {:?}", this);
this
}
pub fn share(&self, parker: Option<&Parker>) -> SharedJob {
#[cfg(feature = "tracing")]
tracing::trace!("sharing job: {:?}", self);
let (sender, receiver) = parker
.map(|parker| crate::channel::channel::<T>(parker.into()))
.unzip();
self.receiver.set(receiver);
SharedJob {
harness: self.harness,
this: self.this,
sender: unsafe { mem::transmute(sender) },
// self.receiver.set(receiver);
if let Job2Inner::Local {
harness,
this,
_pin: _,
} = unsafe {
self.inner.replace(Job2Inner::Shared {
receiver: Cell::new(receiver),
})
} {
// SAFETY: `this` is a valid pointer to the job.
unsafe {
SharedJob {
harness,
this,
sender: mem::transmute(sender), // Convert `Option<Sender<T>>` to `Option<Sender<()>>`
}
}
} else {
panic!("Job2 is already shared");
}
}
pub fn take_receiver(&self) -> Option<crate::channel::Receiver<T>> {
self.receiver.take()
pub fn take_receiver(&self) -> Option<Receiver<T>> {
unsafe {
if let Job2Inner::Shared { receiver } = self.inner.as_ref_unchecked() {
receiver.take()
} else {
None
}
}
}
pub fn from_stackjob<F>(job: &StackJob<F>) -> Self
@ -119,9 +145,9 @@ impl<T: Send> Job2<T> {
T: Send,
{
use std::panic::{AssertUnwindSafe, catch_unwind};
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
let f = unsafe { this.cast::<StackJob<F>>().as_ref().unwrap() };
let sender: Sender<T> = unsafe { mem::transmute(sender) };
// #[cfg(feature = "metrics")]
// if worker.heartbeat.parker() == mutex {
@ -132,7 +158,18 @@ impl<T: Send> Job2<T> {
// tracing::trace!("job sent to self");
// }
sender.send(catch_unwind(AssertUnwindSafe(|| f(worker))));
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
if let Some(sender) = sender {
unsafe {
sender.send_as_ref(result);
worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
}
}
Self::new(harness::<F, T>, NonNull::from(job).cast())
@ -153,6 +190,7 @@ impl<T: Send> Job2<T> {
T: Send,
{
use std::panic::{AssertUnwindSafe, catch_unwind};
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
// expect MIRI to complain about this, but it is actually correct.
// because I am so much smarter than MIRI, naturally, obviously.
@ -160,9 +198,15 @@ impl<T: Send> Job2<T> {
let f = unsafe { (*Box::from_non_null(this.cast::<HeapJob<F>>())).into_inner() };
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
if let Some(sender) = sender {
sender.send(result);
unsafe {
sender.send_as_ref(result);
_ = worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
}
}
@ -177,10 +221,6 @@ impl<T: Send> Job2<T> {
pub fn from_harness(harness: JobHarness, this: NonNull<()>) -> Self {
Self::new(harness, this)
}
pub fn is_shared(&self) -> bool {
unsafe { (&*self.receiver.as_ptr()).is_some() }
}
}
impl SharedJob {

View file

@ -1,7 +1,7 @@
#[cfg(feature = "metrics")]
use std::sync::atomic::Ordering;
use std::{hint::cold_path, sync::Arc};
use std::{hint::cold_path, pin::Pin, sync::Arc};
use crate::{
context::Context,
@ -84,7 +84,6 @@ impl WorkerThread {
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
if count == 0 || queue_len < 3 {
cold_path();
self.join_heartbeat2(a, b)
} else {
(a.run_inline(self), b(self))
@ -103,12 +102,14 @@ impl WorkerThread {
#[cfg(feature = "metrics")]
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
let job = a.into_job();
let _pinned = a.into_job();
let job = unsafe { Pin::new_unchecked(&_pinned) };
self.push_back(&job);
self.push_back(&*job);
self.tick();
// let rb = b(self);
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
Ok(val) => val,
Err(payload) => {
@ -117,32 +118,16 @@ impl WorkerThread {
cold_path();
// if b panicked, we need to wait for a to finish
let mut receiver = job.take_receiver();
self.wait_until_pred(|| match &receiver {
Some(recv) => recv.poll().is_some(),
None => {
receiver = job.take_receiver();
false
}
});
if let Some(recv) = job.take_receiver() {
_ = self.wait_until_recv(recv);
}
resume_unwind(payload);
}
};
let ra = if let Some(recv) = job.take_receiver() {
match self.wait_until_recv(recv) {
Some(t) => crate::util::unwrap_or_panic(t),
None => {
#[cfg(feature = "tracing")]
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
a.run_inline(self)
}
}
crate::util::unwrap_or_panic(self.wait_until_recv(recv))
} else {
self.pop_back();
@ -152,6 +137,9 @@ impl WorkerThread {
a.run_inline(self)
};
// touch the job to ensure it is not dropped while we are still using it.
drop(_pinned);
(ra, rb)
}
@ -183,41 +171,23 @@ impl WorkerThread {
cold_path();
// if b panicked, we need to wait for a to finish
let mut receiver = job.take_receiver();
self.wait_until_pred(|| match &receiver {
Some(recv) => recv.poll().is_some(),
None => {
receiver = job.take_receiver();
false
}
});
if let Some(recv) = job.take_receiver() {
_ = self.wait_until_recv(recv);
}
resume_unwind(payload);
}
};
let ra = if let Some(recv) = job.take_receiver() {
match self.wait_until_recv(recv) {
Some(t) => crate::util::unwrap_or_panic(t),
None => {
#[cfg(feature = "tracing")]
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
unsafe { a.unwrap()(self) }
}
}
crate::util::unwrap_or_panic(self.wait_until_recv(recv))
} else {
self.pop_back();
unsafe {
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.unwrap()(self)
}
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.run_inline(self)
};
(ra, rb)

View file

@ -1,68 +1,172 @@
use std::{
cell::UnsafeCell,
collections::{HashMap, HashSet},
collections::HashMap,
marker::{PhantomData, PhantomPinned},
mem::{self, MaybeUninit},
pin::Pin,
ptr::{self, NonNull},
sync::{
Arc,
atomic::{AtomicU8, AtomicU32, Ordering},
atomic::{AtomicU32, Ordering},
},
};
use crossbeam_utils::CachePadded;
use werkzeug::CachePadded;
use werkzeug::sync::Parker;
use werkzeug::ptr::TaggedAtomicPtr;
// A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast).
// After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it.
struct QueueInner<T> {
parked: HashSet<ReceiverToken>,
owned: HashMap<ReceiverToken, CachePadded<Slot<T>>>,
receivers: HashMap<ReceiverToken, CachePadded<(Slot<T>, bool)>>,
messages: Vec<T>,
_phantom: std::marker::PhantomData<T>,
}
struct Queue<T> {
pub struct Queue<T> {
inner: UnsafeCell<QueueInner<T>>,
lock: AtomicU32,
}
unsafe impl<T> Send for Queue<T> {}
unsafe impl<T> Sync for Queue<T> where T: Send {}
enum SlotKey {
Owned(ReceiverToken),
Indexed(usize),
}
struct Receiver<T> {
pub struct Receiver<T> {
queue: Arc<Queue<T>>,
lock: Pin<Box<(AtomicU32, PhantomPinned)>>,
lock: Pin<Box<(Parker, PhantomPinned)>>,
}
struct Sender<T> {
#[repr(transparent)]
pub struct Sender<T> {
queue: Arc<Queue<T>>,
}
// TODO: make this a linked list of slots so we can queue multiple messages for
// a single receiver
const SLOT_ALIGN: u8 = core::mem::align_of::<usize>().ilog2() as u8;
struct Slot<T> {
value: UnsafeCell<MaybeUninit<T>>,
state: AtomicU8,
next_and_state: TaggedAtomicPtr<Self, SLOT_ALIGN>,
_phantom: PhantomData<Self>,
}
impl<T> Slot<T> {
fn new() -> Self {
Self {
value: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(0), // 0 means empty
next_and_state: TaggedAtomicPtr::new(ptr::null_mut(), 0), // 0 means empty
_phantom: PhantomData,
}
}
fn set(&self, value: T) {}
fn is_set(&self) -> bool {
self.next_and_state.tag(Ordering::Acquire) == 1
}
unsafe fn pop(&self) -> Option<T> {
NonNull::new(self.next_and_state.ptr(Ordering::Acquire))
.and_then(|next| {
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
unsafe { next.as_ref().pop() }
})
.or_else(|| {
if self
.next_and_state
.swap_tag(0, Ordering::AcqRel, Ordering::Relaxed)
== 1
{
// SAFETY: The value is only initialized when the state is set to 1.
Some(unsafe { self.value.as_ref_unchecked().assume_init_read() })
} else {
None
}
})
}
/// the caller must ensure that they have exclusive access to the slot
unsafe fn push(&self, value: T) {
if self.is_set() {
let next = self.next_ptr();
unsafe {
(next.as_ref()).push(value);
}
} else {
// SAFETY: The value is only initialized when the state is set to 1.
unsafe { self.value.as_mut_unchecked().write(value) };
self.next_and_state
.set_tag(1, Ordering::Release, Ordering::Relaxed);
}
}
fn next_ptr(&self) -> NonNull<Slot<T>> {
if let Some(next) = NonNull::new(self.next_and_state.ptr(Ordering::Acquire)) {
next.cast()
} else {
self.alloc_next()
}
}
fn alloc_next(&self) -> NonNull<Slot<T>> {
let next = Box::into_raw(Box::new(Slot::new()));
let next = loop {
match self.next_and_state.compare_exchange_weak_ptr(
ptr::null_mut(),
next,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => break next,
Err(other) => {
if other.is_null() {
eprintln!("What the sigma? Slot::alloc_next: other is null");
continue;
}
// next was allocated under us, so we need to drop the slot we just allocated again.
#[cfg(feature = "tracing")]
tracing::trace!(
"Slot::alloc_next: next was allocated under us, dropping it. ours: {:p}, other: {:p}",
next,
other
);
_ = unsafe { Box::from_raw(next) };
break other;
}
}
};
unsafe {
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
NonNull::new_unchecked(next)
}
}
}
impl<T> Drop for Slot<T> {
fn drop(&mut self) {
// drop next chain
if let Some(next) = NonNull::new(self.next_and_state.swap_ptr(
ptr::null_mut(),
Ordering::Release,
Ordering::Relaxed,
)) {
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
// We drop this in place because idk..
unsafe {
next.drop_in_place();
_ = Box::<mem::ManuallyDrop<Self>>::from_non_null(next.cast());
}
}
// SAFETY: The value is only initialized when the state is set to 1.
if mem::needs_drop::<T>() && self.state.load(Ordering::Acquire) == 1 {
if mem::needs_drop::<T>() && self.next_and_state.tag(Ordering::Acquire) == 1 {
unsafe { self.value.as_mut_unchecked().assume_init_drop() };
}
}
@ -77,19 +181,35 @@ impl<T> Drop for Slot<T> {
/// A token that can be used to identify a specific receiver in a queue.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct ReceiverToken(werkzeug::util::Send<*const u32>);
pub struct ReceiverToken(werkzeug::util::Send<NonNull<u32>>);
impl ReceiverToken {
pub fn as_ptr(&self) -> *mut u32 {
self.0.into_inner().as_ptr()
}
pub unsafe fn as_parker(&self) -> &Parker {
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
unsafe { Parker::from_ptr(self.as_ptr()) }
}
pub unsafe fn from_parker(parker: &Parker) -> Self {
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
let ptr = NonNull::from(parker).cast::<u32>();
ReceiverToken(werkzeug::util::Send(ptr))
}
}
impl<T> Queue<T> {
pub fn new() -> Self {
Self {
pub fn new() -> Arc<Self> {
Arc::new(Self {
inner: UnsafeCell::new(QueueInner {
parked: HashSet::new(),
messages: Vec::new(),
owned: HashMap::new(),
receivers: HashMap::new(),
_phantom: PhantomData,
}),
lock: AtomicU32::new(0),
}
})
}
pub fn new_sender(self: &Arc<Self>) -> Sender<T> {
@ -98,22 +218,28 @@ impl<T> Queue<T> {
}
}
pub fn num_receivers(self: &Arc<Self>) -> usize {
let _guard = self.lock();
self.inner().receivers.len()
}
pub fn as_sender(self: &Arc<Self>) -> &Sender<T> {
unsafe { mem::transmute::<&Arc<Self>, &Sender<T>>(self) }
}
pub fn new_receiver(self: &Arc<Self>) -> Receiver<T> {
let recv = Receiver {
queue: self.clone(),
lock: Box::pin((AtomicU32::new(0), PhantomPinned)),
lock: Box::pin((Parker::new(), PhantomPinned)),
};
// allocate slot for the receiver
let token = recv.get_token();
let _guard = recv.queue.lock();
recv.queue.inner().owned.insert(
token,
CachePadded::new(Slot {
value: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(0), // 0 means empty
}),
);
recv.queue
.inner()
.receivers
.insert(token, CachePadded::new((Slot::new(), false)));
drop(_guard);
recv
@ -134,28 +260,27 @@ impl<T> Queue<T> {
}
impl<T> QueueInner<T> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn poll(&mut self, token: ReceiverToken) -> Option<T> {
// check if someone has sent a message to this receiver
let slot = self.owned.get(&token).unwrap();
if slot.state.swap(0, Ordering::Acquire) == 1 {
// SAFETY: the slot is owned by this receiver and contains a message.
return Some(unsafe { slot.value.as_ref_unchecked().assume_init_read() });
} else if let Some(t) = self.messages.pop() {
return Some(t);
} else {
None
}
let CachePadded((slot, _)) = self.receivers.get(&token)?;
unsafe { slot.pop() }.or_else(|| {
// if the slot is empty, we can check the indexed messages
#[cfg(feature = "tracing")]
tracing::trace!("QueueInner::poll: checking open messages");
self.messages.pop()
})
}
}
impl<T> Receiver<T> {
fn get_token(&self) -> ReceiverToken {
pub fn get_token(&self) -> ReceiverToken {
// the token is just the pointer to the lock of this receiver.
// the lock is pinned, so it's address is stable across calls to `receive`.
ReceiverToken(werkzeug::util::Send(
&self.lock.0 as *const AtomicU32 as *const u32,
))
ReceiverToken(werkzeug::util::Send(NonNull::from(&self.lock.0).cast()))
}
}
@ -167,12 +292,13 @@ impl<T> Drop for Receiver<T> {
let queue = self.queue.inner();
// remove the receiver from the queue
_ = queue.owned.remove(&self.get_token());
_ = queue.receivers.remove(&self.get_token());
}
}
}
impl<T: Send> Receiver<T> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn recv(&self) -> T {
let token = self.get_token();
@ -183,22 +309,23 @@ impl<T: Send> Receiver<T> {
// check if someone has sent a message to this receiver
if let Some(t) = queue.poll(token) {
queue.parked.remove(&token);
queue.receivers.get_mut(&token).unwrap().1 = false; // mark the slot as not parked
return t;
}
// there was no message for this receiver, so we need to park it
queue.parked.insert(token);
queue.receivers.get_mut(&token).unwrap().1 = true; // mark the slot as parked
// wait for a message to be sent to this receiver
drop(_guard);
unsafe {
let lock = werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut());
lock.wait();
}
self.lock.0.park_with_callback(move || {
// drop the lock guard after having set the lock state to waiting.
// this avoids a deadlock if the sender tries to send a message
// while the receiver is in the process of parking (I think..)
drop(_guard);
});
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn try_recv(&self) -> Option<T> {
let token = self.get_token();
@ -214,61 +341,92 @@ impl<T: Send> Receiver<T> {
impl<T: Send> Sender<T> {
/// Sends a message to one of the receivers in the queue, or makes it
/// available to any receiver that will park in the future.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn anycast(&self, value: T) {
// look for a receiver that is parked
let _guard = self.queue.lock();
// SAFETY: The queue is locked, so we can safely access the inner queue.
match unsafe { self.try_anycast_inner(value) } {
Ok(_) => {}
Err(value) => {
#[cfg(feature = "tracing")]
tracing::trace!(
"Queue::anycast: no parked receiver found, adding message to indexed slots"
);
// no parked receiver found, so we want to add the message to the indexed slots
let queue = self.queue.inner();
queue.messages.push(value);
// waking up a parked receiver is not necessary here, as any
// receivers that don't have a free slot are currently waking up.
}
}
}
pub fn try_anycast(&self, value: T) -> Result<(), T> {
// lock the queue
let _guard = self.queue.lock();
// SAFETY: The queue is locked, so we can safely access the inner queue.
unsafe { self.try_anycast_inner(value) }
}
/// The caller must hold the lock on the queue for the duration of this function.
unsafe fn try_anycast_inner(&self, value: T) -> Result<(), T> {
// look for a receiver that is parked
let queue = self.queue.inner();
if let Some((token, slot)) = queue.parked.iter().find_map(|token| {
// ensure the slot is available
queue.owned.get(token).and_then(|s| {
if s.state.load(Ordering::Acquire) == 0 {
Some((*token, s))
} else {
None
}
})
}) {
if let Some((token, slot)) =
queue
.receivers
.iter()
.find_map(|(token, CachePadded((slot, is_parked)))| {
// ensure the slot is available
if *is_parked && !slot.is_set() {
Some((*token, slot))
} else {
None
}
})
{
// we found a receiver that is parked, so we can send the message to it
unsafe {
slot.value.as_mut_unchecked().write(value);
slot.state.store(1, Ordering::Release);
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
slot.next_and_state
.set_tag(1, Ordering::Release, Ordering::Relaxed);
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
}
return;
return Ok(());
} else {
// no parked receiver found, so we want to add the message to the indexed slots
queue.messages.push(value);
// waking up a parked receiver is not necessary here, as any
// receivers that don't have a free slot are currently waking up.
return Err(value);
}
}
/// Sends a message to a specific receiver, waking it if it is parked.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> {
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
let Some(slot) = queue.owned.get_mut(&receiver) else {
let Some(CachePadded((slot, _))) = queue.receivers.get_mut(&receiver) else {
return Err(value);
};
// SAFETY: The slot is owned by this receiver.
unsafe { slot.value.as_mut_unchecked().write(value) };
slot.state.store(1, Ordering::Release);
// check if the receiver is parked
if queue.parked.contains(&receiver) {
// wake the receiver
unsafe {
werkzeug::sync::Lock::from_ptr(receiver.0.into_inner().cast_mut()).wake_one();
}
unsafe {
slot.push(value);
}
// wake the receiver
unsafe {
Parker::from_ptr(receiver.0.into_inner().as_ptr()).unpark();
}
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn broadcast(&self, value: T)
where
T: Clone,
@ -278,25 +436,37 @@ impl<T: Send> Sender<T> {
let queue = self.queue.inner();
// send the message to all receivers
for (token, slot) in queue.owned.iter() {
for (token, CachePadded((slot, _))) in queue.receivers.iter() {
// SAFETY: The slot is owned by this receiver.
if slot.state.load(Ordering::Acquire) != 0 {
// the slot is not available, so we skip it
continue;
}
unsafe { slot.push(value.clone()) };
// wake the receiver
unsafe {
slot.value.as_mut_unchecked().write(value.clone());
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
}
slot.state.store(1, Ordering::Release);
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn broadcast_with<F>(&self, mut f: F)
where
F: FnMut() -> T,
{
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
// send the message to all receivers
for (token, CachePadded((slot, _))) in queue.receivers.iter() {
// SAFETY: The slot is owned by this receiver.
unsafe { slot.push(f()) };
// check if the receiver is parked
if queue.parked.contains(token) {
// wake the receiver
unsafe {
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
}
// wake the receiver
unsafe {
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
}
}
}
@ -308,13 +478,12 @@ mod tests {
#[test]
fn test_queue() {
let queue = Arc::new(Queue::<i32>::new());
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver1 = queue.new_receiver();
let receiver2 = queue.new_receiver();
let token1 = receiver1.get_token();
let token2 = receiver2.get_token();
sender.anycast(42);
@ -325,4 +494,146 @@ mod tests {
assert_eq!(receiver1.try_recv(), None);
assert_eq!(receiver2.recv(), 100);
}
#[test]
fn queue_broadcast() {
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver1 = queue.new_receiver();
let receiver2 = queue.new_receiver();
sender.broadcast(42);
assert_eq!(receiver1.recv(), 42);
assert_eq!(receiver2.recv(), 42);
}
#[test]
fn queue_multiple_messages() {
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver = queue.new_receiver();
sender.anycast(1);
sender.unicast(2, receiver.get_token()).unwrap();
assert_eq!(receiver.recv(), 2);
assert_eq!(receiver.recv(), 1);
}
#[test]
fn queue_threaded() {
#[derive(Debug, Clone, Copy)]
enum Message {
Send(i32),
Exit,
}
let queue = Queue::<Message>::new();
let sender = queue.new_sender();
let threads = (0..5)
.map(|_| {
let queue_clone = queue.clone();
let receiver = queue_clone.new_receiver();
std::thread::spawn(move || {
loop {
match receiver.recv() {
Message::Send(value) => {
println!("Receiver {:?} Received: {}", receiver.get_token(), value);
}
Message::Exit => {
println!("Exiting thread");
break;
}
}
}
})
})
.collect::<Vec<_>>();
// Send messages to the receivers
for i in 0..10 {
sender.anycast(Message::Send(i));
}
// Send exit messages to all receivers
sender.broadcast(Message::Exit);
for thread in threads {
thread.join().unwrap();
}
println!("All threads have exited.");
}
#[test]
fn drop_slot() {
// Test that dropping a slot does not cause a double free or panic
let slot = Slot::<i32>::new();
unsafe {
slot.push(42);
drop(slot);
}
}
#[test]
fn drop_slot_chain() {
struct DropCheck<'a>(&'a AtomicU32);
impl Drop for DropCheck<'_> {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::SeqCst);
}
}
impl<'a> DropCheck<'a> {
fn new(counter: &'a AtomicU32) -> Self {
counter.fetch_add(1, Ordering::SeqCst);
Self(counter)
}
}
let counter = AtomicU32::new(0);
let slot = Slot::<DropCheck>::new();
for _ in 0..10 {
unsafe {
slot.push(DropCheck::new(&counter));
}
}
assert_eq!(counter.load(Ordering::SeqCst), 10);
drop(slot);
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"All DropCheck instances should have been dropped"
);
}
#[test]
fn send_self() {
// Test that sending a message to self works
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver = queue.new_receiver();
sender.unicast(42, receiver.get_token()).unwrap();
assert_eq!(receiver.recv(), 42);
}
#[test]
fn send_self_many() {
// Test that sending multiple messages to self works
let queue = Queue::<i32>::new();
let sender = queue.new_sender();
let receiver = queue.new_receiver();
for i in 0..10 {
sender.unicast(i, receiver.get_token()).unwrap();
}
for i in (0..10).rev() {
assert_eq!(receiver.recv(), i);
}
}
}

View file

@ -1,6 +1,6 @@
use std::{
any::Any,
marker::PhantomData,
marker::{PhantomData, PhantomPinned},
panic::{AssertUnwindSafe, catch_unwind},
pin::{self, Pin},
ptr::{self, NonNull},
@ -11,15 +11,17 @@ use std::{
};
use async_task::Runnable;
use werkzeug::util;
use crate::{
channel::Sender,
context::Context,
context::{Context, Message},
job::{
HeapJob, Job2 as Job,
HeapJob, Job2 as Job, SharedJob,
traits::{InlineJob, IntoJob},
},
latch::{CountLatch, Probe},
queue::ReceiverToken,
util::{DropGuard, SendPtr},
workerthread::WorkerThread,
};
@ -47,7 +49,7 @@ use crate::{
struct ScopeInner {
outstanding_jobs: AtomicUsize,
parker: NonNull<crate::channel::Parker>,
parker: ReceiverToken,
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
}
@ -66,7 +68,7 @@ impl ScopeInner {
fn from_worker(worker: &WorkerThread) -> Self {
Self {
outstanding_jobs: AtomicUsize::new(0),
parker: worker.heartbeat.parker().into(),
parker: worker.receiver.get_token(),
panic: AtomicPtr::new(ptr::null_mut()),
}
}
@ -75,11 +77,13 @@ impl ScopeInner {
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
}
fn decrement(&self) {
fn decrement(&self, worker: &WorkerThread) {
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
unsafe {
self.parker.as_ref().unpark();
}
worker
.context
.queue
.as_sender()
.unicast(Message::ScopeFinished, self.parker);
}
}
@ -196,19 +200,31 @@ impl<'scope, 'env> Scope<'scope, 'env> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn wait_for_jobs(&self) {
#[cfg(feature = "tracing")]
tracing::trace!(
"waiting for {} jobs to finish.",
self.inner().outstanding_jobs.load(Ordering::Relaxed)
);
self.worker().wait_until_pred(|| {
// SAFETY: we are in a worker thread, so the inner is valid.
loop {
let count = self.inner().outstanding_jobs.load(Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!("waiting for {} jobs to finish.", count);
count == 0
});
if count == 0 {
break;
}
match self.worker().receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self.worker());
},
Message::ScopeFinished => {
#[cfg(feature = "tracing")]
tracing::trace!("scope finished, decrementing outstanding jobs.");
assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0);
break;
}
Message::WakeUp | Message::Exit => {}
}
}
}
fn decrement(&self) {
self.inner().decrement(self.worker());
}
fn inner(&self) -> &ScopeInner {
@ -243,6 +259,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
)
}
#[align(8)]
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
@ -268,7 +285,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
scope.inner().panicked(payload);
}
scope.inner().decrement();
scope.decrement();
},
self.inner,
);
@ -309,7 +326,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let future = async move {
let _guard = DropGuard::new(move || {
scope.inner().decrement();
scope.decrement();
});
// TODO: handle panics here
@ -358,6 +375,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
struct ScopeJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
inner: SendPtr<ScopeInner>,
_pin: PhantomPinned,
}
impl<F> ScopeJob<F> {
@ -365,22 +383,24 @@ impl<'scope, 'env> Scope<'scope, 'env> {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
inner,
_pin: PhantomPinned,
}
}
fn into_job<'scope, 'env, T>(&self) -> Job<T>
fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job<T>
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
Job::from_harness(Self::harness, NonNull::from(self).cast())
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
}
unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
#[align(8)]
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
@ -391,16 +411,25 @@ impl<'scope, 'env> Scope<'scope, 'env> {
T: Send,
{
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
let f = unsafe { this.unwrap() };
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
let sender: Sender<T> = unsafe { mem::transmute(sender) };
// SAFETY: we are in a worker thread, so the inner is valid.
sender.send(catch_unwind(AssertUnwindSafe(|| f(scope))));
let result = catch_unwind(AssertUnwindSafe(|| f(scope)));
let sender = sender.unwrap();
unsafe {
sender.send_as_ref(result);
worker
.context
.queue
.as_sender()
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
}
}
}
impl<'scope, 'env, F, T> IntoJob<T> for &ScopeJob<F>
impl<'scope, 'env, F, T> IntoJob<T> for Pin<&ScopeJob<F>>
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
@ -411,7 +440,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
}
}
impl<'scope, 'env, F, T> InlineJob<T> for &ScopeJob<F>
impl<'scope, 'env, F, T> InlineJob<T> for Pin<&ScopeJob<F>>
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
@ -422,8 +451,14 @@ impl<'scope, 'env> Scope<'scope, 'env> {
}
}
return worker
.join_heartbeat2_every::<_, _, _, _, 64>(&ScopeJob::new(a, self.inner), |_| b(*self));
let _pinned = ScopeJob::new(a, self.inner);
let job = unsafe { Pin::new_unchecked(&_pinned) };
let (a, b) = worker.join_heartbeat2(job, |_| b(*self));
// touch job here to ensure it is not dropped before we run the join.
drop(_pinned);
(a, b)
// let stack = ScopeJob::new(a, self.inner);
// let job = ScopeJob::into_job(&stack);
@ -528,13 +563,17 @@ mod tests {
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_join_one() {
let pool = ThreadPool::new_with_threads(1);
let count = AtomicU8::new(0);
let a = pool.scope(|scope| {
let (a, b) = scope.join(|_| 3 + 4, |_| 5 + 6);
let (a, b) = scope.join(
|_| count.fetch_add(1, Ordering::Relaxed) + 4,
|_| count.fetch_add(2, Ordering::Relaxed) + 6,
);
a + b
});
assert_eq!(a, 18);
assert_eq!(count.load(Ordering::Relaxed), 3);
}
#[test]
@ -553,9 +592,9 @@ mod tests {
}
pool.scope(|scope| {
let total = sum(scope, 10);
assert_eq!(total, 1023);
// eprintln!("Total sum: {}", total);
let total = sum(scope, 5);
// assert_eq!(total, 1023);
eprintln!("Total sum: {}", total);
});
}

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use crate::{Scope, context::Context, scope::scope_with_context};
#[repr(transparent)]
pub struct ThreadPool {
pub(crate) context: Arc<Context>,
}
@ -9,7 +10,7 @@ pub struct ThreadPool {
impl Drop for ThreadPool {
fn drop(&mut self) {
// TODO: Ensure that the context is properly cleaned up when the thread pool is dropped.
// self.context.set_should_exit();
self.context.set_should_exit();
}
}
@ -25,9 +26,9 @@ impl ThreadPool {
Self { context }
}
pub fn global() -> Self {
let context = Context::global_context().clone();
Self { context }
pub fn global() -> &'static Self {
// SAFETY: ThreadPool is a transparent wrapper around Arc<Context>,
unsafe { core::mem::transmute(Context::global_context()) }
}
pub fn scope<'env, F, R>(&self, f: F) -> R

View file

@ -8,18 +8,21 @@ use std::{
time::Duration,
};
use crossbeam_utils::CachePadded;
#[cfg(feature = "metrics")]
use werkzeug::CachePadded;
use crate::{
channel::Receiver,
context::Context,
context::{Context, Message},
heartbeat::OwnedHeartbeatReceiver,
job::{Job2 as Job, JobQueue as JobList, SharedJob},
queue,
util::DropGuard,
};
pub struct WorkerThread {
pub(crate) context: Arc<Context>,
pub(crate) receiver: queue::Receiver<Message>,
pub(crate) queue: UnsafeCell<JobList>,
pub(crate) heartbeat: OwnedHeartbeatReceiver,
pub(crate) join_count: Cell<u8>,
@ -37,6 +40,7 @@ impl WorkerThread {
let heartbeat = context.heartbeats.new_heartbeat();
Self {
receiver: context.queue.new_receiver(),
context,
queue: UnsafeCell::new(JobList::new()),
heartbeat,
@ -82,85 +86,26 @@ impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn run_inner(&self) {
let mut job = None;
'outer: loop {
if let Some(job) = job.take() {
self.execute(job);
loop {
if self.context.should_exit() {
break;
}
// no more jobs, wait to be notified of a new job or a heartbeat.
while job.is_none() {
if self.context.should_exit() {
// if the context is stopped, break out of the outer loop which
// will exit the thread.
break 'outer;
match self.receiver.recv() {
Message::Shared(shared_job) => {
self.execute(shared_job);
}
job = self.find_work_or_wait();
Message::Exit => break,
Message::WakeUp | Message::ScopeFinished => {}
}
}
}
}
impl WorkerThread {
/// Looks for work in the local queue, then in the shared context, and if no
/// work is found, waits for the thread to be notified of a new job, after
/// which it returns `None`.
/// The caller should then check for `should_exit` to determine if the
/// thread should exit, or look for work again.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn find_work_or_wait(&self) -> Option<SharedJob> {
if let Some(job) = self.find_work() {
return Some(job);
}
#[cfg(feature = "tracing")]
tracing::trace!("waiting for new job");
self.heartbeat.parker().park();
#[cfg(feature = "tracing")]
tracing::trace!("woken up from wait");
None
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn find_work_or_wait_unless<F>(&self, mut pred: F) -> Option<SharedJob>
where
F: FnMut() -> bool,
{
if let Some(job) = self.find_work() {
return Some(job);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Check the predicate while holding the lock. This is very important,
// because the lock must be held when notifying us of the result of a
// job we scheduled.
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// no jobs found, wait for a heartbeat or a new job
#[cfg(feature = "tracing")]
tracing::trace!(worker = self.heartbeat.index(), "waiting for new job");
if !pred() {
self.heartbeat.parker().park();
}
#[cfg(feature = "tracing")]
tracing::trace!(worker = self.heartbeat.index(), "woken up from wait");
None
}
fn find_work(&self) -> Option<SharedJob> {
let mut guard = self.context.shared();
if let Some(job) = guard.pop_job() {
#[cfg(feature = "metrics")]
self.metrics.num_jobs_stolen.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!("WorkerThread::find_work_inner: found shared job: {:?}", job);
return Some(job);
}
None
}
/// Checks if the worker thread has received a heartbeat, and if so,
/// attempts to share a job with other workers. If a job was popped from
/// the queue, but not shared, this function runs the job locally.
pub(crate) fn tick(&self) {
if self.heartbeat.take() {
#[cfg(feature = "metrics")]
@ -170,6 +115,7 @@ impl WorkerThread {
"received heartbeat, thread id: {:?}",
self.heartbeat.index()
);
self.heartbeat_cold();
}
}
@ -177,28 +123,31 @@ impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn execute(&self, job: SharedJob) {
unsafe { SharedJob::execute(job, self) };
self.tick();
// TODO: maybe tick here?
}
/// Attempts to share a job with other workers within the same context.
/// returns `true` if the job was shared, `false` if it was not.
#[cold]
fn heartbeat_cold(&self) {
let mut guard = self.context.shared();
if let Some(job) = self.pop_back() {
#[cfg(feature = "tracing")]
tracing::trace!("heartbeat: sharing job: {:?}", job);
if !guard.jobs.contains_key(&self.heartbeat.id()) {
if let Some(job) = self.pop_back() {
#[cfg(feature = "tracing")]
tracing::trace!("heartbeat: sharing job: {:?}", job);
#[cfg(feature = "metrics")]
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
if let Err(Message::Shared(job)) =
self.context
.queue
.as_sender()
.try_anycast(Message::Shared(unsafe {
job.as_ref()
.share(Some(self.receiver.get_token().as_parker()))
}))
{
unsafe {
guard.jobs.insert(
self.heartbeat.id(),
job.as_ref().share(Some(self.heartbeat.parker())),
);
// SAFETY: we are holding the lock on the shared context.
self.context.notify_job_shared();
SharedJob::execute(job, self);
}
}
}
@ -264,13 +213,14 @@ impl WorkerThread {
pub struct HeartbeatThread {
ctx: Arc<Context>,
num_workers: usize,
}
impl HeartbeatThread {
const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
pub fn new(ctx: Arc<Context>) -> Self {
Self { ctx }
pub fn new(ctx: Arc<Context>, num_workers: usize) -> Self {
Self { ctx, num_workers }
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
@ -282,6 +232,14 @@ impl HeartbeatThread {
let mut i = 0;
loop {
let sleep_for = {
// loop {
// if self.ctx.should_exit() || self.ctx.queue.num_receivers() != self.num_workers
// {
// break;
// }
// self.ctx.heartbeat.park();
// }
if self.ctx.should_exit() {
break;
}
@ -306,88 +264,18 @@ impl HeartbeatThread {
}
impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> {
let recv = (*job).take_receiver().unwrap();
let mut out = recv.poll();
while std::hint::unlikely(out.is_none()) {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> std::thread::Result<T> {
loop {
if let Some(result) = recv.poll() {
break result;
}
out = recv.poll();
}
out
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> Option<std::thread::Result<T>> {
if self
.context
.shared()
.jobs
.remove(&self.heartbeat.id())
.is_some()
{
#[cfg(feature = "tracing")]
tracing::trace!("reclaiming shared job");
return None;
}
while recv.is_empty() {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
continue;
}
recv.wait();
}
Some(recv.recv())
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_pred<F>(&self, mut pred: F)
where
F: FnMut() -> bool,
{
if !pred() {
#[cfg(feature = "tracing")]
tracing::trace!("thread {:?} waiting on predicate", self.heartbeat.index());
self.wait_until_latch_cold(pred);
}
}
#[cold]
fn wait_until_latch_cold<F>(&self, mut pred: F)
where
F: FnMut() -> bool,
{
if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
#[cfg(feature = "tracing")]
tracing::trace!(
"thread {:?} reclaiming shared job: {:?}",
self.heartbeat.index(),
shared_job
);
unsafe { SharedJob::execute(shared_job, self) };
}
// do the usual thing and wait for the job's latch
// do the usual thing??? chatgipity really said this..
while !pred() {
// check local jobs before locking shared context
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
match self.receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self);
},
Message::WakeUp | Message::Exit | Message::ScopeFinished => {}
}
}
}

View file

@ -87,6 +87,7 @@ fn join_distaff(tree_size: usize) {
let sum = sum(&tree, tree.root().unwrap(), s);
sum
});
eprintln!("sum: {sum}");
std::hint::black_box(sum);
}
}
@ -134,7 +135,7 @@ fn join_rayon(tree_size: usize) {
}
fn main() {
//tracing_subscriber::fmt::init();
// tracing_subscriber::fmt::init();
use tracing_subscriber::layer::SubscriberExt;
tracing::subscriber::set_global_default(
tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()),
@ -166,6 +167,7 @@ fn main() {
}
eprintln!("Done!");
println!("Done!");
// // wait for user input before exiting
// std::io::stdin().read_line(&mut String::new()).unwrap();
}