From 448d2d02b43357c1582252bd04b6a1892ba21fc2 Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 20 Jun 2025 15:35:34 +0200 Subject: [PATCH] moved join methods to workerthread --- src/lib.rs | 1 - src/praetor/mod.rs | 206 +++++++++++++++++++++---------------------- src/praetor/tests.rs | 2 +- 3 files changed, 101 insertions(+), 108 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d79a764..7444558 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ vec_deque_pop_if, unsafe_cell_access, debug_closure_helpers, - cell_update, cold_path, fn_align, box_vec_non_null, diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index 62c9d11..28c9d0b 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -87,7 +87,7 @@ mod util { // Miri doesn't like tagging pointers that it doesn't know the alignment of. // This includes function pointers, which aren't guaranteed to be aligned to // anything, but generally have an alignment of 8, and can be specified to - // be aligned to `n` with `#[repr(align(n))]`. + // be aligned to `n` with `#[align(n)]`. #[repr(transparent)] pub struct TaggedAtomicPtr { ptr: AtomicPtr<()>, @@ -910,7 +910,7 @@ mod job { F: FnOnce() -> T + Send, T: Send, { - #[repr(align(8))] + #[align(8)] unsafe fn harness(this: *const (), job: *const Job<()>) where F: FnOnce() -> T + Send, @@ -963,7 +963,7 @@ mod job { F: FnOnce() -> T + Send, T: Send, { - #[repr(align(8))] + #[align(8)] unsafe fn harness(this: *const (), job: *const Job<()>) where F: FnOnce() -> T + Send, @@ -1225,6 +1225,96 @@ impl WorkerThread { self.heartbeat.store(false, Ordering::Relaxed); } + #[inline] + fn join_seq(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + let rb = b(); + let ra = a(); + + (ra, rb) + } + + /// This function must be called from a worker thread. + #[inline] + fn join_heartbeat_every(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + // SAFETY: each worker is only ever used by one thread, so this is safe. + let count = self.join_count.get(); + self.join_count.set(count.wrapping_add(1) % TIMES as u8); + + // TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq. + // see: chili + if self.join_count.get() == 1 { + self.join_heartbeat(a, b) + } else { + self.join_seq(a, b) + } + } + + /// This function must be called from a worker thread. + fn join_heartbeat(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; + + let a = StackJob::new( + move || { + // TODO: bench whether tick'ing here is good. + // turns out this actually costs a lot of time, likely because of the thread local check. + // WorkerThread::current_ref() + // .expect("stackjob is run in workerthread.") + // .tick(); + + a() + }, + NopLatch, + ); + + let job = a.as_job(); + self.push_front(&job); + + let rb = match catch_unwind(AssertUnwindSafe(|| b())) { + Ok(val) => val, + Err(payload) => { + cold_path(); + // if b panicked, we need to wait for a to finish + self.wait_until_job::(unsafe { job.transmute_ref::() }); + resume_unwind(payload); + } + }; + + let ra = if job.state() == JobState::Empty as u8 { + unsafe { + job.unlink(); + } + + // a is allowed to panic here, because we already finished b. + unsafe { a.unwrap()() } + } else { + match self.wait_until_job::(unsafe { job.transmute_ref::() }) { + Some(t) => t.into_result(), // propagate panic here + None => unsafe { a.unwrap()() }, + } + }; + + drop(a); + (ra, rb) + } + #[cold] fn wait_until_latch_cold(&self, latch: &Latch) { // does this optimise? @@ -1355,7 +1445,7 @@ impl<'scope> Scope<'scope> { #[allow(dead_code)] fn make_job T, T>(f: F) -> Job { - #[repr(align(8))] + #[align(8)] unsafe fn harness T, T>(this: *const (), job: *const Job) { let f = unsafe { Box::from_raw(this.cast::().cast_mut()) }; @@ -1462,7 +1552,7 @@ impl<'scope> Scope<'scope> { }; let schedule = move |runnable: Runnable| { - #[repr(align(8))] + #[align(8)] unsafe fn harness(this: *const (), job: *const Job) { let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); @@ -1522,12 +1612,10 @@ impl<'scope> Scope<'scope> { let worker = WorkerThread::current_ref().expect("join is run in workerthread."); let this = SendPtr::new_const(self).unwrap(); - unsafe { - worker.context.join_heartbeat_every::<_, _, _, _, 64>( - make_scope_closure(this, a), - make_scope_closure(this, b), - ) - } + worker.join_heartbeat_every::<_, _, _, _, 64>( + make_scope_closure(this, a), + make_scope_closure(this, b), + ) } fn from_context(ctx: Arc) -> Self { @@ -1708,101 +1796,7 @@ impl Context { B: FnOnce() -> RB + Send, { // SAFETY: join_heartbeat_every is safe to call from a worker thread. - self.run_in_worker(move |_| unsafe { self.join_heartbeat_every::<_, _, _, _, 64>(a, b) }) - } - - #[inline] - fn join_seq(&self, a: A, b: B) -> (RA, RB) - where - RA: Send, - RB: Send, - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, - { - let rb = b(); - let ra = a(); - - (ra, rb) - } - - /// This function must be called from a worker thread. - #[inline] - unsafe fn join_heartbeat_every(&self, a: A, b: B) -> (RA, RB) - where - RA: Send, - RB: Send, - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, - { - let worker = WorkerThread::current_ref().expect("join is run in workerthread."); - - // SAFETY: each worker is only ever used by one thread, so this is safe. - worker - .join_count - .set(worker.join_count.get().wrapping_add(1) % TIMES as u8); - - // TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq. - // see: chili - if worker.join_count.get() == 0 { - self.join_heartbeat(a, b) - } else { - self.join_seq(a, b) - } - } - - /// This function must be called from a worker thread. - unsafe fn join_heartbeat(&self, a: A, b: B) -> (RA, RB) - where - RA: Send, - RB: Send, - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, - { - let worker = WorkerThread::current_ref().expect("join is run in workerthread."); - - let a = StackJob::new( - move || { - // TODO: bench whether tick'ing here is good. - // turns out this actually costs a lot of time, likely because of the thread local check. - // WorkerThread::current_ref() - // .expect("stackjob is run in workerthread.") - // .tick(); - - a() - }, - NopLatch, - ); - - let job = a.as_job(); - worker.push_front(&job); - - use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; - let rb = match catch_unwind(AssertUnwindSafe(|| b())) { - Ok(val) => val, - Err(payload) => { - cold_path(); - // if b panicked, we need to wait for a to finish - worker.wait_until_job::(unsafe { job.transmute_ref::() }); - resume_unwind(payload); - } - }; - - let ra = if job.state() == JobState::Empty as u8 { - unsafe { - job.unlink(); - } - - // a is allowed to panic here, because we already finished b. - unsafe { a.unwrap()() } - } else { - match worker.wait_until_job::(unsafe { job.transmute_ref::() }) { - Some(t) => t.into_result(), // propagate panic here - None => unsafe { a.unwrap()() }, - } - }; - - drop(a); - (ra, rb) + self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b)) } /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. diff --git a/src/praetor/tests.rs b/src/praetor/tests.rs index ba2f049..2da68ca 100644 --- a/src/praetor/tests.rs +++ b/src/praetor/tests.rs @@ -459,7 +459,7 @@ fn join_many() { fn sum(tree: &Tree, node: usize, scope: &Scope) -> u32 { let node = tree.get(node); - let (l, r) = scope.join_heartbeat( + let (l, r) = scope.join( |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), |s| { node.right