Compare commits
5 commits
d1244026ca
...
26b6ef264c
Author | SHA1 | Date | |
---|---|---|---|
|
26b6ef264c | ||
|
268879d97e | ||
|
7c6e338b77 | ||
|
0836c7c958 | ||
|
b635ea5579 |
|
@ -4,6 +4,7 @@ version = "0.1.0"
|
|||
edition = "2024"
|
||||
|
||||
[profile.bench]
|
||||
opt-level = 0
|
||||
debug = true
|
||||
|
||||
[profile.release]
|
||||
|
|
|
@ -17,62 +17,9 @@ enum State {
|
|||
Taken,
|
||||
}
|
||||
|
||||
// taken from `std`
|
||||
#[derive(Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct Parker {
|
||||
mutex: AtomicU32,
|
||||
}
|
||||
pub use werkzeug::sync::Parker;
|
||||
|
||||
impl Parker {
|
||||
const PARKED: u32 = u32::MAX;
|
||||
const EMPTY: u32 = 0;
|
||||
const NOTIFIED: u32 = 1;
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
mutex: AtomicU32::new(Self::EMPTY),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_parked(&self) -> bool {
|
||||
self.mutex.load(Ordering::Acquire) == Self::PARKED
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(this = self as *const Self as usize)))]
|
||||
pub fn park(&self) {
|
||||
if self.mutex.fetch_sub(1, Ordering::Acquire) == Self::NOTIFIED {
|
||||
// The thread was notified, so we can return immediately.
|
||||
return;
|
||||
}
|
||||
|
||||
loop {
|
||||
atomic_wait::wait(&self.mutex, Self::PARKED);
|
||||
|
||||
// We check whether we were notified or woke up spuriously with
|
||||
// acquire ordering in order to make-visible any writes made by the
|
||||
// thread that notified us.
|
||||
if self.mutex.swap(Self::EMPTY, Ordering::Acquire) == Self::NOTIFIED {
|
||||
// The thread was notified, so we can return immediately.
|
||||
return;
|
||||
} else {
|
||||
// spurious wakeup, so we need to re-park.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(this = self as *const Self as usize)))]
|
||||
pub fn unpark(&self) {
|
||||
// write with Release ordering to ensure that any writes made by this
|
||||
// thread are made-available to the unparked thread.
|
||||
if self.mutex.swap(Self::NOTIFIED, Ordering::Release) == Self::PARKED {
|
||||
// The thread was parked, so we need to notify it.
|
||||
atomic_wait::wake_one(&self.mutex);
|
||||
} else {
|
||||
// The thread was not parked, so we don't need to do anything.
|
||||
}
|
||||
}
|
||||
}
|
||||
use crate::queue::Queue;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
|
@ -104,6 +51,10 @@ impl<T: Send> Receiver<T> {
|
|||
self.0.state.load(Ordering::Acquire) != State::Ready as u8
|
||||
}
|
||||
|
||||
pub fn sender(&self) -> Sender<T> {
|
||||
Sender(self.0.clone())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn wait(&self) {
|
||||
loop {
|
||||
|
@ -182,15 +133,15 @@ impl<T: Send> Receiver<T> {
|
|||
// `State::Ready`.
|
||||
//
|
||||
// In either case, this thread now has unique access to `val`.
|
||||
unsafe { self.take() }
|
||||
}
|
||||
|
||||
unsafe fn take(&self) -> thread::Result<T> {
|
||||
assert_eq!(
|
||||
self.0.state.swap(State::Taken as u8, Ordering::Acquire),
|
||||
State::Ready as u8
|
||||
);
|
||||
|
||||
unsafe { self.take() }
|
||||
}
|
||||
|
||||
unsafe fn take(&self) -> thread::Result<T> {
|
||||
let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() };
|
||||
|
||||
result
|
||||
|
@ -221,6 +172,30 @@ impl<T: Send> Sender<T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn parker(&self) -> &Parker {
|
||||
unsafe { self.0.waiting_thread.as_ref() }
|
||||
}
|
||||
|
||||
/// The caller must ensure that this function or `send` are only ever called once.
|
||||
pub unsafe fn send_as_ref(&self, val: thread::Result<T>) {
|
||||
// SAFETY:
|
||||
// Only this thread can write to `val` and none can read it
|
||||
// yet.
|
||||
unsafe {
|
||||
*self.0.val.get() = Some(Box::new(val));
|
||||
}
|
||||
|
||||
if self.0.state.swap(State::Ready as u8, Ordering::AcqRel) == State::Waiting as u8 {
|
||||
// SAFETY:
|
||||
// A `Receiver` already wrote its thread to `waiting_thread`
|
||||
// *before* setting the `state` to `State::Waiting`.
|
||||
unsafe {
|
||||
let thread = self.0.waiting_thread.as_ref();
|
||||
thread.unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn channel<T: Send>(thread: NonNull<Parker>) -> (Sender<T>, Receiver<T>) {
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
use std::{
|
||||
cell::UnsafeCell,
|
||||
marker::PhantomPinned,
|
||||
mem::{self, ManuallyDrop},
|
||||
panic::{AssertUnwindSafe, catch_unwind},
|
||||
pin::Pin,
|
||||
ptr::NonNull,
|
||||
sync::{
|
||||
Arc, OnceLock,
|
||||
|
@ -9,21 +14,28 @@ use std::{
|
|||
use alloc::collections::BTreeMap;
|
||||
|
||||
use async_task::Runnable;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
|
||||
use crate::{
|
||||
channel::{Parker, Sender},
|
||||
heartbeat::HeartbeatList,
|
||||
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
|
||||
queue::ReceiverToken,
|
||||
util::DropGuard,
|
||||
workerthread::{HeartbeatThread, WorkerThread},
|
||||
};
|
||||
|
||||
pub struct Context {
|
||||
shared: Mutex<Shared>,
|
||||
pub shared_job: Condvar,
|
||||
should_exit: AtomicBool,
|
||||
pub heartbeats: HeartbeatList,
|
||||
pub(crate) queue: Arc<crate::queue::Queue<Message>>,
|
||||
pub(crate) heartbeat: Parker,
|
||||
}
|
||||
|
||||
pub(crate) enum Message {
|
||||
Shared(SharedJob),
|
||||
WakeUp,
|
||||
Exit,
|
||||
ScopeFinished,
|
||||
}
|
||||
|
||||
pub(crate) struct Shared {
|
||||
|
@ -52,22 +64,15 @@ impl Shared {
|
|||
}
|
||||
|
||||
impl Context {
|
||||
pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
|
||||
self.shared.lock()
|
||||
}
|
||||
|
||||
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Creating context with {} threads", num_threads);
|
||||
|
||||
let this = Arc::new(Self {
|
||||
shared: Mutex::new(Shared {
|
||||
jobs: BTreeMap::new(),
|
||||
injected_jobs: Vec::new(),
|
||||
}),
|
||||
shared_job: Condvar::new(),
|
||||
should_exit: AtomicBool::new(false),
|
||||
heartbeats: HeartbeatList::new(),
|
||||
queue: crate::queue::Queue::new(),
|
||||
heartbeat: Parker::new(),
|
||||
});
|
||||
|
||||
// Create a barrier to synchronize the worker threads and the heartbeat thread
|
||||
|
@ -94,7 +99,7 @@ impl Context {
|
|||
std::thread::Builder::new()
|
||||
.name("heartbeat-thread".to_string())
|
||||
.spawn(move || {
|
||||
HeartbeatThread::new(ctx).run(barrier);
|
||||
HeartbeatThread::new(ctx, num_threads).run(barrier);
|
||||
})
|
||||
.expect("Failed to spawn heartbeat thread");
|
||||
}
|
||||
|
@ -106,7 +111,7 @@ impl Context {
|
|||
|
||||
pub fn set_should_exit(&self) {
|
||||
self.should_exit.store(true, Ordering::Relaxed);
|
||||
self.heartbeats.notify_all();
|
||||
self.queue.as_sender().broadcast_with(|| Message::Exit);
|
||||
}
|
||||
|
||||
pub fn should_exit(&self) -> bool {
|
||||
|
@ -124,31 +129,7 @@ impl Context {
|
|||
}
|
||||
|
||||
pub fn inject_job(&self, job: SharedJob) {
|
||||
let mut shared = self.shared.lock();
|
||||
shared.injected_jobs.push(job);
|
||||
|
||||
unsafe {
|
||||
// SAFETY: we are holding the shared lock, so it is safe to notify
|
||||
self.notify_job_shared();
|
||||
}
|
||||
}
|
||||
|
||||
/// caller should hold the shared lock while calling this
|
||||
pub unsafe fn notify_job_shared(&self) {
|
||||
let heartbeats = self.heartbeats.inner();
|
||||
if let Some((i, sender)) = heartbeats
|
||||
.iter()
|
||||
.find(|(_, heartbeat)| heartbeat.is_waiting())
|
||||
.or_else(|| heartbeats.iter().next())
|
||||
{
|
||||
_ = i;
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Notifying worker thread {} about job sharing", i);
|
||||
sender.wake();
|
||||
} else {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!("No worker found to notify about job sharing");
|
||||
}
|
||||
self.queue.as_sender().anycast(Message::Shared(job));
|
||||
}
|
||||
|
||||
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
||||
|
@ -160,13 +141,18 @@ impl Context {
|
|||
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
|
||||
|
||||
// SAFETY: we are waiting on this latch in this thread.
|
||||
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
|
||||
let _pinned = StackJob::new(move |worker: &WorkerThread| f(worker));
|
||||
let job = unsafe { Pin::new_unchecked(&_pinned) };
|
||||
|
||||
let job = Job::from_stackjob(&job);
|
||||
unsafe {
|
||||
self.inject_job(job.share(Some(worker.receiver.get_token().as_parker())));
|
||||
}
|
||||
|
||||
self.inject_job(job.share(Some(worker.heartbeat.parker())));
|
||||
let t = worker.wait_until_recv(job.take_receiver().expect("Job should have a receiver"));
|
||||
|
||||
let t = worker.wait_until_shared_job(&job).unwrap();
|
||||
// touch the job to ensure it is dropped after we are done with it.
|
||||
drop(_pinned);
|
||||
|
||||
crate::util::unwrap_or_panic(t)
|
||||
}
|
||||
|
@ -180,15 +166,69 @@ impl Context {
|
|||
// current thread isn't a worker thread, create job and inject into context
|
||||
let parker = Parker::new();
|
||||
|
||||
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
|
||||
struct CrossJob<F> {
|
||||
f: UnsafeCell<ManuallyDrop<F>>,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
|
||||
let job = Job::from_stackjob(&job);
|
||||
impl<F> CrossJob<F> {
|
||||
fn new(f: F) -> Self {
|
||||
Self {
|
||||
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
self.inject_job(job.share(Some(&parker)));
|
||||
fn into_job<T>(self: &Self) -> Job<T>
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
|
||||
}
|
||||
|
||||
let recv = job.take_receiver().unwrap();
|
||||
unsafe fn unwrap(&self) -> F {
|
||||
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
||||
}
|
||||
|
||||
crate::util::unwrap_or_panic(recv.recv())
|
||||
#[align(8)]
|
||||
unsafe fn harness<T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
let this: &CrossJob<F> = unsafe { this.cast().as_ref() };
|
||||
let f = unsafe { this.unwrap() };
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
|
||||
|
||||
let sender = sender.unwrap();
|
||||
unsafe {
|
||||
sender.send_as_ref(result);
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pinned = CrossJob::new(move |worker: &WorkerThread| f(worker));
|
||||
let job2 = pinned.into_job();
|
||||
|
||||
self.inject_job(job2.share(Some(&parker)));
|
||||
|
||||
let recv = job2.take_receiver().unwrap();
|
||||
|
||||
let out = crate::util::unwrap_or_panic(recv.recv());
|
||||
|
||||
// touch the job to ensure it is dropped after we are done with it.
|
||||
drop(pinned);
|
||||
drop(parker);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Run closure in this context.
|
||||
|
@ -375,9 +415,12 @@ mod tests {
|
|||
ctx.inject_job(job.share(Some(&parker)));
|
||||
|
||||
// Wait for the job to be executed
|
||||
let recv = job.take_receiver().unwrap();
|
||||
let result = recv.recv();
|
||||
let result = crate::util::unwrap_or_panic(result);
|
||||
let recv = job.take_receiver().expect("Job should have a receiver");
|
||||
let Some(result) = recv.poll() else {
|
||||
panic!("Expected a finished message");
|
||||
};
|
||||
|
||||
let result = crate::util::unwrap_or_panic::<i32>(result);
|
||||
assert_eq!(result, 42);
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
|
|
@ -4,13 +4,15 @@ use core::{
|
|||
mem::{self, ManuallyDrop},
|
||||
ptr::NonNull,
|
||||
};
|
||||
use std::cell::Cell;
|
||||
use std::{cell::Cell, marker::PhantomPinned};
|
||||
|
||||
use alloc::boxed::Box;
|
||||
|
||||
use crate::{
|
||||
WorkerThread,
|
||||
channel::{Parker, Sender},
|
||||
channel::{Parker, Receiver, Sender},
|
||||
context::Message,
|
||||
queue::ReceiverToken,
|
||||
};
|
||||
|
||||
#[repr(transparent)]
|
||||
|
@ -43,65 +45,89 @@ impl<F> HeapJob<F> {
|
|||
}
|
||||
}
|
||||
|
||||
type JobHarness =
|
||||
unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<crate::channel::Sender>);
|
||||
type JobHarness = unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<Sender>);
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Job2<T = ()> {
|
||||
harness: JobHarness,
|
||||
this: NonNull<()>,
|
||||
receiver: Cell<Option<crate::channel::Receiver<T>>>,
|
||||
inner: UnsafeCell<Job2Inner<T>>,
|
||||
}
|
||||
|
||||
impl<T> Debug for Job2<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Job2")
|
||||
.field("harness", &self.harness)
|
||||
.field("this", &self.this)
|
||||
.finish_non_exhaustive()
|
||||
f.debug_struct("Job2").field("inner", &self.inner).finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub enum Job2Inner<T = ()> {
|
||||
Local {
|
||||
harness: JobHarness,
|
||||
this: NonNull<()>,
|
||||
_pin: PhantomPinned,
|
||||
},
|
||||
Shared {
|
||||
receiver: Cell<Option<Receiver<T>>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SharedJob {
|
||||
harness: JobHarness,
|
||||
this: NonNull<()>,
|
||||
sender: Option<crate::channel::Sender>,
|
||||
sender: Option<Sender<()>>,
|
||||
}
|
||||
|
||||
unsafe impl Send for SharedJob {}
|
||||
|
||||
impl<T: Send> Job2<T> {
|
||||
fn new(harness: JobHarness, this: NonNull<()>) -> Self {
|
||||
let this = Self {
|
||||
harness,
|
||||
this,
|
||||
receiver: Cell::new(None),
|
||||
inner: UnsafeCell::new(Job2Inner::Local {
|
||||
harness: harness,
|
||||
this,
|
||||
_pin: PhantomPinned,
|
||||
}),
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("new job: {:?}", this);
|
||||
|
||||
this
|
||||
}
|
||||
|
||||
pub fn share(&self, parker: Option<&Parker>) -> SharedJob {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("sharing job: {:?}", self);
|
||||
|
||||
let (sender, receiver) = parker
|
||||
.map(|parker| crate::channel::channel::<T>(parker.into()))
|
||||
.unzip();
|
||||
|
||||
self.receiver.set(receiver);
|
||||
|
||||
SharedJob {
|
||||
harness: self.harness,
|
||||
this: self.this,
|
||||
sender: unsafe { mem::transmute(sender) },
|
||||
// self.receiver.set(receiver);
|
||||
if let Job2Inner::Local {
|
||||
harness,
|
||||
this,
|
||||
_pin: _,
|
||||
} = unsafe {
|
||||
self.inner.replace(Job2Inner::Shared {
|
||||
receiver: Cell::new(receiver),
|
||||
})
|
||||
} {
|
||||
// SAFETY: `this` is a valid pointer to the job.
|
||||
unsafe {
|
||||
SharedJob {
|
||||
harness,
|
||||
this,
|
||||
sender: mem::transmute(sender), // Convert `Option<Sender<T>>` to `Option<Sender<()>>`
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("Job2 is already shared");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_receiver(&self) -> Option<crate::channel::Receiver<T>> {
|
||||
self.receiver.take()
|
||||
pub fn take_receiver(&self) -> Option<Receiver<T>> {
|
||||
unsafe {
|
||||
if let Job2Inner::Shared { receiver } = self.inner.as_ref_unchecked() {
|
||||
receiver.take()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_stackjob<F>(job: &StackJob<F>) -> Self
|
||||
|
@ -119,9 +145,9 @@ impl<T: Send> Job2<T> {
|
|||
T: Send,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
|
||||
let f = unsafe { this.cast::<StackJob<F>>().as_ref().unwrap() };
|
||||
let sender: Sender<T> = unsafe { mem::transmute(sender) };
|
||||
|
||||
// #[cfg(feature = "metrics")]
|
||||
// if worker.heartbeat.parker() == mutex {
|
||||
|
@ -132,7 +158,18 @@ impl<T: Send> Job2<T> {
|
|||
// tracing::trace!("job sent to self");
|
||||
// }
|
||||
|
||||
sender.send(catch_unwind(AssertUnwindSafe(|| f(worker))));
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
|
||||
|
||||
if let Some(sender) = sender {
|
||||
unsafe {
|
||||
sender.send_as_ref(result);
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self::new(harness::<F, T>, NonNull::from(job).cast())
|
||||
|
@ -153,6 +190,7 @@ impl<T: Send> Job2<T> {
|
|||
T: Send,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
|
||||
// expect MIRI to complain about this, but it is actually correct.
|
||||
// because I am so much smarter than MIRI, naturally, obviously.
|
||||
|
@ -160,9 +198,15 @@ impl<T: Send> Job2<T> {
|
|||
let f = unsafe { (*Box::from_non_null(this.cast::<HeapJob<F>>())).into_inner() };
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
if let Some(sender) = sender {
|
||||
sender.send(result);
|
||||
unsafe {
|
||||
sender.send_as_ref(result);
|
||||
_ = worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -177,10 +221,6 @@ impl<T: Send> Job2<T> {
|
|||
pub fn from_harness(harness: JobHarness, this: NonNull<()>) -> Self {
|
||||
Self::new(harness, this)
|
||||
}
|
||||
|
||||
pub fn is_shared(&self) -> bool {
|
||||
unsafe { (&*self.receiver.as_ptr()).is_some() }
|
||||
}
|
||||
}
|
||||
|
||||
impl SharedJob {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[cfg(feature = "metrics")]
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use std::{hint::cold_path, sync::Arc};
|
||||
use std::{hint::cold_path, pin::Pin, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
context::Context,
|
||||
|
@ -84,7 +84,6 @@ impl WorkerThread {
|
|||
|
||||
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
|
||||
if count == 0 || queue_len < 3 {
|
||||
cold_path();
|
||||
self.join_heartbeat2(a, b)
|
||||
} else {
|
||||
(a.run_inline(self), b(self))
|
||||
|
@ -103,12 +102,14 @@ impl WorkerThread {
|
|||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let job = a.into_job();
|
||||
let _pinned = a.into_job();
|
||||
let job = unsafe { Pin::new_unchecked(&_pinned) };
|
||||
|
||||
self.push_back(&job);
|
||||
self.push_back(&*job);
|
||||
|
||||
self.tick();
|
||||
|
||||
// let rb = b(self);
|
||||
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
||||
Ok(val) => val,
|
||||
Err(payload) => {
|
||||
|
@ -117,32 +118,16 @@ impl WorkerThread {
|
|||
cold_path();
|
||||
|
||||
// if b panicked, we need to wait for a to finish
|
||||
let mut receiver = job.take_receiver();
|
||||
self.wait_until_pred(|| match &receiver {
|
||||
Some(recv) => recv.poll().is_some(),
|
||||
None => {
|
||||
receiver = job.take_receiver();
|
||||
false
|
||||
}
|
||||
});
|
||||
if let Some(recv) = job.take_receiver() {
|
||||
_ = self.wait_until_recv(recv);
|
||||
}
|
||||
|
||||
resume_unwind(payload);
|
||||
}
|
||||
};
|
||||
|
||||
let ra = if let Some(recv) = job.take_receiver() {
|
||||
match self.wait_until_recv(recv) {
|
||||
Some(t) => crate::util::unwrap_or_panic(t),
|
||||
None => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"join_heartbeat: job was shared, but reclaimed, running a() inline"
|
||||
);
|
||||
// the job was shared, but not yet stolen, so we get to run the
|
||||
// job inline
|
||||
a.run_inline(self)
|
||||
}
|
||||
}
|
||||
crate::util::unwrap_or_panic(self.wait_until_recv(recv))
|
||||
} else {
|
||||
self.pop_back();
|
||||
|
||||
|
@ -152,6 +137,9 @@ impl WorkerThread {
|
|||
a.run_inline(self)
|
||||
};
|
||||
|
||||
// touch the job to ensure it is not dropped while we are still using it.
|
||||
drop(_pinned);
|
||||
|
||||
(ra, rb)
|
||||
}
|
||||
|
||||
|
@ -183,41 +171,23 @@ impl WorkerThread {
|
|||
cold_path();
|
||||
|
||||
// if b panicked, we need to wait for a to finish
|
||||
let mut receiver = job.take_receiver();
|
||||
self.wait_until_pred(|| match &receiver {
|
||||
Some(recv) => recv.poll().is_some(),
|
||||
None => {
|
||||
receiver = job.take_receiver();
|
||||
false
|
||||
}
|
||||
});
|
||||
if let Some(recv) = job.take_receiver() {
|
||||
_ = self.wait_until_recv(recv);
|
||||
}
|
||||
|
||||
resume_unwind(payload);
|
||||
}
|
||||
};
|
||||
|
||||
let ra = if let Some(recv) = job.take_receiver() {
|
||||
match self.wait_until_recv(recv) {
|
||||
Some(t) => crate::util::unwrap_or_panic(t),
|
||||
None => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"join_heartbeat: job was shared, but reclaimed, running a() inline"
|
||||
);
|
||||
// the job was shared, but not yet stolen, so we get to run the
|
||||
// job inline
|
||||
unsafe { a.unwrap()(self) }
|
||||
}
|
||||
}
|
||||
crate::util::unwrap_or_panic(self.wait_until_recv(recv))
|
||||
} else {
|
||||
self.pop_back();
|
||||
|
||||
unsafe {
|
||||
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||
a.unwrap()(self)
|
||||
}
|
||||
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||
a.run_inline(self)
|
||||
};
|
||||
|
||||
(ra, rb)
|
||||
|
|
|
@ -1,68 +1,172 @@
|
|||
use std::{
|
||||
cell::UnsafeCell,
|
||||
collections::{HashMap, HashSet},
|
||||
collections::HashMap,
|
||||
marker::{PhantomData, PhantomPinned},
|
||||
mem::{self, MaybeUninit},
|
||||
pin::Pin,
|
||||
ptr::{self, NonNull},
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU8, AtomicU32, Ordering},
|
||||
atomic::{AtomicU32, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use crossbeam_utils::CachePadded;
|
||||
use werkzeug::CachePadded;
|
||||
use werkzeug::sync::Parker;
|
||||
|
||||
use werkzeug::ptr::TaggedAtomicPtr;
|
||||
|
||||
// A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast).
|
||||
// After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it.
|
||||
|
||||
struct QueueInner<T> {
|
||||
parked: HashSet<ReceiverToken>,
|
||||
owned: HashMap<ReceiverToken, CachePadded<Slot<T>>>,
|
||||
receivers: HashMap<ReceiverToken, CachePadded<(Slot<T>, bool)>>,
|
||||
messages: Vec<T>,
|
||||
_phantom: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
struct Queue<T> {
|
||||
pub struct Queue<T> {
|
||||
inner: UnsafeCell<QueueInner<T>>,
|
||||
lock: AtomicU32,
|
||||
}
|
||||
|
||||
unsafe impl<T> Send for Queue<T> {}
|
||||
unsafe impl<T> Sync for Queue<T> where T: Send {}
|
||||
|
||||
enum SlotKey {
|
||||
Owned(ReceiverToken),
|
||||
Indexed(usize),
|
||||
}
|
||||
|
||||
struct Receiver<T> {
|
||||
pub struct Receiver<T> {
|
||||
queue: Arc<Queue<T>>,
|
||||
lock: Pin<Box<(AtomicU32, PhantomPinned)>>,
|
||||
lock: Pin<Box<(Parker, PhantomPinned)>>,
|
||||
}
|
||||
|
||||
struct Sender<T> {
|
||||
#[repr(transparent)]
|
||||
pub struct Sender<T> {
|
||||
queue: Arc<Queue<T>>,
|
||||
}
|
||||
|
||||
// TODO: make this a linked list of slots so we can queue multiple messages for
|
||||
// a single receiver
|
||||
const SLOT_ALIGN: u8 = core::mem::align_of::<usize>().ilog2() as u8;
|
||||
struct Slot<T> {
|
||||
value: UnsafeCell<MaybeUninit<T>>,
|
||||
state: AtomicU8,
|
||||
next_and_state: TaggedAtomicPtr<Self, SLOT_ALIGN>,
|
||||
_phantom: PhantomData<Self>,
|
||||
}
|
||||
|
||||
impl<T> Slot<T> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
value: UnsafeCell::new(MaybeUninit::uninit()),
|
||||
state: AtomicU8::new(0), // 0 means empty
|
||||
next_and_state: TaggedAtomicPtr::new(ptr::null_mut(), 0), // 0 means empty
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn set(&self, value: T) {}
|
||||
fn is_set(&self) -> bool {
|
||||
self.next_and_state.tag(Ordering::Acquire) == 1
|
||||
}
|
||||
|
||||
unsafe fn pop(&self) -> Option<T> {
|
||||
NonNull::new(self.next_and_state.ptr(Ordering::Acquire))
|
||||
.and_then(|next| {
|
||||
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
|
||||
unsafe { next.as_ref().pop() }
|
||||
})
|
||||
.or_else(|| {
|
||||
if self
|
||||
.next_and_state
|
||||
.swap_tag(0, Ordering::AcqRel, Ordering::Relaxed)
|
||||
== 1
|
||||
{
|
||||
// SAFETY: The value is only initialized when the state is set to 1.
|
||||
Some(unsafe { self.value.as_ref_unchecked().assume_init_read() })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// the caller must ensure that they have exclusive access to the slot
|
||||
unsafe fn push(&self, value: T) {
|
||||
if self.is_set() {
|
||||
let next = self.next_ptr();
|
||||
unsafe {
|
||||
(next.as_ref()).push(value);
|
||||
}
|
||||
} else {
|
||||
// SAFETY: The value is only initialized when the state is set to 1.
|
||||
unsafe { self.value.as_mut_unchecked().write(value) };
|
||||
self.next_and_state
|
||||
.set_tag(1, Ordering::Release, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn next_ptr(&self) -> NonNull<Slot<T>> {
|
||||
if let Some(next) = NonNull::new(self.next_and_state.ptr(Ordering::Acquire)) {
|
||||
next.cast()
|
||||
} else {
|
||||
self.alloc_next()
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_next(&self) -> NonNull<Slot<T>> {
|
||||
let next = Box::into_raw(Box::new(Slot::new()));
|
||||
|
||||
let next = loop {
|
||||
match self.next_and_state.compare_exchange_weak_ptr(
|
||||
ptr::null_mut(),
|
||||
next,
|
||||
Ordering::Release,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => break next,
|
||||
Err(other) => {
|
||||
if other.is_null() {
|
||||
eprintln!("What the sigma? Slot::alloc_next: other is null");
|
||||
continue;
|
||||
}
|
||||
// next was allocated under us, so we need to drop the slot we just allocated again.
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"Slot::alloc_next: next was allocated under us, dropping it. ours: {:p}, other: {:p}",
|
||||
next,
|
||||
other
|
||||
);
|
||||
_ = unsafe { Box::from_raw(next) };
|
||||
break other;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
unsafe {
|
||||
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
|
||||
NonNull::new_unchecked(next)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Slot<T> {
|
||||
fn drop(&mut self) {
|
||||
// drop next chain
|
||||
if let Some(next) = NonNull::new(self.next_and_state.swap_ptr(
|
||||
ptr::null_mut(),
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
)) {
|
||||
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
|
||||
// We drop this in place because idk..
|
||||
unsafe {
|
||||
next.drop_in_place();
|
||||
_ = Box::<mem::ManuallyDrop<Self>>::from_non_null(next.cast());
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: The value is only initialized when the state is set to 1.
|
||||
if mem::needs_drop::<T>() && self.state.load(Ordering::Acquire) == 1 {
|
||||
if mem::needs_drop::<T>() && self.next_and_state.tag(Ordering::Acquire) == 1 {
|
||||
unsafe { self.value.as_mut_unchecked().assume_init_drop() };
|
||||
}
|
||||
}
|
||||
|
@ -77,19 +181,35 @@ impl<T> Drop for Slot<T> {
|
|||
/// A token that can be used to identify a specific receiver in a queue.
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub struct ReceiverToken(werkzeug::util::Send<*const u32>);
|
||||
pub struct ReceiverToken(werkzeug::util::Send<NonNull<u32>>);
|
||||
|
||||
impl ReceiverToken {
|
||||
pub fn as_ptr(&self) -> *mut u32 {
|
||||
self.0.into_inner().as_ptr()
|
||||
}
|
||||
|
||||
pub unsafe fn as_parker(&self) -> &Parker {
|
||||
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
|
||||
unsafe { Parker::from_ptr(self.as_ptr()) }
|
||||
}
|
||||
|
||||
pub unsafe fn from_parker(parker: &Parker) -> Self {
|
||||
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
|
||||
let ptr = NonNull::from(parker).cast::<u32>();
|
||||
ReceiverToken(werkzeug::util::Send(ptr))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Queue<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
inner: UnsafeCell::new(QueueInner {
|
||||
parked: HashSet::new(),
|
||||
messages: Vec::new(),
|
||||
owned: HashMap::new(),
|
||||
receivers: HashMap::new(),
|
||||
_phantom: PhantomData,
|
||||
}),
|
||||
lock: AtomicU32::new(0),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_sender(self: &Arc<Self>) -> Sender<T> {
|
||||
|
@ -98,22 +218,28 @@ impl<T> Queue<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn num_receivers(self: &Arc<Self>) -> usize {
|
||||
let _guard = self.lock();
|
||||
self.inner().receivers.len()
|
||||
}
|
||||
|
||||
pub fn as_sender(self: &Arc<Self>) -> &Sender<T> {
|
||||
unsafe { mem::transmute::<&Arc<Self>, &Sender<T>>(self) }
|
||||
}
|
||||
|
||||
pub fn new_receiver(self: &Arc<Self>) -> Receiver<T> {
|
||||
let recv = Receiver {
|
||||
queue: self.clone(),
|
||||
lock: Box::pin((AtomicU32::new(0), PhantomPinned)),
|
||||
lock: Box::pin((Parker::new(), PhantomPinned)),
|
||||
};
|
||||
|
||||
// allocate slot for the receiver
|
||||
let token = recv.get_token();
|
||||
let _guard = recv.queue.lock();
|
||||
recv.queue.inner().owned.insert(
|
||||
token,
|
||||
CachePadded::new(Slot {
|
||||
value: UnsafeCell::new(MaybeUninit::uninit()),
|
||||
state: AtomicU8::new(0), // 0 means empty
|
||||
}),
|
||||
);
|
||||
recv.queue
|
||||
.inner()
|
||||
.receivers
|
||||
.insert(token, CachePadded::new((Slot::new(), false)));
|
||||
|
||||
drop(_guard);
|
||||
recv
|
||||
|
@ -134,28 +260,27 @@ impl<T> Queue<T> {
|
|||
}
|
||||
|
||||
impl<T> QueueInner<T> {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn poll(&mut self, token: ReceiverToken) -> Option<T> {
|
||||
// check if someone has sent a message to this receiver
|
||||
let slot = self.owned.get(&token).unwrap();
|
||||
if slot.state.swap(0, Ordering::Acquire) == 1 {
|
||||
// SAFETY: the slot is owned by this receiver and contains a message.
|
||||
return Some(unsafe { slot.value.as_ref_unchecked().assume_init_read() });
|
||||
} else if let Some(t) = self.messages.pop() {
|
||||
return Some(t);
|
||||
} else {
|
||||
None
|
||||
}
|
||||
let CachePadded((slot, _)) = self.receivers.get(&token)?;
|
||||
|
||||
unsafe { slot.pop() }.or_else(|| {
|
||||
// if the slot is empty, we can check the indexed messages
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("QueueInner::poll: checking open messages");
|
||||
|
||||
self.messages.pop()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Receiver<T> {
|
||||
fn get_token(&self) -> ReceiverToken {
|
||||
pub fn get_token(&self) -> ReceiverToken {
|
||||
// the token is just the pointer to the lock of this receiver.
|
||||
// the lock is pinned, so it's address is stable across calls to `receive`.
|
||||
|
||||
ReceiverToken(werkzeug::util::Send(
|
||||
&self.lock.0 as *const AtomicU32 as *const u32,
|
||||
))
|
||||
ReceiverToken(werkzeug::util::Send(NonNull::from(&self.lock.0).cast()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,12 +292,13 @@ impl<T> Drop for Receiver<T> {
|
|||
let queue = self.queue.inner();
|
||||
|
||||
// remove the receiver from the queue
|
||||
_ = queue.owned.remove(&self.get_token());
|
||||
_ = queue.receivers.remove(&self.get_token());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send> Receiver<T> {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn recv(&self) -> T {
|
||||
let token = self.get_token();
|
||||
|
||||
|
@ -183,22 +309,23 @@ impl<T: Send> Receiver<T> {
|
|||
|
||||
// check if someone has sent a message to this receiver
|
||||
if let Some(t) = queue.poll(token) {
|
||||
queue.parked.remove(&token);
|
||||
queue.receivers.get_mut(&token).unwrap().1 = false; // mark the slot as not parked
|
||||
return t;
|
||||
}
|
||||
|
||||
// there was no message for this receiver, so we need to park it
|
||||
queue.parked.insert(token);
|
||||
queue.receivers.get_mut(&token).unwrap().1 = true; // mark the slot as parked
|
||||
|
||||
// wait for a message to be sent to this receiver
|
||||
drop(_guard);
|
||||
unsafe {
|
||||
let lock = werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut());
|
||||
lock.wait();
|
||||
}
|
||||
self.lock.0.park_with_callback(move || {
|
||||
// drop the lock guard after having set the lock state to waiting.
|
||||
// this avoids a deadlock if the sender tries to send a message
|
||||
// while the receiver is in the process of parking (I think..)
|
||||
drop(_guard);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn try_recv(&self) -> Option<T> {
|
||||
let token = self.get_token();
|
||||
|
||||
|
@ -214,61 +341,92 @@ impl<T: Send> Receiver<T> {
|
|||
impl<T: Send> Sender<T> {
|
||||
/// Sends a message to one of the receivers in the queue, or makes it
|
||||
/// available to any receiver that will park in the future.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn anycast(&self, value: T) {
|
||||
// look for a receiver that is parked
|
||||
let _guard = self.queue.lock();
|
||||
|
||||
// SAFETY: The queue is locked, so we can safely access the inner queue.
|
||||
match unsafe { self.try_anycast_inner(value) } {
|
||||
Ok(_) => {}
|
||||
Err(value) => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"Queue::anycast: no parked receiver found, adding message to indexed slots"
|
||||
);
|
||||
|
||||
// no parked receiver found, so we want to add the message to the indexed slots
|
||||
let queue = self.queue.inner();
|
||||
queue.messages.push(value);
|
||||
|
||||
// waking up a parked receiver is not necessary here, as any
|
||||
// receivers that don't have a free slot are currently waking up.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_anycast(&self, value: T) -> Result<(), T> {
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
|
||||
// SAFETY: The queue is locked, so we can safely access the inner queue.
|
||||
unsafe { self.try_anycast_inner(value) }
|
||||
}
|
||||
|
||||
/// The caller must hold the lock on the queue for the duration of this function.
|
||||
unsafe fn try_anycast_inner(&self, value: T) -> Result<(), T> {
|
||||
// look for a receiver that is parked
|
||||
let queue = self.queue.inner();
|
||||
if let Some((token, slot)) = queue.parked.iter().find_map(|token| {
|
||||
// ensure the slot is available
|
||||
queue.owned.get(token).and_then(|s| {
|
||||
if s.state.load(Ordering::Acquire) == 0 {
|
||||
Some((*token, s))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}) {
|
||||
if let Some((token, slot)) =
|
||||
queue
|
||||
.receivers
|
||||
.iter()
|
||||
.find_map(|(token, CachePadded((slot, is_parked)))| {
|
||||
// ensure the slot is available
|
||||
if *is_parked && !slot.is_set() {
|
||||
Some((*token, slot))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
{
|
||||
// we found a receiver that is parked, so we can send the message to it
|
||||
unsafe {
|
||||
slot.value.as_mut_unchecked().write(value);
|
||||
slot.state.store(1, Ordering::Release);
|
||||
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
|
||||
slot.next_and_state
|
||||
.set_tag(1, Ordering::Release, Ordering::Relaxed);
|
||||
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
|
||||
return;
|
||||
return Ok(());
|
||||
} else {
|
||||
// no parked receiver found, so we want to add the message to the indexed slots
|
||||
queue.messages.push(value);
|
||||
|
||||
// waking up a parked receiver is not necessary here, as any
|
||||
// receivers that don't have a free slot are currently waking up.
|
||||
return Err(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a message to a specific receiver, waking it if it is parked.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> {
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
let Some(slot) = queue.owned.get_mut(&receiver) else {
|
||||
let Some(CachePadded((slot, _))) = queue.receivers.get_mut(&receiver) else {
|
||||
return Err(value);
|
||||
};
|
||||
// SAFETY: The slot is owned by this receiver.
|
||||
unsafe { slot.value.as_mut_unchecked().write(value) };
|
||||
slot.state.store(1, Ordering::Release);
|
||||
|
||||
// check if the receiver is parked
|
||||
if queue.parked.contains(&receiver) {
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
werkzeug::sync::Lock::from_ptr(receiver.0.into_inner().cast_mut()).wake_one();
|
||||
}
|
||||
unsafe {
|
||||
slot.push(value);
|
||||
}
|
||||
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
Parker::from_ptr(receiver.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn broadcast(&self, value: T)
|
||||
where
|
||||
T: Clone,
|
||||
|
@ -278,25 +436,37 @@ impl<T: Send> Sender<T> {
|
|||
let queue = self.queue.inner();
|
||||
|
||||
// send the message to all receivers
|
||||
for (token, slot) in queue.owned.iter() {
|
||||
for (token, CachePadded((slot, _))) in queue.receivers.iter() {
|
||||
// SAFETY: The slot is owned by this receiver.
|
||||
|
||||
if slot.state.load(Ordering::Acquire) != 0 {
|
||||
// the slot is not available, so we skip it
|
||||
continue;
|
||||
}
|
||||
unsafe { slot.push(value.clone()) };
|
||||
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
slot.value.as_mut_unchecked().write(value.clone());
|
||||
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
slot.state.store(1, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn broadcast_with<F>(&self, mut f: F)
|
||||
where
|
||||
F: FnMut() -> T,
|
||||
{
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
// send the message to all receivers
|
||||
for (token, CachePadded((slot, _))) in queue.receivers.iter() {
|
||||
// SAFETY: The slot is owned by this receiver.
|
||||
|
||||
unsafe { slot.push(f()) };
|
||||
|
||||
// check if the receiver is parked
|
||||
if queue.parked.contains(token) {
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
|
||||
}
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -308,13 +478,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_queue() {
|
||||
let queue = Arc::new(Queue::<i32>::new());
|
||||
let queue = Queue::<i32>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
let receiver1 = queue.new_receiver();
|
||||
let receiver2 = queue.new_receiver();
|
||||
|
||||
let token1 = receiver1.get_token();
|
||||
let token2 = receiver2.get_token();
|
||||
|
||||
sender.anycast(42);
|
||||
|
@ -325,4 +494,146 @@ mod tests {
|
|||
assert_eq!(receiver1.try_recv(), None);
|
||||
assert_eq!(receiver2.recv(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_broadcast() {
|
||||
let queue = Queue::<i32>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
let receiver1 = queue.new_receiver();
|
||||
let receiver2 = queue.new_receiver();
|
||||
|
||||
sender.broadcast(42);
|
||||
|
||||
assert_eq!(receiver1.recv(), 42);
|
||||
assert_eq!(receiver2.recv(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_multiple_messages() {
|
||||
let queue = Queue::<i32>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
let receiver = queue.new_receiver();
|
||||
|
||||
sender.anycast(1);
|
||||
sender.unicast(2, receiver.get_token()).unwrap();
|
||||
|
||||
assert_eq!(receiver.recv(), 2);
|
||||
assert_eq!(receiver.recv(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_threaded() {
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Message {
|
||||
Send(i32),
|
||||
Exit,
|
||||
}
|
||||
|
||||
let queue = Queue::<Message>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
|
||||
let threads = (0..5)
|
||||
.map(|_| {
|
||||
let queue_clone = queue.clone();
|
||||
let receiver = queue_clone.new_receiver();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
loop {
|
||||
match receiver.recv() {
|
||||
Message::Send(value) => {
|
||||
println!("Receiver {:?} Received: {}", receiver.get_token(), value);
|
||||
}
|
||||
Message::Exit => {
|
||||
println!("Exiting thread");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Send messages to the receivers
|
||||
for i in 0..10 {
|
||||
sender.anycast(Message::Send(i));
|
||||
}
|
||||
|
||||
// Send exit messages to all receivers
|
||||
sender.broadcast(Message::Exit);
|
||||
for thread in threads {
|
||||
thread.join().unwrap();
|
||||
}
|
||||
println!("All threads have exited.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drop_slot() {
|
||||
// Test that dropping a slot does not cause a double free or panic
|
||||
let slot = Slot::<i32>::new();
|
||||
unsafe {
|
||||
slot.push(42);
|
||||
drop(slot);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drop_slot_chain() {
|
||||
struct DropCheck<'a>(&'a AtomicU32);
|
||||
impl Drop for DropCheck<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.0.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> DropCheck<'a> {
|
||||
fn new(counter: &'a AtomicU32) -> Self {
|
||||
counter.fetch_add(1, Ordering::SeqCst);
|
||||
Self(counter)
|
||||
}
|
||||
}
|
||||
let counter = AtomicU32::new(0);
|
||||
let slot = Slot::<DropCheck>::new();
|
||||
for _ in 0..10 {
|
||||
unsafe {
|
||||
slot.push(DropCheck::new(&counter));
|
||||
}
|
||||
}
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 10);
|
||||
drop(slot);
|
||||
assert_eq!(
|
||||
counter.load(Ordering::SeqCst),
|
||||
0,
|
||||
"All DropCheck instances should have been dropped"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_self() {
|
||||
// Test that sending a message to self works
|
||||
let queue = Queue::<i32>::new();
|
||||
let sender = queue.new_sender();
|
||||
let receiver = queue.new_receiver();
|
||||
|
||||
sender.unicast(42, receiver.get_token()).unwrap();
|
||||
assert_eq!(receiver.recv(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_self_many() {
|
||||
// Test that sending multiple messages to self works
|
||||
let queue = Queue::<i32>::new();
|
||||
let sender = queue.new_sender();
|
||||
let receiver = queue.new_receiver();
|
||||
|
||||
for i in 0..10 {
|
||||
sender.unicast(i, receiver.get_token()).unwrap();
|
||||
}
|
||||
|
||||
for i in (0..10).rev() {
|
||||
assert_eq!(receiver.recv(), i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::{
|
||||
any::Any,
|
||||
marker::PhantomData,
|
||||
marker::{PhantomData, PhantomPinned},
|
||||
panic::{AssertUnwindSafe, catch_unwind},
|
||||
pin::{self, Pin},
|
||||
ptr::{self, NonNull},
|
||||
|
@ -11,15 +11,17 @@ use std::{
|
|||
};
|
||||
|
||||
use async_task::Runnable;
|
||||
use werkzeug::util;
|
||||
|
||||
use crate::{
|
||||
channel::Sender,
|
||||
context::Context,
|
||||
context::{Context, Message},
|
||||
job::{
|
||||
HeapJob, Job2 as Job,
|
||||
HeapJob, Job2 as Job, SharedJob,
|
||||
traits::{InlineJob, IntoJob},
|
||||
},
|
||||
latch::{CountLatch, Probe},
|
||||
queue::ReceiverToken,
|
||||
util::{DropGuard, SendPtr},
|
||||
workerthread::WorkerThread,
|
||||
};
|
||||
|
@ -47,7 +49,7 @@ use crate::{
|
|||
|
||||
struct ScopeInner {
|
||||
outstanding_jobs: AtomicUsize,
|
||||
parker: NonNull<crate::channel::Parker>,
|
||||
parker: ReceiverToken,
|
||||
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
|
||||
|
@ -66,7 +68,7 @@ impl ScopeInner {
|
|||
fn from_worker(worker: &WorkerThread) -> Self {
|
||||
Self {
|
||||
outstanding_jobs: AtomicUsize::new(0),
|
||||
parker: worker.heartbeat.parker().into(),
|
||||
parker: worker.receiver.get_token(),
|
||||
panic: AtomicPtr::new(ptr::null_mut()),
|
||||
}
|
||||
}
|
||||
|
@ -75,11 +77,13 @@ impl ScopeInner {
|
|||
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn decrement(&self) {
|
||||
fn decrement(&self, worker: &WorkerThread) {
|
||||
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
|
||||
unsafe {
|
||||
self.parker.as_ref().unpark();
|
||||
}
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::ScopeFinished, self.parker);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -196,19 +200,31 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn wait_for_jobs(&self) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"waiting for {} jobs to finish.",
|
||||
self.inner().outstanding_jobs.load(Ordering::Relaxed)
|
||||
);
|
||||
|
||||
self.worker().wait_until_pred(|| {
|
||||
// SAFETY: we are in a worker thread, so the inner is valid.
|
||||
loop {
|
||||
let count = self.inner().outstanding_jobs.load(Ordering::Relaxed);
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("waiting for {} jobs to finish.", count);
|
||||
count == 0
|
||||
});
|
||||
if count == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
match self.worker().receiver.recv() {
|
||||
Message::Shared(shared_job) => unsafe {
|
||||
SharedJob::execute(shared_job, self.worker());
|
||||
},
|
||||
Message::ScopeFinished => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("scope finished, decrementing outstanding jobs.");
|
||||
assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0);
|
||||
break;
|
||||
}
|
||||
Message::WakeUp | Message::Exit => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decrement(&self) {
|
||||
self.inner().decrement(self.worker());
|
||||
}
|
||||
|
||||
fn inner(&self) -> &ScopeInner {
|
||||
|
@ -243,6 +259,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
)
|
||||
}
|
||||
|
||||
#[align(8)]
|
||||
unsafe fn harness<'scope, 'env, T>(
|
||||
worker: &WorkerThread,
|
||||
this: NonNull<()>,
|
||||
|
@ -268,7 +285,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
scope.inner().panicked(payload);
|
||||
}
|
||||
|
||||
scope.inner().decrement();
|
||||
scope.decrement();
|
||||
},
|
||||
self.inner,
|
||||
);
|
||||
|
@ -309,7 +326,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
|
||||
let future = async move {
|
||||
let _guard = DropGuard::new(move || {
|
||||
scope.inner().decrement();
|
||||
scope.decrement();
|
||||
});
|
||||
|
||||
// TODO: handle panics here
|
||||
|
@ -358,6 +375,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
struct ScopeJob<F> {
|
||||
f: UnsafeCell<ManuallyDrop<F>>,
|
||||
inner: SendPtr<ScopeInner>,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
|
||||
impl<F> ScopeJob<F> {
|
||||
|
@ -365,22 +383,24 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
Self {
|
||||
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
||||
inner,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_job<'scope, 'env, T>(&self) -> Job<T>
|
||||
fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job<T>
|
||||
where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
T: Send,
|
||||
{
|
||||
Job::from_harness(Self::harness, NonNull::from(self).cast())
|
||||
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
|
||||
}
|
||||
|
||||
unsafe fn unwrap(&self) -> F {
|
||||
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
||||
}
|
||||
|
||||
#[align(8)]
|
||||
unsafe fn harness<'scope, 'env, T>(
|
||||
worker: &WorkerThread,
|
||||
this: NonNull<()>,
|
||||
|
@ -391,16 +411,25 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
T: Send,
|
||||
{
|
||||
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
let f = unsafe { this.unwrap() };
|
||||
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
|
||||
let sender: Sender<T> = unsafe { mem::transmute(sender) };
|
||||
|
||||
// SAFETY: we are in a worker thread, so the inner is valid.
|
||||
sender.send(catch_unwind(AssertUnwindSafe(|| f(scope))));
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(scope)));
|
||||
|
||||
let sender = sender.unwrap();
|
||||
unsafe {
|
||||
sender.send_as_ref(result);
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'scope, 'env, F, T> IntoJob<T> for &ScopeJob<F>
|
||||
impl<'scope, 'env, F, T> IntoJob<T> for Pin<&ScopeJob<F>>
|
||||
where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
|
@ -411,7 +440,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'scope, 'env, F, T> InlineJob<T> for &ScopeJob<F>
|
||||
impl<'scope, 'env, F, T> InlineJob<T> for Pin<&ScopeJob<F>>
|
||||
where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
|
@ -422,8 +451,14 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
}
|
||||
}
|
||||
|
||||
return worker
|
||||
.join_heartbeat2_every::<_, _, _, _, 64>(&ScopeJob::new(a, self.inner), |_| b(*self));
|
||||
let _pinned = ScopeJob::new(a, self.inner);
|
||||
let job = unsafe { Pin::new_unchecked(&_pinned) };
|
||||
|
||||
let (a, b) = worker.join_heartbeat2(job, |_| b(*self));
|
||||
|
||||
// touch job here to ensure it is not dropped before we run the join.
|
||||
drop(_pinned);
|
||||
(a, b)
|
||||
|
||||
// let stack = ScopeJob::new(a, self.inner);
|
||||
// let job = ScopeJob::into_job(&stack);
|
||||
|
@ -528,13 +563,17 @@ mod tests {
|
|||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn scope_join_one() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let count = AtomicU8::new(0);
|
||||
|
||||
let a = pool.scope(|scope| {
|
||||
let (a, b) = scope.join(|_| 3 + 4, |_| 5 + 6);
|
||||
let (a, b) = scope.join(
|
||||
|_| count.fetch_add(1, Ordering::Relaxed) + 4,
|
||||
|_| count.fetch_add(2, Ordering::Relaxed) + 6,
|
||||
);
|
||||
a + b
|
||||
});
|
||||
|
||||
assert_eq!(a, 18);
|
||||
assert_eq!(count.load(Ordering::Relaxed), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -553,9 +592,9 @@ mod tests {
|
|||
}
|
||||
|
||||
pool.scope(|scope| {
|
||||
let total = sum(scope, 10);
|
||||
assert_eq!(total, 1023);
|
||||
// eprintln!("Total sum: {}", total);
|
||||
let total = sum(scope, 5);
|
||||
// assert_eq!(total, 1023);
|
||||
eprintln!("Total sum: {}", total);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::sync::Arc;
|
|||
|
||||
use crate::{Scope, context::Context, scope::scope_with_context};
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct ThreadPool {
|
||||
pub(crate) context: Arc<Context>,
|
||||
}
|
||||
|
@ -9,7 +10,7 @@ pub struct ThreadPool {
|
|||
impl Drop for ThreadPool {
|
||||
fn drop(&mut self) {
|
||||
// TODO: Ensure that the context is properly cleaned up when the thread pool is dropped.
|
||||
// self.context.set_should_exit();
|
||||
self.context.set_should_exit();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,9 +26,9 @@ impl ThreadPool {
|
|||
Self { context }
|
||||
}
|
||||
|
||||
pub fn global() -> Self {
|
||||
let context = Context::global_context().clone();
|
||||
Self { context }
|
||||
pub fn global() -> &'static Self {
|
||||
// SAFETY: ThreadPool is a transparent wrapper around Arc<Context>,
|
||||
unsafe { core::mem::transmute(Context::global_context()) }
|
||||
}
|
||||
|
||||
pub fn scope<'env, F, R>(&self, f: F) -> R
|
||||
|
|
|
@ -8,18 +8,21 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
|
||||
use crossbeam_utils::CachePadded;
|
||||
#[cfg(feature = "metrics")]
|
||||
use werkzeug::CachePadded;
|
||||
|
||||
use crate::{
|
||||
channel::Receiver,
|
||||
context::Context,
|
||||
context::{Context, Message},
|
||||
heartbeat::OwnedHeartbeatReceiver,
|
||||
job::{Job2 as Job, JobQueue as JobList, SharedJob},
|
||||
queue,
|
||||
util::DropGuard,
|
||||
};
|
||||
|
||||
pub struct WorkerThread {
|
||||
pub(crate) context: Arc<Context>,
|
||||
pub(crate) receiver: queue::Receiver<Message>,
|
||||
pub(crate) queue: UnsafeCell<JobList>,
|
||||
pub(crate) heartbeat: OwnedHeartbeatReceiver,
|
||||
pub(crate) join_count: Cell<u8>,
|
||||
|
@ -37,6 +40,7 @@ impl WorkerThread {
|
|||
let heartbeat = context.heartbeats.new_heartbeat();
|
||||
|
||||
Self {
|
||||
receiver: context.queue.new_receiver(),
|
||||
context,
|
||||
queue: UnsafeCell::new(JobList::new()),
|
||||
heartbeat,
|
||||
|
@ -82,85 +86,26 @@ impl WorkerThread {
|
|||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn run_inner(&self) {
|
||||
let mut job = None;
|
||||
'outer: loop {
|
||||
if let Some(job) = job.take() {
|
||||
self.execute(job);
|
||||
loop {
|
||||
if self.context.should_exit() {
|
||||
break;
|
||||
}
|
||||
|
||||
// no more jobs, wait to be notified of a new job or a heartbeat.
|
||||
while job.is_none() {
|
||||
if self.context.should_exit() {
|
||||
// if the context is stopped, break out of the outer loop which
|
||||
// will exit the thread.
|
||||
break 'outer;
|
||||
match self.receiver.recv() {
|
||||
Message::Shared(shared_job) => {
|
||||
self.execute(shared_job);
|
||||
}
|
||||
|
||||
job = self.find_work_or_wait();
|
||||
Message::Exit => break,
|
||||
Message::WakeUp | Message::ScopeFinished => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerThread {
|
||||
/// Looks for work in the local queue, then in the shared context, and if no
|
||||
/// work is found, waits for the thread to be notified of a new job, after
|
||||
/// which it returns `None`.
|
||||
/// The caller should then check for `should_exit` to determine if the
|
||||
/// thread should exit, or look for work again.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub(crate) fn find_work_or_wait(&self) -> Option<SharedJob> {
|
||||
if let Some(job) = self.find_work() {
|
||||
return Some(job);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("waiting for new job");
|
||||
self.heartbeat.parker().park();
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("woken up from wait");
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub(crate) fn find_work_or_wait_unless<F>(&self, mut pred: F) -> Option<SharedJob>
|
||||
where
|
||||
F: FnMut() -> bool,
|
||||
{
|
||||
if let Some(job) = self.find_work() {
|
||||
return Some(job);
|
||||
}
|
||||
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
// Check the predicate while holding the lock. This is very important,
|
||||
// because the lock must be held when notifying us of the result of a
|
||||
// job we scheduled.
|
||||
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
// no jobs found, wait for a heartbeat or a new job
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(worker = self.heartbeat.index(), "waiting for new job");
|
||||
if !pred() {
|
||||
self.heartbeat.parker().park();
|
||||
}
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(worker = self.heartbeat.index(), "woken up from wait");
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn find_work(&self) -> Option<SharedJob> {
|
||||
let mut guard = self.context.shared();
|
||||
|
||||
if let Some(job) = guard.pop_job() {
|
||||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_jobs_stolen.fetch_add(1, Ordering::Relaxed);
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("WorkerThread::find_work_inner: found shared job: {:?}", job);
|
||||
return Some(job);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Checks if the worker thread has received a heartbeat, and if so,
|
||||
/// attempts to share a job with other workers. If a job was popped from
|
||||
/// the queue, but not shared, this function runs the job locally.
|
||||
pub(crate) fn tick(&self) {
|
||||
if self.heartbeat.take() {
|
||||
#[cfg(feature = "metrics")]
|
||||
|
@ -170,6 +115,7 @@ impl WorkerThread {
|
|||
"received heartbeat, thread id: {:?}",
|
||||
self.heartbeat.index()
|
||||
);
|
||||
|
||||
self.heartbeat_cold();
|
||||
}
|
||||
}
|
||||
|
@ -177,28 +123,31 @@ impl WorkerThread {
|
|||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn execute(&self, job: SharedJob) {
|
||||
unsafe { SharedJob::execute(job, self) };
|
||||
self.tick();
|
||||
// TODO: maybe tick here?
|
||||
}
|
||||
|
||||
/// Attempts to share a job with other workers within the same context.
|
||||
/// returns `true` if the job was shared, `false` if it was not.
|
||||
#[cold]
|
||||
fn heartbeat_cold(&self) {
|
||||
let mut guard = self.context.shared();
|
||||
if let Some(job) = self.pop_back() {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("heartbeat: sharing job: {:?}", job);
|
||||
|
||||
if !guard.jobs.contains_key(&self.heartbeat.id()) {
|
||||
if let Some(job) = self.pop_back() {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("heartbeat: sharing job: {:?}", job);
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
|
||||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if let Err(Message::Shared(job)) =
|
||||
self.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.try_anycast(Message::Shared(unsafe {
|
||||
job.as_ref()
|
||||
.share(Some(self.receiver.get_token().as_parker()))
|
||||
}))
|
||||
{
|
||||
unsafe {
|
||||
guard.jobs.insert(
|
||||
self.heartbeat.id(),
|
||||
job.as_ref().share(Some(self.heartbeat.parker())),
|
||||
);
|
||||
// SAFETY: we are holding the lock on the shared context.
|
||||
self.context.notify_job_shared();
|
||||
SharedJob::execute(job, self);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -264,13 +213,14 @@ impl WorkerThread {
|
|||
|
||||
pub struct HeartbeatThread {
|
||||
ctx: Arc<Context>,
|
||||
num_workers: usize,
|
||||
}
|
||||
|
||||
impl HeartbeatThread {
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
|
||||
|
||||
pub fn new(ctx: Arc<Context>) -> Self {
|
||||
Self { ctx }
|
||||
pub fn new(ctx: Arc<Context>, num_workers: usize) -> Self {
|
||||
Self { ctx, num_workers }
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
|
||||
|
@ -282,6 +232,14 @@ impl HeartbeatThread {
|
|||
let mut i = 0;
|
||||
loop {
|
||||
let sleep_for = {
|
||||
// loop {
|
||||
// if self.ctx.should_exit() || self.ctx.queue.num_receivers() != self.num_workers
|
||||
// {
|
||||
// break;
|
||||
// }
|
||||
|
||||
// self.ctx.heartbeat.park();
|
||||
// }
|
||||
if self.ctx.should_exit() {
|
||||
break;
|
||||
}
|
||||
|
@ -306,88 +264,18 @@ impl HeartbeatThread {
|
|||
}
|
||||
|
||||
impl WorkerThread {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
|
||||
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> {
|
||||
let recv = (*job).take_receiver().unwrap();
|
||||
|
||||
let mut out = recv.poll();
|
||||
|
||||
while std::hint::unlikely(out.is_none()) {
|
||||
if let Some(job) = self.find_work() {
|
||||
unsafe {
|
||||
SharedJob::execute(job, self);
|
||||
}
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> std::thread::Result<T> {
|
||||
loop {
|
||||
if let Some(result) = recv.poll() {
|
||||
break result;
|
||||
}
|
||||
|
||||
out = recv.poll();
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> Option<std::thread::Result<T>> {
|
||||
if self
|
||||
.context
|
||||
.shared()
|
||||
.jobs
|
||||
.remove(&self.heartbeat.id())
|
||||
.is_some()
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("reclaiming shared job");
|
||||
return None;
|
||||
}
|
||||
|
||||
while recv.is_empty() {
|
||||
if let Some(job) = self.find_work() {
|
||||
unsafe {
|
||||
SharedJob::execute(job, self);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
recv.wait();
|
||||
}
|
||||
|
||||
Some(recv.recv())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn wait_until_pred<F>(&self, mut pred: F)
|
||||
where
|
||||
F: FnMut() -> bool,
|
||||
{
|
||||
if !pred() {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("thread {:?} waiting on predicate", self.heartbeat.index());
|
||||
self.wait_until_latch_cold(pred);
|
||||
}
|
||||
}
|
||||
|
||||
#[cold]
|
||||
fn wait_until_latch_cold<F>(&self, mut pred: F)
|
||||
where
|
||||
F: FnMut() -> bool,
|
||||
{
|
||||
if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"thread {:?} reclaiming shared job: {:?}",
|
||||
self.heartbeat.index(),
|
||||
shared_job
|
||||
);
|
||||
unsafe { SharedJob::execute(shared_job, self) };
|
||||
}
|
||||
|
||||
// do the usual thing and wait for the job's latch
|
||||
// do the usual thing??? chatgipity really said this..
|
||||
while !pred() {
|
||||
// check local jobs before locking shared context
|
||||
if let Some(job) = self.find_work() {
|
||||
unsafe {
|
||||
SharedJob::execute(job, self);
|
||||
}
|
||||
match self.receiver.recv() {
|
||||
Message::Shared(shared_job) => unsafe {
|
||||
SharedJob::execute(shared_job, self);
|
||||
},
|
||||
Message::WakeUp | Message::Exit | Message::ScopeFinished => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,6 +87,7 @@ fn join_distaff(tree_size: usize) {
|
|||
let sum = sum(&tree, tree.root().unwrap(), s);
|
||||
sum
|
||||
});
|
||||
eprintln!("sum: {sum}");
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
}
|
||||
|
@ -134,7 +135,7 @@ fn join_rayon(tree_size: usize) {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
//tracing_subscriber::fmt::init();
|
||||
// tracing_subscriber::fmt::init();
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
tracing::subscriber::set_global_default(
|
||||
tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()),
|
||||
|
@ -166,6 +167,7 @@ fn main() {
|
|||
}
|
||||
|
||||
eprintln!("Done!");
|
||||
println!("Done!");
|
||||
// // wait for user input before exiting
|
||||
// std::io::stdin().read_line(&mut String::new()).unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue