diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index d6fcc59..62c9d11 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -1081,6 +1081,7 @@ struct WorkerThread { index: usize, heartbeat: Arc>, queue: UnsafeCell, + join_count: Cell, } pub struct Scope<'scope> { @@ -1115,7 +1116,7 @@ impl WorkerThread { index, heartbeat, queue: UnsafeCell::new(JobList::new()), - // join_count: Cell::new(0), + join_count: Cell::new(0), // _pd: PhantomData, } } @@ -1505,109 +1506,28 @@ impl<'scope> Scope<'scope> { A: FnOnce(&Self) -> RA + Send, B: FnOnce(&Self) -> RB + Send, { - self.join_heartbeat_every::<_, _, _, _, 64>(a, b) - // self.join_heartbeat(a, b) - } - - pub fn join_seq(&self, a: A, b: B) -> (RA, RB) - where - RA: Send, - RB: Send, - A: FnOnce(&Self) -> RA + Send, - B: FnOnce(&Self) -> RB + Send, - { - let rb = b(&self); - let ra = a(&self); - - (ra, rb) - } - - pub fn join_heartbeat_every(&self, a: A, b: B) -> (RA, RB) - where - RA: Send, - RB: Send, - A: FnOnce(&Self) -> RA + Send, - B: FnOnce(&Self) -> RB + Send, - { - thread_local! { - static JOIN_COUNT: Cell = Cell::new(0); + #[inline(always)] + fn make_scope_closure<'scope, A, RA>( + this: SendPtr>, + a: A, + ) -> impl FnOnce() -> RA + use<'scope, RA, A> + where + A: FnOnce(&Scope<'scope>) -> RA + Send, + RA: Send, + { + let scope = unsafe { this.as_ref() }; + move || a(scope) } - // a threadlocal counter is much faster than a sync atomic counter - let count = JOIN_COUNT.with(|count| { - count.set(count.get().wrapping_add(1) % TIMES); - count.get() - }); - - // let count = self.join_count.load(Ordering::Relaxed); - // self.join_count - // .store(count.wrapping_add(1) % TIMES, Ordering::Relaxed); - - // let count = self - // .join_count - // .update(Ordering::Relaxed, Ordering::Relaxed, |n| { - // n.wrapping_add(1) % TIMES - // }); - - if count == 1 { - self.join_heartbeat(a, b) - } else { - self.join_seq(a, b) - } - } - - pub fn join_heartbeat(&self, a: A, b: B) -> (RA, RB) - where - RA: Send, - RB: Send, - A: FnOnce(&Self) -> RA + Send, - B: FnOnce(&Self) -> RB + Send, - { let worker = WorkerThread::current_ref().expect("join is run in workerthread."); - let this = SendPtr::new_const(self).unwrap(); - let a = StackJob::new( - move || unsafe { - WorkerThread::current_ref() - .expect("stackjob is run in workerthread.") - .tick(); - let scope = this.as_ref(); - a(scope) - }, - NopLatch, - ); - - let job = a.as_job(); - worker.push_front(&job); - - use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; - let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) { - Ok(val) => val, - Err(payload) => { - cold_path(); - // if b panicked, we need to wait for a to finish - worker.wait_until_job::(unsafe { job.transmute_ref::() }); - resume_unwind(payload); - } - }; - - let ra = if job.state() == JobState::Empty as u8 { - unsafe { - job.unlink(); - } - - // a is allowed to panic here, because we already finished b. - unsafe { a.unwrap()() } - } else { - match worker.wait_until_job::(unsafe { job.transmute_ref::() }) { - Some(t) => t.into_result(), // propagate panic here - None => unsafe { a.unwrap()() }, - } - }; - - drop(a); - (ra, rb) + unsafe { + worker.context.join_heartbeat_every::<_, _, _, _, 64>( + make_scope_closure(this, a), + make_scope_closure(this, b), + ) + } } fn from_context(ctx: Arc) -> Self { @@ -1629,7 +1549,19 @@ where A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { - Scope::scope(|scope| scope.join(|_| a(), |_| b())) + join_in(Context::global().clone(), a, b) +} + +/// run two closures potentially in parallel, in the global threadpool. +#[allow(dead_code)] +pub fn join_in(context: Arc, a: A, b: B) -> (RA, RB) +where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, +{ + context.join(a, b) } pub struct ThreadPool { @@ -1649,6 +1581,16 @@ impl ThreadPool { } } + pub fn join(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + self.context.join(a, b) + } + pub fn scope<'scope, R, F>(&self, f: F) -> R where F: FnOnce(&Scope<'scope>) -> R + Send, @@ -1757,6 +1699,112 @@ impl Context { self.shared_job.notify_one(); } + #[inline] + pub fn join(self: &Arc, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + // SAFETY: join_heartbeat_every is safe to call from a worker thread. + self.run_in_worker(move |_| unsafe { self.join_heartbeat_every::<_, _, _, _, 64>(a, b) }) + } + + #[inline] + fn join_seq(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + let rb = b(); + let ra = a(); + + (ra, rb) + } + + /// This function must be called from a worker thread. + #[inline] + unsafe fn join_heartbeat_every(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + let worker = WorkerThread::current_ref().expect("join is run in workerthread."); + + // SAFETY: each worker is only ever used by one thread, so this is safe. + worker + .join_count + .set(worker.join_count.get().wrapping_add(1) % TIMES as u8); + + // TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq. + // see: chili + if worker.join_count.get() == 0 { + self.join_heartbeat(a, b) + } else { + self.join_seq(a, b) + } + } + + /// This function must be called from a worker thread. + unsafe fn join_heartbeat(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + let worker = WorkerThread::current_ref().expect("join is run in workerthread."); + + let a = StackJob::new( + move || { + // TODO: bench whether tick'ing here is good. + // turns out this actually costs a lot of time, likely because of the thread local check. + // WorkerThread::current_ref() + // .expect("stackjob is run in workerthread.") + // .tick(); + + a() + }, + NopLatch, + ); + + let job = a.as_job(); + worker.push_front(&job); + + use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; + let rb = match catch_unwind(AssertUnwindSafe(|| b())) { + Ok(val) => val, + Err(payload) => { + cold_path(); + // if b panicked, we need to wait for a to finish + worker.wait_until_job::(unsafe { job.transmute_ref::() }); + resume_unwind(payload); + } + }; + + let ra = if job.state() == JobState::Empty as u8 { + unsafe { + job.unlink(); + } + + // a is allowed to panic here, because we already finished b. + unsafe { a.unwrap()() } + } else { + match worker.wait_until_job::(unsafe { job.transmute_ref::() }) { + Some(t) => t.into_result(), // propagate panic here + None => unsafe { a.unwrap()() }, + } + }; + + drop(a); + (ra, rb) + } + /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. fn run_in_worker_cross(self: &Arc, worker: &WorkerThread, f: F) -> T where