executor/distaff/src/join.rs

230 lines
6.9 KiB
Rust

#[cfg(feature = "metrics")]
use std::sync::atomic::Ordering;
use std::{hint::cold_path, sync::Arc};
use crate::{
context::Context,
job::{
Job2 as Job, StackJob,
traits::{InlineJob, IntoJob},
},
workerthread::WorkerThread,
};
impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&WorkerThread) -> RA,
B: FnOnce(&WorkerThread) -> RB,
{
let rb = b(self);
let ra = a(self);
(ra, rb)
}
pub(crate) fn join_heartbeat_every<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&WorkerThread) -> RA + Send,
B: FnOnce(&WorkerThread) -> RB,
RA: Send,
{
// self.join_heartbeat_every_inner::<A, B, RA, RB, 2>(a, b)
self.join_heartbeat(a, b)
}
/// This function must be called from a worker thread.
#[allow(dead_code)]
#[inline(always)]
fn join_heartbeat_every_inner<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
A: FnOnce(&WorkerThread) -> RA + Send,
B: FnOnce(&WorkerThread) -> RB,
{
// SAFETY: each worker is only ever used by one thread, so this is safe.
let count = self.join_count.get();
let queue_len = unsafe { self.queue.as_ref_unchecked().len() };
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
// see: chili
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
if count == 0 || queue_len < 3 {
cold_path();
self.join_heartbeat(a, b)
} else {
self.join_seq(a, b)
}
}
/// This function must be called from a worker thread.
#[allow(dead_code)]
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn join_heartbeat2_every<A, B, RA, RB, const TIMES: usize>(
&self,
a: A,
b: B,
) -> (RA, RB)
where
RA: Send,
A: InlineJob<RA> + Copy,
B: FnOnce(&WorkerThread) -> RB,
{
// SAFETY: each worker is only ever used by one thread, so this is safe.
let count = self.join_count.get();
let queue_len = unsafe { self.queue.as_ref_unchecked().len() };
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
// see: chili
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
if count == 0 || queue_len < 3 {
cold_path();
self.join_heartbeat2(a, b)
} else {
(a.run_inline(self), b(self))
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn join_heartbeat2<RA, A, B, RB>(&self, a: A, b: B) -> (RA, RB)
where
B: FnOnce(&WorkerThread) -> RB,
A: InlineJob<RA> + Copy,
RA: Send,
{
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
#[cfg(feature = "metrics")]
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
let job = a.into_job();
self.push_back(&job);
self.tick();
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
Ok(val) => val,
Err(payload) => {
#[cfg(feature = "tracing")]
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
cold_path();
// if b panicked, we need to wait for a to finish
if job.is_shared() {
_ = self.wait_until_recv::<RA>();
}
resume_unwind(payload);
}
};
let ra = if job.is_shared() {
crate::util::unwrap_or_panic(self.wait_until_recv())
} else {
self.pop_back();
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.run_inline(self)
};
(ra, rb)
}
/// This function must be called from a worker thread.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
A: FnOnce(&WorkerThread) -> RA + Send,
B: FnOnce(&WorkerThread) -> RB,
{
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
#[cfg(feature = "metrics")]
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
let a = StackJob::new(a);
let job = Job::from_stackjob(&a);
self.push_back(&job);
self.tick();
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
Ok(val) => val,
Err(payload) => {
#[cfg(feature = "tracing")]
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
cold_path();
// if b panicked, we need to wait for a to finish
if job.is_shared() {
_ = self.wait_until_recv::<RA>();
}
resume_unwind(payload);
}
};
let ra = if job.is_shared() {
crate::util::unwrap_or_panic(self.wait_until_recv())
} else {
self.pop_back();
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.run_inline(self)
};
(ra, rb)
}
}
impl Context {
pub fn join<A, B, RA, RB>(self: &Arc<Self>, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
self.run_in_worker(move |worker| {
worker.join_heartbeat_every::<_, _, _, _>(|_| a(), |_| b())
})
}
}
/// run two closures potentially in parallel, in the global threadpool.
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
join_in(Context::global_context().clone(), a, b)
}
/// run two closures potentially in parallel, in the global threadpool.
#[allow(dead_code)]
fn join_in<A, B, RA, RB>(context: Arc<Context>, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
context.join(a, b)
}