152 lines
4.5 KiB
Rust
152 lines
4.5 KiB
Rust
use std::{hint::cold_path, ptr::NonNull, sync::Arc};
|
|
|
|
use crate::{
|
|
context::Context,
|
|
job::{JobState, StackJob},
|
|
latch::{AsCoreLatch, LatchRef, UnsafeWakeLatch, WakeLatch},
|
|
util::SendPtr,
|
|
workerthread::WorkerThread,
|
|
};
|
|
|
|
impl WorkerThread {
|
|
#[inline]
|
|
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
RB: Send,
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
{
|
|
let rb = b();
|
|
let ra = a();
|
|
|
|
(ra, rb)
|
|
}
|
|
|
|
/// This function must be called from a worker thread.
|
|
#[inline]
|
|
pub(crate) fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(
|
|
&self,
|
|
a: A,
|
|
b: B,
|
|
) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
RB: Send,
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
{
|
|
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
|
let count = self.join_count.get();
|
|
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
|
|
|
|
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
|
// see: chili
|
|
|
|
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
|
|
if count == 0 || unsafe { self.queue.as_ref_unchecked().len() } < 3 {
|
|
cold_path();
|
|
self.join_heartbeat(a, b)
|
|
} else {
|
|
self.join_seq(a, b)
|
|
}
|
|
}
|
|
|
|
/// This function must be called from a worker thread.
|
|
#[inline]
|
|
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
RB: Send,
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
{
|
|
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
|
|
|
// SAFETY: this thread's heartbeat latch is valid until the job sets it
|
|
// because we will be waiting on it.
|
|
let latch = unsafe { UnsafeWakeLatch::new(&raw const self.heartbeat().latch) };
|
|
|
|
let a = StackJob::new(a, LatchRef::new(&latch));
|
|
|
|
let job = a.as_job();
|
|
self.push_back(&job);
|
|
|
|
self.tick();
|
|
|
|
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
|
|
Ok(val) => val,
|
|
Err(payload) => {
|
|
cold_path();
|
|
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
|
// if b panicked, we need to wait for a to finish
|
|
self.wait_until_latch(&latch);
|
|
resume_unwind(payload);
|
|
}
|
|
};
|
|
|
|
let ra = if job.state() == JobState::Empty as u8 {
|
|
// remove job from the queue, so it doesn't get run again.
|
|
// job.unlink();
|
|
//SAFETY: we are in a worker thread, so we can safely access the queue.
|
|
// unsafe {
|
|
// self.queue.as_mut_unchecked().remove(&job);
|
|
// }
|
|
|
|
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
|
|
self.pop_back();
|
|
|
|
// a is allowed to panic here, because we already finished b.
|
|
unsafe { a.unwrap()() }
|
|
} else {
|
|
match self.wait_until_job::<RA>(unsafe { job.transmute_ref() }, latch.as_core_latch()) {
|
|
Some(t) => t.into_result(), // propagate panic here
|
|
// the job was shared, but not yet stolen, so we get to run the
|
|
// job inline
|
|
None => unsafe { a.unwrap()() },
|
|
}
|
|
};
|
|
|
|
drop(a);
|
|
(ra, rb)
|
|
}
|
|
}
|
|
|
|
impl Context {
|
|
#[inline]
|
|
pub fn join<A, B, RA, RB>(self: &Arc<Self>, a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
RB: Send,
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
{
|
|
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
|
self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b))
|
|
}
|
|
}
|
|
|
|
/// run two closures potentially in parallel, in the global threadpool.
|
|
#[allow(dead_code)]
|
|
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
RB: Send,
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
{
|
|
join_in(Context::global_context().clone(), a, b)
|
|
}
|
|
|
|
/// run two closures potentially in parallel, in the global threadpool.
|
|
#[allow(dead_code)]
|
|
fn join_in<A, B, RA, RB>(context: Arc<Context>, a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
RB: Send,
|
|
A: FnOnce() -> RA + Send,
|
|
B: FnOnce() -> RB + Send,
|
|
{
|
|
context.join(a, b)
|
|
}
|