230 lines
6.9 KiB
Rust
230 lines
6.9 KiB
Rust
#[cfg(feature = "metrics")]
|
|
use std::sync::atomic::Ordering;
|
|
|
|
use std::{hint::cold_path, sync::Arc};
|
|
|
|
use crate::{
|
|
context::Context,
|
|
job::{
|
|
Job2 as Job, StackJob,
|
|
traits::{InlineJob, IntoJob},
|
|
},
|
|
workerthread::WorkerThread,
|
|
};
|
|
|
|
impl WorkerThread {
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
A: FnOnce(&WorkerThread) -> RA,
|
|
B: FnOnce(&WorkerThread) -> RB,
|
|
{
|
|
let rb = b(self);
|
|
let ra = a(self);
|
|
|
|
(ra, rb)
|
|
}
|
|
|
|
pub(crate) fn join_heartbeat_every<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
A: FnOnce(&WorkerThread) -> RA + Send,
|
|
B: FnOnce(&WorkerThread) -> RB,
|
|
RA: Send,
|
|
{
|
|
// self.join_heartbeat_every_inner::<A, B, RA, RB, 2>(a, b)
|
|
self.join_heartbeat(a, b)
|
|
}
|
|
|
|
/// This function must be called from a worker thread.
|
|
#[allow(dead_code)]
|
|
#[inline(always)]
|
|
fn join_heartbeat_every_inner<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
A: FnOnce(&WorkerThread) -> RA + Send,
|
|
B: FnOnce(&WorkerThread) -> RB,
|
|
{
|
|
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
|
let count = self.join_count.get();
|
|
let queue_len = unsafe { self.queue.as_ref_unchecked().len() };
|
|
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
|
|
|
|
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
|
// see: chili
|
|
|
|
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
|
|
if count == 0 || queue_len < 3 {
|
|
cold_path();
|
|
self.join_heartbeat(a, b)
|
|
} else {
|
|
self.join_seq(a, b)
|
|
}
|
|
}
|
|
|
|
/// This function must be called from a worker thread.
|
|
#[allow(dead_code)]
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
pub(crate) fn join_heartbeat2_every<A, B, RA, RB, const TIMES: usize>(
|
|
&self,
|
|
a: A,
|
|
b: B,
|
|
) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
A: InlineJob<RA> + Copy,
|
|
B: FnOnce(&WorkerThread) -> RB,
|
|
{
|
|
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
|
let count = self.join_count.get();
|
|
let queue_len = unsafe { self.queue.as_ref_unchecked().len() };
|
|
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
|
|
|
|
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
|
// see: chili
|
|
|
|
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
|
|
if count == 0 || queue_len < 3 {
|
|
cold_path();
|
|
self.join_heartbeat2(a, b)
|
|
} else {
|
|
(a.run_inline(self), b(self))
|
|
}
|
|
}
|
|
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
pub(crate) fn join_heartbeat2<RA, A, B, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
B: FnOnce(&WorkerThread) -> RB,
|
|
A: InlineJob<RA> + Copy,
|
|
RA: Send,
|
|
{
|
|
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
|
|
|
#[cfg(feature = "metrics")]
|
|
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
|
|
|
|
let job = a.into_job();
|
|
|
|
self.push_back(&job);
|
|
|
|
self.tick();
|
|
|
|
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
|
Ok(val) => val,
|
|
Err(payload) => {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
|
cold_path();
|
|
|
|
// if b panicked, we need to wait for a to finish
|
|
if job.is_shared() {
|
|
_ = self.wait_until_recv::<RA>();
|
|
}
|
|
|
|
resume_unwind(payload);
|
|
}
|
|
};
|
|
|
|
let ra = if job.is_shared() {
|
|
crate::util::unwrap_or_panic(self.wait_until_recv())
|
|
} else {
|
|
self.pop_back();
|
|
|
|
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
|
a.run_inline(self)
|
|
};
|
|
|
|
(ra, rb)
|
|
}
|
|
|
|
/// This function must be called from a worker thread.
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
A: FnOnce(&WorkerThread) -> RA + Send,
|
|
B: FnOnce(&WorkerThread) -> RB,
|
|
{
|
|
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
|
|
|
#[cfg(feature = "metrics")]
|
|
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
|
|
|
|
let a = StackJob::new(a);
|
|
let job = Job::from_stackjob(&a);
|
|
|
|
self.push_back(&job);
|
|
|
|
self.tick();
|
|
|
|
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
|
Ok(val) => val,
|
|
Err(payload) => {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
|
cold_path();
|
|
|
|
// if b panicked, we need to wait for a to finish
|
|
if job.is_shared() {
|
|
_ = self.wait_until_recv::<RA>();
|
|
}
|
|
|
|
resume_unwind(payload);
|
|
}
|
|
};
|
|
|
|
let ra = if job.is_shared() {
|
|
crate::util::unwrap_or_panic(self.wait_until_recv())
|
|
} else {
|
|
self.pop_back();
|
|
|
|
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
|
a.run_inline(self)
|
|
};
|
|
|
|
(ra, rb)
|
|
}
|
|
}
|
|
|
|
impl Context {
|
|
pub fn join<A, B, RA, RB>(self: &Arc<Self>, a: A, b: B) -> (RA, RB)
|
|
where
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
RA: Send,
|
|
RB: Send,
|
|
{
|
|
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
|
self.run_in_worker(move |worker| {
|
|
worker.join_heartbeat_every::<_, _, _, _>(|_| a(), |_| b())
|
|
})
|
|
}
|
|
}
|
|
|
|
/// run two closures potentially in parallel, in the global threadpool.
|
|
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
|
where
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
RA: Send,
|
|
RB: Send,
|
|
{
|
|
join_in(Context::global_context().clone(), a, b)
|
|
}
|
|
|
|
/// run two closures potentially in parallel, in the global threadpool.
|
|
#[allow(dead_code)]
|
|
fn join_in<A, B, RA, RB>(context: Arc<Context>, a: A, b: B) -> (RA, RB)
|
|
where
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
RA: Send,
|
|
RB: Send,
|
|
{
|
|
context.join(a, b)
|
|
}
|