use std::{hint::cold_path, ptr::NonNull, sync::Arc};
use crate::{
context::Context,
job::{JobState, StackJob},
latch::{AsCoreLatch, LatchRef, UnsafeWakeLatch, WakeLatch},
util::SendPtr,
workerthread::WorkerThread,
};
impl WorkerThread {
#[inline]
fn join_seq(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
let rb = b();
let ra = a();
(ra, rb)
}
/// This function must be called from a worker thread.
#[inline]
pub(crate) fn join_heartbeat_every(
&self,
a: A,
b: B,
) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
// SAFETY: each worker is only ever used by one thread, so this is safe.
let count = self.join_count.get();
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
// see: chili
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
if count == 0 || unsafe { self.queue.as_ref_unchecked().len() } < 3 {
cold_path();
self.join_heartbeat(a, b)
} else {
self.join_seq(a, b)
}
}
/// This function must be called from a worker thread.
#[inline]
fn join_heartbeat(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
// SAFETY: this thread's heartbeat latch is valid until the job sets it
// because we will be waiting on it.
let latch = unsafe { UnsafeWakeLatch::new(&raw const (*self.heartbeat).latch) };
let a = StackJob::new(a, LatchRef::new(&latch));
let job = a.as_job();
self.push_back(&job);
self.tick();
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
Ok(val) => val,
Err(payload) => {
cold_path();
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
// if b panicked, we need to wait for a to finish
self.wait_until_latch(&latch);
resume_unwind(payload);
}
};
let ra = if job.state() == JobState::Empty as u8 {
// remove job from the queue, so it doesn't get run again.
// job.unlink();
//SAFETY: we are in a worker thread, so we can safely access the queue.
// unsafe {
// self.queue.as_mut_unchecked().remove(&job);
// }
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
self.pop_back();
// a is allowed to panic here, because we already finished b.
unsafe { a.unwrap()() }
} else {
match self.wait_until_job::(unsafe { job.transmute_ref() }, latch.as_core_latch()) {
Some(t) => t.into_result(), // propagate panic here
// the job was shared, but not yet stolen, so we get to run the
// job inline
None => unsafe { a.unwrap()() },
}
};
drop(a);
(ra, rb)
}
}
impl Context {
#[inline]
pub fn join(self: &Arc, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b))
}
}
/// run two closures potentially in parallel, in the global threadpool.
#[allow(dead_code)]
pub fn join(a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
join_in(Context::global_context().clone(), a, b)
}
/// run two closures potentially in parallel, in the global threadpool.
#[allow(dead_code)]
fn join_in(context: Arc, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
context.join(a, b)
}