From a3b9222ed9825b52d664005be214e9e69f94c730 Mon Sep 17 00:00:00 2001 From: Janis Date: Tue, 24 Jun 2025 19:12:27 +0200 Subject: [PATCH] so... --- distaff/src/context.rs | 5 ++- distaff/src/join.rs | 10 +++-- distaff/src/latch.rs | 74 +++++++++++++++++++++++++++++-------- distaff/src/scope.rs | 6 +-- distaff/src/workerthread.rs | 51 +++++++++++++++---------- 5 files changed, 104 insertions(+), 42 deletions(-) diff --git a/distaff/src/context.rs b/distaff/src/context.rs index 051eec4..b2350bb 100644 --- a/distaff/src/context.rs +++ b/distaff/src/context.rs @@ -14,7 +14,7 @@ use parking_lot::{Condvar, Mutex}; use crate::{ job::{HeapJob, Job, StackJob}, - latch::{LatchRef, MutexLatch, WakeLatch}, + latch::{LatchRef, MutexLatch, UnsafeWakeLatch}, workerthread::{HeartbeatThread, WorkerThread}, }; @@ -196,7 +196,8 @@ impl Context { { // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. - let latch = WakeLatch::new(self.clone(), worker.index); + // SAFETY: we are waiting on this latch in this thread. + let latch = unsafe { UnsafeWakeLatch::new(&raw const worker.heartbeat.latch) }; let job = StackJob::new( move || { diff --git a/distaff/src/join.rs b/distaff/src/join.rs index 3a4e4e9..5f13b3f 100644 --- a/distaff/src/join.rs +++ b/distaff/src/join.rs @@ -1,9 +1,10 @@ -use std::{hint::cold_path, sync::Arc}; +use std::{hint::cold_path, ptr::NonNull, sync::Arc}; use crate::{ context::Context, job::{JobState, StackJob}, - latch::{AsCoreLatch, LatchRef, WakeLatch}, + latch::{AsCoreLatch, LatchRef, UnsafeWakeLatch, WakeLatch}, + util::SendPtr, workerthread::WorkerThread, }; @@ -62,7 +63,10 @@ impl WorkerThread { { use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; - let latch = WakeLatch::new(self.context.clone(), self.index); + // SAFETY: this thread's heartbeat latch is valid until the job sets it + // because we will be waiting on it. + let latch = unsafe { UnsafeWakeLatch::new(&raw const (*self.heartbeat).latch) }; + let a = StackJob::new( move || { // TODO: bench whether tick'ing here is good. diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs index 767f86f..def9d72 100644 --- a/distaff/src/latch.rs +++ b/distaff/src/latch.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, atomic::AtomicU8}; use parking_lot::{Condvar, Mutex}; -use crate::context::Context; +use crate::{WorkerThread, context::Context}; pub trait Latch { unsafe fn set_raw(this: *const Self); @@ -325,17 +325,16 @@ impl Probe for MutexLatch { } } +/// Must only be `set` from a worker thread. pub struct WakeLatch { inner: AtomicLatch, - context: Arc, worker_index: AtomicUsize, } impl WakeLatch { - pub fn new(context: Arc, worker_index: usize) -> Self { + pub fn new(worker_index: usize) -> Self { Self { inner: AtomicLatch::new(), - context, worker_index: AtomicUsize::new(worker_index), } } @@ -349,10 +348,10 @@ impl Latch for WakeLatch { #[inline] unsafe fn set_raw(this: *const Self) { unsafe { - let ctx = (&*this).context.clone(); let worker_index = (&*this).worker_index.load(Ordering::Relaxed); if CoreLatch::set(&(&*this).inner) { + let ctx = WorkerThread::current_ref().unwrap().context.clone(); // If the latch was sleeping, wake the worker thread ctx.shared().heartbeats.get(&worker_index).and_then(|weak| { weak.upgrade() @@ -377,6 +376,48 @@ impl AsCoreLatch for WakeLatch { } } +pub struct UnsafeWakeLatch { + inner: AtomicLatch, + waker: *const MutexLatch, +} + +impl UnsafeWakeLatch { + /// # Safety + /// The `waker` must be valid until the latch is set. + pub unsafe fn new(waker: *const MutexLatch) -> Self { + Self { + inner: AtomicLatch::new(), + waker, + } + } +} + +impl Latch for UnsafeWakeLatch { + #[inline] + unsafe fn set_raw(this: *const Self) { + unsafe { + let waker = (*this).waker; + if CoreLatch::set(&(&*this).inner) { + Latch::set_raw(waker); + } + } + } +} + +impl Probe for UnsafeWakeLatch { + #[inline] + fn probe(&self) -> bool { + self.inner.probe() + } +} + +impl AsCoreLatch for UnsafeWakeLatch { + #[inline] + fn as_core_latch(&self) -> &CoreLatch { + &self.inner + } +} + #[cfg(test)] mod tests { use std::sync::Barrier; @@ -505,17 +546,18 @@ mod tests { let barrier = Arc::new(Barrier::new(2)); tracing::info!("running scope in worker thread"); - let latch = context.run_in_worker(|worker| { + context.run_in_worker(|worker| { tracing::info!("worker thread started: {:?}", worker.index); - let latch = WakeLatch::new(worker.context.clone(), worker.index); + let latch = Arc::new(WakeLatch::new(worker.index)); worker.context.spawn({ let heartbeat = worker.heartbeat.clone(); let barrier = barrier.clone(); let count = count.clone(); - // set sleeping outside of the closure so we don't have to deal with lifetimes - latch.as_core_latch().set_sleeping(); + let latch = latch.clone(); move || { tracing::info!("sleeping workerthread"); + + latch.as_core_latch().set_sleeping(); heartbeat.latch.wait_and_reset(); tracing::info!("woken up workerthread"); count.fetch_add(1, Ordering::SeqCst); @@ -524,14 +566,16 @@ mod tests { } }); - latch + worker.context.spawn({ + move || { + tracing::info!("setting latch in worker thread"); + unsafe { + Latch::set_raw(&*latch); + } + } + }); }); - tracing::info!("setting latch in main thread"); - unsafe { - Latch::set_raw(&latch); - } - tracing::info!("main thread set latch, waiting for worker thread to wake up"); barrier.wait(); assert_eq!( diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs index e663922..06625ca 100644 --- a/distaff/src/scope.rs +++ b/distaff/src/scope.rs @@ -256,10 +256,10 @@ impl<'scope, 'env> Scope<'scope, 'env> { ) } - unsafe fn from_context(ctx: Arc) -> Self { + unsafe fn from_context(context: Arc) -> Self { Self { - context: ctx.clone(), - job_counter: CountLatch::new(WakeLatch::new(ctx, 0)), + context, + job_counter: CountLatch::new(WakeLatch::new(0)), panic: AtomicPtr::new(ptr::null_mut()), _scope: PhantomData, _env: PhantomData, diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index bbc72a0..edec7eb 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -9,7 +9,7 @@ use crossbeam_utils::CachePadded; use crate::{ context::{Context, Heartbeat}, - job::{Job, JobResult, JobVec as JobList}, + job::{Job, JobList, JobResult}, latch::{AsCoreLatch, CoreLatch, Probe}, util::DropGuard, }; @@ -70,30 +70,43 @@ impl WorkerThread { self.execute(job); } - let mut guard = self.context.shared(); - if guard.should_exit() { - // if the context is stopped, break out of the outer loop which - // will exit the thread. - break 'outer; - } + // we executed the shared job, now we want to check for any + // local jobs which this job might have spawned. + let next = self + .pop_front() + .map(|job| (Some(job), None)) + .unwrap_or_else(|| { + let mut guard = self.context.shared(); + (guard.pop_job(), Some(guard)) + }); - // TODO: also check the local queue? - match guard.pop_job() { - Some(popped) => { - tracing::trace!("worker: popping job: {:?}", popped); - job = Some(popped); - // found job, continue inner loop - continue; - } - None => { - tracing::trace!("worker: no job, waiting for shared job"); - // no more jobs, break out of inner loop and wait for shared job + match next { + // no job, but guard => check if we should exit + (None, Some(guard)) => { + tracing::trace!("worker: no local job, waiting for shared job"); + + if guard.should_exit() { + // if the context is stopped, break out of the outer loop which + // will exit the thread. + break 'outer; + } + + // no local jobs, wait for shared job break guard; } + // some job => drop guard, continue inner loop + (Some(next), _) => { + tracing::trace!("worker: executing job: {:?}", next); + job = Some(next); + continue; + } + // no job, no guard ought to be unreachable. + _ => unreachable!(), } }; self.context.shared_job.wait(&mut guard); + // a job was shared and we were notified, so we want to execute that job before any possible local jobs. job = guard.pop_job(); } } @@ -101,7 +114,7 @@ impl WorkerThread { impl WorkerThread { #[inline(always)] - fn tick(&self) { + pub(crate) fn tick(&self) { if self.heartbeat.is_pending() { tracing::trace!("received heartbeat, thread id: {:?}", self.index); self.heartbeat_cold();