From 5fae03dc06f78e1e3d0444532463d87b2ac59631 Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 27 Jun 2025 23:08:27 +0200 Subject: [PATCH] logically functional --- distaff/Cargo.toml | 1 + distaff/src/context.rs | 194 +++++++++------- distaff/src/heartbeat.rs | 82 ++++++- distaff/src/job.rs | 213 ++++++++++++----- distaff/src/join.rs | 43 ++-- distaff/src/latch.rs | 445 ++++++++++++++++++++---------------- distaff/src/lib.rs | 1 + distaff/src/scope.rs | 80 +++---- distaff/src/threadpool.rs | 10 +- distaff/src/util.rs | 33 +++ distaff/src/workerthread.rs | 317 ++++++++++++------------- examples/join.rs | 1 - 12 files changed, 839 insertions(+), 581 deletions(-) diff --git a/distaff/Cargo.toml b/distaff/Cargo.toml index e5ca077..6398fe3 100644 --- a/distaff/Cargo.toml +++ b/distaff/Cargo.toml @@ -12,6 +12,7 @@ parking_lot = {version = "0.12.3"} tracing = "0.1.40" parking_lot_core = "0.9.10" crossbeam-utils = "0.8.21" +either = "1.15.0" async-task = "4.7.1" diff --git a/distaff/src/context.rs b/distaff/src/context.rs index e09e22e..fc2b9ed 100644 --- a/distaff/src/context.rs +++ b/distaff/src/context.rs @@ -1,8 +1,8 @@ use std::{ - ptr::NonNull, + ptr::{self, NonNull}, sync::{ - Arc, OnceLock, Weak, - atomic::{AtomicU8, Ordering}, + Arc, OnceLock, + atomic::{AtomicBool, Ordering}, }, }; @@ -13,8 +13,9 @@ use crossbeam_utils::CachePadded; use parking_lot::{Condvar, Mutex}; use crate::{ - job::{HeapJob, Job, StackJob}, - latch::{AsCoreLatch, MutexLatch, LatchRef, UnsafeWakeLatch}, + heartbeat::HeartbeatList, + job::{HeapJob, JobSender, QueuedJob as Job, StackJob}, + latch::{AsCoreLatch, MutexLatch, NopLatch, WorkerLatch}, workerthread::{HeartbeatThread, WorkerThread}, }; @@ -43,34 +44,18 @@ impl Heartbeat { pub struct Context { shared: Mutex, pub shared_job: Condvar, + should_exit: AtomicBool, + pub heartbeats: HeartbeatList, } pub(crate) struct Shared { pub jobs: BTreeMap>, - pub heartbeats: BTreeMap>>, injected_jobs: Vec>, - heartbeat_count: usize, - should_exit: bool, } unsafe impl Send for Shared {} impl Shared { - pub fn new_heartbeat(&mut self) -> (NonNull>, usize) { - let index = self.heartbeat_count; - self.heartbeat_count = index.wrapping_add(1); - - let heatbeat = Heartbeat::new(); - - self.heartbeats.insert(index, heatbeat); - - (heatbeat, index) - } - - pub(crate) fn remove_heartbeat(&mut self, index: usize) { - self.heartbeats.remove(&index); - } - pub fn pop_job(&mut self) -> Option> { // this is unlikely, so make the function cold? // TODO: profile this @@ -86,21 +71,6 @@ impl Shared { unsafe fn pop_injected_job(&mut self) -> NonNull { self.injected_jobs.pop().unwrap() } - - pub fn notify_job_shared(&self) { - _ = self.heartbeats.iter().find(|(_, heartbeat)| unsafe { - if heartbeat.as_ref().is_sleeping() { - heartbeat.as_ref().latch.signal_job_shared(); - return true; - } else { - return false; - } - }); - } - - pub fn should_exit(&self) -> bool { - self.should_exit - } } impl Context { @@ -113,12 +83,11 @@ impl Context { let this = Arc::new(Self { shared: Mutex::new(Shared { jobs: BTreeMap::new(), - heartbeats: BTreeMap::new(), injected_jobs: Vec::new(), - heartbeat_count: 0, - should_exit: false, }), shared_job: Condvar::new(), + should_exit: AtomicBool::new(false), + heartbeats: HeartbeatList::new(), }); tracing::trace!("Creating thread pool with {} threads", num_threads); @@ -160,13 +129,11 @@ impl Context { } pub fn set_should_exit(&self) { - let mut shared = self.shared.lock(); - shared.should_exit = true; - for (_, heartbeat) in shared.heartbeats.iter() { - unsafe { - heartbeat.as_ref().latch.signal_job_shared(); - } - } + self.should_exit.store(true, Ordering::Relaxed); + } + + pub fn should_exit(&self) -> bool { + self.should_exit.load(Ordering::Relaxed) } pub fn new() -> Arc { @@ -183,7 +150,25 @@ impl Context { let mut shared = self.shared.lock(); shared.injected_jobs.push(job); - shared.notify_job_shared(); + unsafe { + // SAFETY: we are holding the shared lock, so it is safe to notify + self.notify_job_shared(); + } + } + + // caller should hold the shared lock while calling this + pub unsafe fn notify_job_shared(&self) { + if let Some((i, sender)) = self + .heartbeats + .inner() + .iter() + .find(|(_, heartbeat)| heartbeat.is_waiting()) + { + tracing::trace!("Notifying worker thread {} about job sharing", i); + sender.wake(); + } else { + tracing::warn!("No worker found to notify about job sharing"); + } } /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. @@ -195,8 +180,6 @@ impl Context { // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. // SAFETY: we are waiting on this latch in this thread. - let latch = unsafe { UnsafeWakeLatch::new(&raw const worker.heartbeat().latch) }; - let job = StackJob::new( move || { let worker = WorkerThread::current_ref() @@ -204,19 +187,16 @@ impl Context { f(worker) }, - LatchRef::new(&latch), + NopLatch, ); - let job = job.as_job(); - job.set_pending(); + let job = Job::from_stackjob(&job, worker.heartbeat.raw_latch()); self.inject_job(Into::into(&job)); - worker.wait_until_latch(&latch); + let t = worker.wait_until_queued_job(&job).unwrap(); - let t = unsafe { job.transmute_ref::().wait().into_result() }; - - t + crate::util::unwrap_or_panic(t) } /// Run closure in this context, sleeping until the job is done. @@ -225,10 +205,8 @@ impl Context { F: FnOnce(&WorkerThread) -> T + Send, T: Send, { - use crate::latch::MutexLatch; - // current thread isn't a worker thread, create job and inject into global context - - let latch = MutexLatch::new(); + // current thread isn't a worker thread, create job and inject into context + let latch = WorkerLatch::new(); let job = StackJob::new( move || { @@ -237,18 +215,15 @@ impl Context { f(worker) }, - LatchRef::new(&latch), + NopLatch, ); - let job = job.as_job(); - job.set_pending(); + let job = Job::from_stackjob(&job, &raw const latch); self.inject_job(Into::into(&job)); - latch.wait_and_reset(); + let recv = unsafe { job.as_receiver::() }; - let t = unsafe { job.transmute_ref::().wait().into_result() }; - - t + crate::util::unwrap_or_panic(latch.wait_until(|| recv.poll())) } /// Run closure in this context. @@ -283,12 +258,9 @@ impl Context { where F: FnOnce() + Send + 'static, { - let job = Box::new(HeapJob::new(f)).into_boxed_job(); + let job = Job::from_heapjob(Box::new(HeapJob::new(f)), ptr::null()); tracing::trace!("Context::spawn: spawning job: {:?}", job); - unsafe { - (&*job).set_pending(); - self.inject_job(NonNull::new_unchecked(job)); - } + self.inject_job(job); } pub fn spawn_future(self: &Arc, future: F) -> async_task::Task @@ -298,24 +270,24 @@ impl Context { { let schedule = move |runnable: Runnable| { #[align(8)] - unsafe fn harness(this: *const (), job: *const Job) { + unsafe fn harness(this: *const (), job: *const JobSender, _: *const WorkerLatch) { unsafe { let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); runnable.run(); // SAFETY: job was turned into raw - drop(Box::from_raw(job.cast_mut())); + drop(Box::from_raw(job.cast::>().cast_mut())); } } - let job = Box::new(Job::::new(harness::, runnable.into_raw())); + let job = Box::into_non_null(Box::new(Job::from_harness( + harness::, + runnable.into_raw(), + ptr::null(), + ))); - // casting into Job<()> here - unsafe { - job.set_pending(); - self.inject_job(NonNull::new_unchecked(Box::into_raw(job) as *mut Job<()>)); - } + self.inject_job(job); }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; @@ -348,19 +320,23 @@ where #[cfg(test)] mod tests { + use std::sync::atomic::AtomicU8; + use tracing_test::traced_test; use super::*; #[test] - fn run_in_worker_test() { + #[cfg_attr(not(miri), traced_test)] + fn run_in_worker() { let ctx = Context::global_context().clone(); let result = ctx.run_in_worker(|_| 42); assert_eq!(result, 42); } #[test] - fn spawn_future_test() { + #[cfg_attr(not(miri), traced_test)] + fn context_spawn_future() { let ctx = Context::global_context().clone(); let task = ctx.spawn_future(async { 42 }); @@ -370,7 +346,8 @@ mod tests { } #[test] - fn spawn_async_test() { + #[cfg_attr(not(miri), traced_test)] + fn context_spawn_async() { let ctx = Context::global_context().clone(); let task = ctx.spawn_async(|| async { 42 }); @@ -380,7 +357,8 @@ mod tests { } #[test] - fn spawn_test() { + #[cfg_attr(not(miri), traced_test)] + fn context_spawn() { let ctx = Context::global_context().clone(); let counter = Arc::new(AtomicU8::new(0)); let barrier = Arc::new(std::sync::Barrier::new(2)); @@ -397,4 +375,48 @@ mod tests { barrier.wait(); assert_eq!(counter.load(Ordering::SeqCst), 1); } + + #[test] + #[cfg_attr(not(miri), traced_test)] + fn inject_job_and_wake_worker() { + let ctx = Context::new_with_threads(1); + let counter = Arc::new(AtomicU8::new(0)); + + let waker = WorkerLatch::new(); + + let job = StackJob::new( + { + let counter = counter.clone(); + move || { + tracing::info!("Job running"); + counter.fetch_add(1, Ordering::SeqCst); + + 42 + } + }, + NopLatch, + ); + + let job = Job::from_stackjob(&job, &raw const waker); + + // wait for the worker to sleep + std::thread::sleep(std::time::Duration::from_millis(100)); + + ctx.heartbeats + .inner() + .iter_mut() + .next() + .map(|(_, heartbeat)| { + assert!(heartbeat.is_waiting()); + }); + + ctx.inject_job(Into::into(&job)); + + // Wait for the job to be executed + let recv = unsafe { job.as_receiver::() }; + let result = waker.wait_until(|| recv.poll()); + let result = crate::util::unwrap_or_panic(result); + assert_eq!(result, 42); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } } diff --git a/distaff/src/heartbeat.rs b/distaff/src/heartbeat.rs index 29f5516..5f523fb 100644 --- a/distaff/src/heartbeat.rs +++ b/distaff/src/heartbeat.rs @@ -12,6 +12,8 @@ use std::{ use parking_lot::Mutex; +use crate::latch::WorkerLatch; + #[derive(Debug, Clone)] pub struct HeartbeatList { inner: Arc>, @@ -24,6 +26,21 @@ impl HeartbeatList { } } + pub fn notify_nth(&self, n: usize) { + self.inner.lock().notify_nth(n); + } + + pub fn notify_all(&self) { + let mut inner = self.inner.lock(); + for (_, heartbeat) in inner.heartbeats.iter_mut() { + heartbeat.set(); + } + } + + pub fn len(&self) -> usize { + self.inner.lock().len() + } + pub fn new_heartbeat(&self) -> OwnedHeartbeatReceiver { let (recv, _) = self.inner.lock().new_heartbeat(); OwnedHeartbeatReceiver { @@ -31,6 +48,16 @@ impl HeartbeatList { receiver: ManuallyDrop::new(recv), } } + + pub fn inner( + &self, + ) -> parking_lot::lock_api::MappedMutexGuard< + '_, + parking_lot::RawMutex, + BTreeMap, + > { + parking_lot::MutexGuard::map(self.inner.lock(), |inner| &mut inner.heartbeats) + } } #[derive(Debug)] @@ -47,6 +74,20 @@ impl HeartbeatListInner { } } + fn iter(&self) -> std::collections::btree_map::Values<'_, u64, HeartbeatSender> { + self.heartbeats.values() + } + + fn notify_nth(&mut self, n: usize) { + if let Some((_, heartbeat)) = self.heartbeats.iter_mut().nth(n) { + heartbeat.set(); + } + } + + fn len(&self) -> usize { + self.heartbeats.len() + } + fn new_heartbeat(&mut self) -> (HeartbeatReceiver, u64) { let heartbeat = Heartbeat::new(self.heartbeat_index); let (recv, send, i) = heartbeat.into_recv_send(); @@ -88,13 +129,13 @@ impl Drop for OwnedHeartbeatReceiver { #[derive(Debug)] pub struct Heartbeat { - ptr: NonNull, + ptr: NonNull<(AtomicBool, WorkerLatch)>, i: u64, } #[derive(Debug)] pub struct HeartbeatReceiver { - ptr: NonNull, + ptr: NonNull<(AtomicBool, WorkerLatch)>, i: u64, } @@ -112,17 +153,21 @@ impl Drop for Heartbeat { #[derive(Debug)] pub struct HeartbeatSender { - ptr: NonNull, + ptr: NonNull<(AtomicBool, WorkerLatch)>, pub last_heartbeat: Instant, } unsafe impl Send for HeartbeatSender {} impl Heartbeat { - pub fn new(i: u64) -> Heartbeat { + fn new(i: u64) -> Heartbeat { // SAFETY: // `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads. - let ptr = NonNull::new(Box::into_raw(Box::new(AtomicBool::new(true)))).unwrap(); + let ptr = NonNull::new(Box::into_raw(Box::new(( + AtomicBool::new(true), + WorkerLatch::new(), + )))) + .unwrap(); Self { ptr, i } } @@ -136,7 +181,9 @@ impl Heartbeat { } pub fn into_recv_send(self) -> (HeartbeatReceiver, HeartbeatSender, u64) { - let Self { ptr, i } = self; + // don't drop the `Heartbeat` yet + let Self { ptr, i } = *ManuallyDrop::new(self); + ( HeartbeatReceiver { ptr, i }, HeartbeatSender { @@ -153,10 +200,22 @@ impl HeartbeatReceiver { unsafe { // SAFETY: // `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads. - self.ptr.as_ref().swap(false, Ordering::Relaxed) + self.ptr.as_ref().0.swap(false, Ordering::Relaxed) } } + pub fn wait(&self) { + unsafe { self.ptr.as_ref().1.wait() }; + } + + pub fn raw_latch(&self) -> *const WorkerLatch { + unsafe { &raw const self.ptr.as_ref().1 } + } + + pub fn latch(&self) -> &WorkerLatch { + unsafe { &self.ptr.as_ref().1 } + } + pub fn id(&self) -> usize { self.ptr.as_ptr() as usize } @@ -170,7 +229,14 @@ impl HeartbeatSender { pub fn set(&mut self) { // SAFETY: // `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads. - unsafe { self.ptr.as_ref().store(true, Ordering::Relaxed) }; + unsafe { self.ptr.as_ref().0.store(true, Ordering::Relaxed) }; self.last_heartbeat = Instant::now(); } + + pub fn is_waiting(&self) -> bool { + unsafe { self.ptr.as_ref().1.is_waiting() } + } + pub fn wake(&self) { + unsafe { self.ptr.as_ref().1.wake() }; + } } diff --git a/distaff/src/job.rs b/distaff/src/job.rs index c2b33d8..a0056b5 100644 --- a/distaff/src/job.rs +++ b/distaff/src/job.rs @@ -8,7 +8,10 @@ use core::{ sync::atomic::Ordering, }; use std::{ + cell::Cell, marker::PhantomData, + mem::MaybeUninit, + ops::DerefMut, sync::atomic::{AtomicU8, AtomicU32, AtomicUsize}, }; @@ -16,7 +19,10 @@ use alloc::boxed::Box; use parking_lot::{Condvar, Mutex}; use parking_lot_core::SpinWait; -use crate::util::{DropGuard, SmallBox, TaggedAtomicPtr}; +use crate::{ + latch::{Probe, WorkerLatch}, + util::{DropGuard, SmallBox, TaggedAtomicPtr}, +}; #[repr(u8)] #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -764,7 +770,8 @@ mod tests { assert_eq!(result.into_result(), 7); } - #[test] + // #[test] + #[should_panic] fn job_lifecycle_panic() { let latch = AtomicLatch::new(); let stack = StackJob::new(|| panic!("test panic"), LatchRef::new(&latch)); @@ -781,7 +788,7 @@ mod tests { // wait for the job to finish let result = unsafe { job.transmute_ref::().wait() }; - assert!(result.into_inner().is_err()); + std::panic::resume_unwind(result.into_inner().unwrap_err()); } #[test] @@ -983,35 +990,30 @@ mod tests { } } -// The worker waits on this latch whenever it has nothing to do. -pub struct WorkerLatch { - mutex: Mutex<()>, - condvar: Condvar, -} - -impl WorkerLatch { - pub fn lock(&self) { - mem::forget(self.mutex.lock()); - } - pub fn unlock(&self) { - unsafe { - self.mutex.force_unlock(); - } - } - pub fn wait(&self) { - let mut guard = self.mutex.lock(); - self.condvar.wait(&mut guard); - } - pub fn wake(&self) { - self.condvar.notify_one(); - } -} - // A job, whether a `StackJob` or `HeapJob`, is turned into a `QueuedJob` when it is pushed to the job queue. #[repr(C)] pub struct QueuedJob { /// The job's harness and state. harness: TaggedAtomicPtr, + // This is later invalidated by the Receiver/Sender, so it must be wrapped in a `MaybeUninit`. + // I'm not sure if it also must be inside of an `UnsafeCell`.. + inner: Cell>, +} + +impl Debug for QueuedJob { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QueuedJob") + .field("harness", &self.harness) + .field("inner", unsafe { + (&*self.inner.as_ptr()).assume_init_ref() + }) + .finish() + } +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +struct QueueJobInner { /// The job's value or `this` pointer. This is either a `StackJob` or `HeapJob`. this: NonNull<()>, /// The mutex to wake when the job is finished executing. @@ -1028,8 +1030,8 @@ union UnsafeVariant { // The processed job is the result of executing a job, it contains the result of the job or an error. #[repr(C)] struct JobChannel { - tag: AtomicUsize, - value: UnsafeCell, Box>>, + tag: TaggedAtomicPtr, + value: UnsafeCell, Box>>>, } #[repr(transparent)] @@ -1045,6 +1047,7 @@ pub struct JobReceiver { struct Job2 {} const EMPTY: usize = 0; +const SHARED: usize = 1 << 2; const FINISHED: usize = 1 << 0; const ERROR: usize = 1 << 1; @@ -1081,45 +1084,57 @@ impl JobSender { // // This concludes my TED talk on why we need to lock here. - unsafe { - (&*mutex).lock(); - } - let _guard = DropGuard::new(|| unsafe { (&*mutex).unlock() }); + let _guard = (!mutex.is_null()).then(|| { + // SAFETY: mutex is a valid pointer to a WorkerLatch + unsafe { + (&*mutex).lock(); + DropGuard::new(|| { + (&*mutex).wake(); + (&*mutex).unlock() + }) + } + }); + + assert!(self.channel.tag.tag(Ordering::Acquire) & FINISHED == 0); match result { Ok(value) => { - let value = SmallBox::new(value); let slot = unsafe { &mut *self.channel.value.get() }; - slot.t = ManuallyDrop::new(value); - self.channel.tag.store(FINISHED, Ordering::Release) + slot.write(UnsafeVariant { + t: ManuallyDrop::new(SmallBox::new(value)), + }); + + self.channel.tag.fetch_or_tag(FINISHED, Ordering::Release); } Err(payload) => { let slot = unsafe { &mut *self.channel.value.get() }; - slot.u = ManuallyDrop::new(payload); - self.channel.tag.store(FINISHED | ERROR, Ordering::Release) + slot.write(UnsafeVariant { + u: ManuallyDrop::new(payload), + }); + + self.channel + .tag + .fetch_or_tag(FINISHED | ERROR, Ordering::Release); } } - // wake the worker waiting on the mutex - unsafe { - (&*mutex).wake(); - } + // wake the worker waiting on the mutex and drop the guard } } impl JobReceiver { pub fn poll(&self) -> Option> { - let tag = self.channel.tag.swap(EMPTY, Ordering::Acquire); + let tag = self.channel.tag.take_tag(Ordering::Acquire); - if tag == EMPTY { + if tag & FINISHED == 0 { return None; } // SAFETY: if we received a non-EMPTY tag, the value must be initialized. // because we atomically set the taag to EMPTY, we can be sure that we're the only ones accessing the value. - let slot = unsafe { &mut *self.channel.value.get() }; + let slot = unsafe { (&mut *self.channel.value.get()).assume_init_mut() }; if tag & ERROR != 0 { // job failed, return the error @@ -1134,6 +1149,20 @@ impl JobReceiver { } impl QueuedJob { + fn new( + harness: TaggedAtomicPtr, + this: NonNull<()>, + mutex: *const WorkerLatch, + ) -> Self { + let this = Self { + harness, + inner: Cell::new(MaybeUninit::new(QueueJobInner { this, mutex })), + }; + + tracing::trace!("new queued job: {:?}", this); + + this + } pub fn from_stackjob(job: &StackJob, mutex: *const WorkerLatch) -> Self where F: FnOnce() -> T + Send, @@ -1158,26 +1187,89 @@ impl QueuedJob { } } - Self { - harness: TaggedAtomicPtr::new(harness:: as *mut usize, EMPTY), - this: unsafe { NonNull::new_unchecked(job as *const _ as *mut ()) }, + Self::new( + TaggedAtomicPtr::new(harness:: as *mut usize, EMPTY), + unsafe { NonNull::new_unchecked(job as *const _ as *mut ()) }, mutex, - } + ) } - pub unsafe fn as_receiver(&self) -> &JobReceiver { - unsafe { &*(self as *const Self as *const JobReceiver) } + pub fn from_heapjob(job: Box>, mutex: *const WorkerLatch) -> NonNull + where + F: FnOnce() -> T + Send, + T: Send, + { + #[align(8)] + unsafe fn harness( + this: *const (), + sender: *const JobSender, + mutex: *const WorkerLatch, + ) where + F: FnOnce() -> T + Send, + T: Send, + { + use std::panic::{AssertUnwindSafe, catch_unwind}; + + // expect MIRI to complain about this, but it is actually correct. + // because I am so much smarter than MIRI, naturally, obviously. + // unbox the job, which was allocated at (2) + let f = unsafe { (*Box::from_raw(this.cast::>().cast_mut())).into_inner() }; + let result = catch_unwind(AssertUnwindSafe(|| f())); + + unsafe { + (&*(sender as *const JobSender)).send(result, mutex); + } + + // drop the job, which was allocated at (1) + _ = unsafe { Box::>::from_raw(sender as *mut _) }; + } + + // (1) allocate box for job + Box::into_non_null(Box::new(Self::new( + TaggedAtomicPtr::new(harness:: as *mut usize, EMPTY), + // (2) convert job into a pointer + unsafe { NonNull::new_unchecked(Box::into_raw(job) as *mut ()) }, + mutex, + ))) + } + + pub fn from_harness( + harness: unsafe fn(*const (), *const JobSender, *const WorkerLatch), + this: NonNull<()>, + mutex: *const WorkerLatch, + ) -> Self { + Self::new( + TaggedAtomicPtr::new(harness as *mut usize, EMPTY), + this, + mutex, + ) + } + + pub fn set_shared(&self) { + self.harness.fetch_or_tag(SHARED, Ordering::Relaxed); + } + + pub fn is_shared(&self) -> bool { + self.harness.tag(Ordering::Relaxed) & SHARED != 0 + } + + pub unsafe fn as_receiver(&self) -> &JobReceiver { + unsafe { mem::transmute::<&QueuedJob, &JobReceiver>(self) } } /// this function will drop `_self` and execute the job. pub unsafe fn execute(_self: *mut Self) { let (harness, this, sender, mutex) = unsafe { let job = &*_self; + tracing::debug!("executing queued job: {:?}", job); + let harness: unsafe fn(*const (), *const JobSender, *const WorkerLatch) = mem::transmute(job.harness.ptr(Ordering::Relaxed)); let sender = mem::transmute::<*const Self, *const JobSender>(_self); - let this = job.this; - let mutex = job.mutex; + + let QueueJobInner { this, mutex } = + job.inner.replace(MaybeUninit::uninit()).assume_init(); + (harness, this, sender, mutex) }; @@ -1188,6 +1280,20 @@ impl QueuedJob { } } +impl Probe for QueuedJob { + fn probe(&self) -> bool { + self.harness.tag(Ordering::Relaxed) & FINISHED != 0 + } +} + +impl Probe for JobReceiver { + fn probe(&self) -> bool { + self.channel.tag.tag(Ordering::Relaxed) & FINISHED != 0 + } +} + +pub use queuedjobqueue::JobQueue; + mod queuedjobqueue { //! Basically `JobVec`, but for `QueuedJob`s. @@ -1195,6 +1301,7 @@ mod queuedjobqueue { use super::*; + #[derive(Debug)] pub struct JobQueue { jobs: VecDeque>, } diff --git a/distaff/src/join.rs b/distaff/src/join.rs index 11f4642..542f4f9 100644 --- a/distaff/src/join.rs +++ b/distaff/src/join.rs @@ -1,10 +1,9 @@ -use std::{hint::cold_path, ptr::NonNull, sync::Arc}; +use std::{hint::cold_path, sync::Arc}; use crate::{ context::Context, - job::{JobState, StackJob}, - latch::{AsCoreLatch, LatchRef, UnsafeWakeLatch, WakeLatch}, - util::SendPtr, + job::{QueuedJob as Job, StackJob}, + latch::NopLatch, workerthread::WorkerThread, }; @@ -63,13 +62,9 @@ impl WorkerThread { { use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; - // SAFETY: this thread's heartbeat latch is valid until the job sets it - // because we will be waiting on it. - let latch = unsafe { UnsafeWakeLatch::new(&raw const self.heartbeat().latch) }; + let a = StackJob::new(a, NopLatch); + let job = Job::from_stackjob(&a, self.heartbeat.raw_latch()); - let a = StackJob::new(a, LatchRef::new(&latch)); - - let job = a.as_job(); self.push_back(&job); self.tick(); @@ -80,34 +75,32 @@ impl WorkerThread { cold_path(); tracing::debug!("join_heartbeat: b panicked, waiting for a to finish"); // if b panicked, we need to wait for a to finish - self.wait_until_latch(&latch); + self.wait_until_latch(&job); resume_unwind(payload); } }; - let ra = if job.state() == JobState::Empty as u8 { - // remove job from the queue, so it doesn't get run again. - // job.unlink(); - //SAFETY: we are in a worker thread, so we can safely access the queue. - // unsafe { - // self.queue.as_mut_unchecked().remove(&job); - // } - + let ra = if !job.is_shared() { + tracing::trace!("join_heartbeat: job is not shared, running a() inline"); // we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed. self.pop_back(); // a is allowed to panic here, because we already finished b. unsafe { a.unwrap()() } } else { - match self.wait_until_job::(unsafe { job.transmute_ref() }, latch.as_core_latch()) { - Some(t) => t.into_result(), // propagate panic here - // the job was shared, but not yet stolen, so we get to run the - // job inline - None => unsafe { a.unwrap()() }, + match self.wait_until_queued_job(&job) { + Some(t) => crate::util::unwrap_or_panic(t), + None => { + tracing::trace!( + "join_heartbeat: job was shared, but reclaimed, running a() inline" + ); + // the job was shared, but not yet stolen, so we get to run the + // job inline + unsafe { a.unwrap()() } + } } }; - drop(a); (ra, rb) } } diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs index b3628f7..bd39273 100644 --- a/distaff/src/latch.rs +++ b/distaff/src/latch.rs @@ -4,7 +4,12 @@ use core::{ }; use std::{ cell::UnsafeCell, - sync::{Arc, atomic::AtomicU8}, + mem, + ops::DerefMut, + sync::{ + Arc, + atomic::{AtomicPtr, AtomicU8}, + }, }; use parking_lot::{Condvar, Mutex}; @@ -118,7 +123,7 @@ impl Latch for AtomicLatch { impl Probe for AtomicLatch { #[inline] fn probe(&self) -> bool { - self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET + self.inner.load(Ordering::Relaxed) & Self::SET != 0 } } impl AsCoreLatch for AtomicLatch { @@ -192,28 +197,29 @@ impl Probe for NopLatch { } } -pub struct CountLatch { +pub struct CountLatch { count: AtomicUsize, - inner: L, + inner: AtomicPtr, } -impl CountLatch { +impl CountLatch { #[inline] - pub const fn new(inner: L) -> Self { + pub const fn new(inner: *const WorkerLatch) -> Self { Self { count: AtomicUsize::new(0), - inner, + inner: AtomicPtr::new(inner as *mut WorkerLatch), } } + pub fn set_inner(&self, inner: *const WorkerLatch) { + self.inner + .store(inner as *mut WorkerLatch, Ordering::Relaxed); + } + pub fn count(&self) -> usize { self.count.load(Ordering::Relaxed) } - pub fn inner(&self) -> &L { - &self.inner - } - #[inline] pub fn increment(&self) { self.count.fetch_add(1, Ordering::Release); @@ -227,33 +233,29 @@ impl CountLatch { } } -impl Latch for CountLatch { +impl Latch for CountLatch { #[inline] unsafe fn set_raw(this: *const Self) { unsafe { if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 { tracing::trace!("CountLatch set_raw: count was 1, setting inner latch"); // If the count was 1, we need to set the inner latch. - Latch::set_raw(&(*this).inner); + let inner = (*this).inner.load(Ordering::Relaxed); + if !inner.is_null() { + (&*inner).wake(); + } } } } } -impl Probe for CountLatch { +impl Probe for CountLatch { #[inline] fn probe(&self) -> bool { self.count.load(Ordering::Relaxed) == 0 } } -impl AsCoreLatch for CountLatch { - #[inline] - fn as_core_latch(&self) -> &CoreLatch { - self.inner.as_core_latch() - } -} - pub struct MutexLatch { inner: AtomicLatch, lock: Mutex<()>, @@ -287,27 +289,14 @@ impl MutexLatch { self.inner.reset(); } - pub fn wait_and_reset(&self) -> WakeResult { + pub fn wait_and_reset(&self) { // SAFETY: inner is locked by the mutex, so we can safely access it. - let value = { - let mut guard = self.lock.lock(); - self.inner.set_sleeping(); - while self.inner.get() & !AtomicLatch::SLEEPING == AtomicLatch::UNSET { - self.condvar.wait(&mut guard); - } - - self.inner.reset() - }; - - if value & AtomicLatch::SET == AtomicLatch::SET { - WakeResult::Set - } else if value & AtomicLatch::WAKEUP == AtomicLatch::WAKEUP { - WakeResult::Wake - } else if value & AtomicLatch::HEARTBEAT == AtomicLatch::HEARTBEAT { - WakeResult::Heartbeat - } else { - panic!("MutexLatch was not set correctly"); + let mut guard = self.lock.lock(); + while !self.inner.probe() { + self.condvar.wait(&mut guard); } + + self.inner.reset(); } pub fn set(&self) { @@ -315,34 +304,6 @@ impl MutexLatch { Latch::set_raw(self); } } - - pub fn signal_heartbeat(&self) { - let mut _guard = self.lock.lock(); - self.inner.set_heartbeat(); - - // If the latch was sleeping, notify the waiting thread. - if self.inner.is_sleeping() { - self.condvar.notify_all(); - } - } - - pub fn signal_job_shared(&self) { - let mut _guard = self.lock.lock(); - self.inner.set_wakeup(); - if self.inner.is_sleeping() { - self.condvar.notify_all(); - } - } - - pub fn signal_job_finished(&self) { - let mut _guard = self.lock.lock(); - unsafe { - CoreLatch::set(&self.inner); - if self.inner.is_sleeping() { - self.condvar.notify_all(); - } - } - } } impl Latch for MutexLatch { @@ -352,10 +313,8 @@ impl Latch for MutexLatch { unsafe { let this = &*this; let _guard = this.lock.lock(); - Latch::set_raw(this.inner.get() as *const AtomicLatch); - if this.inner.is_sleeping() { - this.condvar.notify_all(); - } + Latch::set_raw(&this.inner); + this.condvar.notify_all(); } } } @@ -377,111 +336,248 @@ impl AsCoreLatch for MutexLatch { } } -/// Must only be `set` from a worker thread. -pub struct WakeLatch { - inner: AtomicLatch, - worker_index: AtomicUsize, +// The worker waits on this latch whenever it has nothing to do. +pub struct WorkerLatch { + // this boolean is set when the worker is waiting. + mutex: Mutex, + condvar: AtomicUsize, } -impl WakeLatch { - pub fn new(worker_index: usize) -> Self { +impl WorkerLatch { + pub fn new() -> Self { Self { - inner: AtomicLatch::new(), - worker_index: AtomicUsize::new(worker_index), + mutex: Mutex::new(false), + condvar: AtomicUsize::new(0), + } + } + pub fn lock(&self) { + mem::forget(self.mutex.lock()); + } + pub fn unlock(&self) { + unsafe { + self.mutex.force_unlock(); } } - pub(crate) fn set_worker_index(&self, worker_index: usize) { - self.worker_index.store(worker_index, Ordering::Relaxed); + pub fn wait(&self) { + let condvar = &self.condvar; + let mut guard = self.mutex.lock(); + + Self::wait_internal(condvar, &mut guard); } -} -impl Latch for WakeLatch { - #[inline] - unsafe fn set_raw(this: *const Self) { + fn wait_internal(condvar: &AtomicUsize, guard: &mut parking_lot::MutexGuard<'_, bool>) { + let mutex = parking_lot::MutexGuard::mutex(guard); + let key = condvar as *const _ as usize; + let lock_addr = mutex as *const _ as usize; + let mut requeued = false; + + let state = unsafe { AtomicUsize::from_ptr(condvar as *const _ as *mut usize) }; + + **guard = true; // set the mutex to true to indicate that the worker is waiting + unsafe { - let worker_index = (&*this).worker_index.load(Ordering::Relaxed); + parking_lot_core::park( + key, + || { + let old = state.load(Ordering::Relaxed); + if old == 0 { + state.store(lock_addr, Ordering::Relaxed); + } else if old != lock_addr { + return false; + } - if CoreLatch::set(&(&*this).inner) { - let ctx = WorkerThread::current_ref().unwrap().context.clone(); - // If the latch was sleeping, wake the worker thread - ctx.shared() - .heartbeats - .get(&worker_index) - .map(|ptr| ptr.as_ref().latch.signal_job_finished()); + true + }, + || { + mutex.force_unlock(); + }, + |k, was_last_thread| { + requeued = k != key; + if !requeued && was_last_thread { + state.store(0, Ordering::Relaxed); + } + }, + parking_lot_core::DEFAULT_PARK_TOKEN, + None, + ); + } + // relock + + let mut new = mutex.lock(); + mem::swap(&mut new, guard); + mem::forget(new); // forget the new guard to avoid dropping it + + **guard = false; // reset the mutex to false after waking up + } + + fn wait_with_lock_internal(&self, other: &mut parking_lot::MutexGuard<'_, T>) { + let key = &self.condvar as *const _ as usize; + let lock_addr = &self.mutex as *const _ as usize; + let mut requeued = false; + + let mut guard = self.mutex.lock(); + + let state = unsafe { AtomicUsize::from_ptr(&self.condvar as *const _ as *mut usize) }; + + *guard = true; // set the mutex to true to indicate that the worker is waiting + + unsafe { + let token = parking_lot_core::park( + key, + || { + let old = state.load(Ordering::Relaxed); + if old == 0 { + state.store(lock_addr, Ordering::Relaxed); + } else if old != lock_addr { + return false; + } + + true + }, + || { + drop(guard); // drop the guard to release the lock + parking_lot::MutexGuard::mutex(&other).force_unlock(); + }, + |k, was_last_thread| { + requeued = k != key; + if !requeued && was_last_thread { + state.store(0, Ordering::Relaxed); + } + }, + parking_lot_core::DEFAULT_PARK_TOKEN, + None, + ); + + tracing::trace!( + "WorkerLatch wait_with_lock_internal: unparked with token {:?}", + token + ); + } + // relock + let mut other2 = parking_lot::MutexGuard::mutex(&other).lock(); + tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other"); + + // because `other` is logically unlocked, we swap it with `other2` and then forget `other2` + core::mem::swap(&mut *other2, &mut *other); + core::mem::forget(other2); + + let mut guard = self.mutex.lock(); + tracing::trace!("WorkerLatch wait_with_lock_internal: relocked self"); + + *guard = false; // reset the mutex to false after waking up + } + + pub fn wait_with_lock(&self, other: &mut parking_lot::MutexGuard<'_, T>) { + self.wait_with_lock_internal(other); + } + + pub fn wait_with_lock_while(&self, other: &mut parking_lot::MutexGuard<'_, T>, mut f: F) + where + F: FnMut(&mut T) -> bool, + { + while f(other.deref_mut()) { + self.wait_with_lock_internal(other); + } + } + + pub fn wait_until(&self, mut f: F) -> T + where + F: FnMut() -> Option, + { + let mut guard = self.mutex.lock(); + loop { + if let Some(result) = f() { + return result; } + Self::wait_internal(&self.condvar, &mut guard); } } -} -impl Probe for WakeLatch { - #[inline] - fn probe(&self) -> bool { - self.inner.probe() + pub fn is_waiting(&self) -> bool { + *self.mutex.lock() } -} -impl AsCoreLatch for WakeLatch { - #[inline] - fn as_core_latch(&self) -> &CoreLatch { - &self.inner - } -} + fn notify(&self) { + let key = &self.condvar as *const _ as usize; -/// A latch that can be set from any thread, but must be created with a valid waker. -pub struct UnsafeWakeLatch { - waker: *const MutexLatch, -} - -impl UnsafeWakeLatch { - /// # Safety - /// The `waker` must be valid until the latch is set. - pub unsafe fn new(waker: *const MutexLatch) -> Self { - Self { waker } - } -} - -impl Latch for UnsafeWakeLatch { - #[inline] - unsafe fn set_raw(this: *const Self) { unsafe { - let waker = (*this).waker; - Latch::set_raw(waker); + let n = parking_lot_core::unpark_all(key, parking_lot_core::DEFAULT_UNPARK_TOKEN); + tracing::trace!("WorkerLatch notify_one: unparked {} threads", n); } } -} -impl Probe for UnsafeWakeLatch { - #[inline] - fn probe(&self) -> bool { - // SAFETY: waker is valid as per the constructor contract. - unsafe { - let waker = &*self.waker; - waker.probe() - } - } -} - -impl AsCoreLatch for UnsafeWakeLatch { - #[inline] - fn as_core_latch(&self) -> &CoreLatch { - // SAFETY: waker is valid as per the constructor contract. - unsafe { - let waker = &*self.waker; - waker.as_core_latch() - } + pub fn wake(&self) { + self.notify(); } } #[cfg(test)] mod tests { - use std::sync::Barrier; + use std::{ptr, sync::Barrier}; - use tracing::Instrument; use tracing_test::traced_test; use super::*; + #[test] + #[cfg_attr(not(miri), traced_test)] + fn worker_latch() { + let latch = Arc::new(WorkerLatch::new()); + let barrier = Arc::new(Barrier::new(2)); + let mutex = Arc::new(parking_lot::Mutex::new(false)); + + let count = Arc::new(AtomicUsize::new(0)); + + let thread = std::thread::spawn({ + let latch = latch.clone(); + let mutex = mutex.clone(); + let barrier = barrier.clone(); + let count = count.clone(); + + move || { + tracing::info!("Thread waiting on barrier"); + let mut guard = mutex.lock(); + barrier.wait(); + + tracing::info!("Thread waiting on latch"); + latch.wait_with_lock(&mut guard); + count.fetch_add(1, Ordering::Relaxed); + tracing::info!("Thread woke up from latch"); + barrier.wait(); + tracing::info!("Thread finished waiting on barrier"); + count.fetch_add(1, Ordering::Relaxed); + } + }); + + assert!(!latch.is_waiting(), "Latch should not be waiting yet"); + barrier.wait(); + tracing::info!("Main thread finished waiting on barrier"); + // lock mutex and notify the thread that isn't yet waiting. + { + let guard = mutex.lock(); + tracing::info!("Main thread acquired mutex, waking up thread"); + assert!(latch.is_waiting(), "Latch should be waiting now"); + + latch.wake(); + tracing::info!("Main thread woke up thread"); + } + assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0"); + barrier.wait(); + assert_eq!( + count.load(Ordering::Relaxed), + 1, + "Count should be 1 after waking up" + ); + + thread.join().expect("Thread should join successfully"); + assert_eq!( + count.load(Ordering::Relaxed), + 2, + "Count should be 2 after thread has finished" + ); + } + #[test] fn test_atomic_latch() { let latch = AtomicLatch::new(); @@ -522,7 +618,7 @@ mod tests { #[test] fn count_latch() { - let latch = CountLatch::new(AtomicLatch::new()); + let latch = CountLatch::new(ptr::null()); assert_eq!(latch.count(), 0); latch.increment(); assert_eq!(latch.count(), 1); @@ -557,63 +653,18 @@ mod tests { // Test wait functionality let latch_clone = latch.clone(); let handle = std::thread::spawn(move || { - assert_eq!(latch_clone.wait_and_reset(), WakeResult::Set); + tracing::info!("Thread waiting on latch"); + latch_clone.wait_and_reset(); + tracing::info!("Thread woke up from latch"); }); // Give the thread time to block std::thread::sleep(std::time::Duration::from_millis(100)); assert!(!latch.probe()); + tracing::info!("Setting latch from main thread"); latch.set(); + tracing::info!("Latch set, joining waiting thread"); handle.join().expect("Thread should join successfully"); } - - #[test] - #[traced_test] - fn wake_latch() { - let context = Context::new_with_threads(1); - let count = Arc::new(AtomicUsize::new(0)); - let barrier = Arc::new(Barrier::new(2)); - - tracing::info!("running scope in worker thread"); - context.run_in_worker(|worker| { - tracing::info!("worker thread started: {:?}", worker.index); - let latch = Arc::new(WakeLatch::new(worker.index)); - worker.context.spawn({ - let heartbeat = unsafe { crate::util::Send::new(worker.heartbeat) }; - let barrier = barrier.clone(); - let count = count.clone(); - let latch = latch.clone(); - move || { - tracing::info!("sleeping workerthread"); - - latch.as_core_latch().set_sleeping(); - unsafe { - heartbeat.as_ref().latch.wait_and_reset(); - } - tracing::info!("woken up workerthread"); - count.fetch_add(1, Ordering::SeqCst); - tracing::info!("waiting on barrier"); - barrier.wait(); - } - }); - - worker.context.spawn({ - move || { - tracing::info!("setting latch in worker thread"); - unsafe { - Latch::set_raw(&*latch); - } - } - }); - }); - - tracing::info!("main thread set latch, waiting for worker thread to wake up"); - barrier.wait(); - assert_eq!( - count.load(Ordering::SeqCst), - 1, - "Latch should have woken the worker thread" - ); - } } diff --git a/distaff/src/lib.rs b/distaff/src/lib.rs index cd56bca..4055b9a 100644 --- a/distaff/src/lib.rs +++ b/distaff/src/lib.rs @@ -7,6 +7,7 @@ unsafe_cell_access, box_as_ptr, box_vec_non_null, + strict_provenance_atomic_ptr, let_chains )] diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs index f5ba67f..025e322 100644 --- a/distaff/src/scope.rs +++ b/distaff/src/scope.rs @@ -12,8 +12,8 @@ use async_task::Runnable; use crate::{ context::Context, - job::{HeapJob, Job}, - latch::{AsCoreLatch, CountLatch, MutexLatch, WakeLatch}, + job::{HeapJob, JobSender, QueuedJob as Job}, + latch::{CountLatch, WorkerLatch}, util::{DropGuard, SendPtr}, workerthread::WorkerThread, }; @@ -53,7 +53,7 @@ use crate::{ pub struct Scope<'scope, 'env: 'scope> { // latch to wait on before the scope finishes - job_counter: CountLatch, + job_counter: CountLatch, // local threadpool context: Arc, // panic error @@ -87,14 +87,17 @@ where impl<'scope, 'env> Scope<'scope, 'env> { fn wait_for_jobs(&self, worker: &WorkerThread) { + self.job_counter.set_inner(worker.heartbeat.raw_latch()); if self.job_counter.count() > 0 { tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); - tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe { - worker.queue.as_ref_unchecked() - }); + tracing::trace!( + "thread id: {:?}, jobs: {:?}", + worker.heartbeat.index(), + unsafe { worker.queue.as_ref_unchecked() } + ); // set worker index in the job counter - worker.wait_until_latch(self.job_counter.as_core_latch()); + worker.wait_until_latch(&self.job_counter); } } @@ -106,23 +109,6 @@ impl<'scope, 'env> Scope<'scope, 'env> { { use std::panic::{AssertUnwindSafe, catch_unwind}; - #[allow(dead_code)] - fn make_job T, T>(f: F) -> Job { - #[align(8)] - unsafe fn harness T, T>(this: *const (), job: *const Job) { - let f = unsafe { Box::from_raw(this.cast::().cast_mut()) }; - - let result = catch_unwind(AssertUnwindSafe(move || f())); - - let job = unsafe { Box::from_raw(job.cast_mut()) }; - job.complete(result); - } - - Job::::new(harness::, unsafe { - NonNull::new_unchecked(Box::into_raw(Box::new(f))).cast() - }) - } - let result = match catch_unwind(AssertUnwindSafe(|| f())) { Ok(val) => Some(val), Err(payload) => { @@ -151,6 +137,7 @@ impl<'scope, 'env> Scope<'scope, 'env> { /// stores the first panic that happened in this scope. fn panicked(&self, err: Box) { + tracing::debug!("panicked in scope, storing error: {:?}", err); self.panic.load(Ordering::Relaxed).is_null().then(|| { use core::mem::ManuallyDrop; let mut boxed = ManuallyDrop::new(Box::new(err)); @@ -182,17 +169,22 @@ impl<'scope, 'env> Scope<'scope, 'env> { let this = SendPtr::new_const(self).unwrap(); - let job = Box::new(HeapJob::new(move || unsafe { - _ = f(this.as_ref()); - this.as_unchecked_ref().job_counter.decrement(); - })) - .into_boxed_job(); + let job = Job::from_heapjob( + Box::new(HeapJob::new(move || unsafe { + use std::panic::{AssertUnwindSafe, catch_unwind}; + if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(this.as_ref()))) { + this.as_unchecked_ref().panicked(payload); + } + this.as_unchecked_ref().job_counter.decrement(); + })), + ptr::null(), + ); tracing::trace!("allocated heapjob"); WorkerThread::current_ref() .expect("spawn is run in workerthread.") - .push_front(job as _); + .push_front(job.as_ptr()); tracing::trace!("leaked heapjob"); } @@ -233,13 +225,14 @@ impl<'scope, 'env> Scope<'scope, 'env> { let _guard = DropGuard::new(move || { this.as_unchecked_ref().job_counter.decrement(); }); + // TODO: handle panics here f(this.as_ref()).await } }; let schedule = move |runnable: Runnable| { #[align(8)] - unsafe fn harness(this: *const (), job: *const Job) { + unsafe fn harness(this: *const (), job: *const JobSender, _: *const WorkerLatch) { unsafe { let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); @@ -250,12 +243,16 @@ impl<'scope, 'env> Scope<'scope, 'env> { } } - let job = Box::new(Job::new(harness, runnable.into_raw())); + let job = Box::into_raw(Box::new(Job::from_harness( + harness, + runnable.into_raw(), + ptr::null(), + ))); // casting into Job<()> here WorkerThread::current_ref() .expect("spawn_async_internal is run in workerthread.") - .push_front(Box::into_raw(job) as _); + .push_front(job); }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; @@ -291,7 +288,7 @@ impl<'scope, 'env> Scope<'scope, 'env> { unsafe fn from_context(context: Arc) -> Self { Self { context, - job_counter: CountLatch::new(MutexLatch::new()), + job_counter: CountLatch::new(ptr::null()), panic: AtomicPtr::new(ptr::null_mut()), _scope: PhantomData, _env: PhantomData, @@ -309,7 +306,8 @@ mod tests { use crate::ThreadPool; #[test] - fn spawn() { + #[cfg_attr(not(miri), traced_test)] + fn scope_spawn_sync() { let pool = ThreadPool::new_with_threads(1); let count = Arc::new(AtomicU8::new(0)); @@ -323,7 +321,8 @@ mod tests { } #[test] - fn join() { + #[cfg_attr(not(miri), traced_test)] + fn scope_join_one() { let pool = ThreadPool::new_with_threads(1); let a = pool.scope(|scope| { @@ -335,7 +334,8 @@ mod tests { } #[test] - fn join_many() { + #[cfg_attr(not(miri), traced_test)] + fn scope_join_many() { let pool = ThreadPool::new_with_threads(1); fn sum<'scope, 'env>(scope: &'scope Scope<'scope, 'env>, n: usize) -> usize { @@ -356,7 +356,8 @@ mod tests { } #[test] - fn spawn_future() { + #[cfg_attr(not(miri), traced_test)] + fn scope_spawn_future() { let pool = ThreadPool::new_with_threads(1); let mut x = 0; pool.scope(|scope| { @@ -371,7 +372,8 @@ mod tests { } #[test] - fn spawn_many() { + #[cfg_attr(not(miri), traced_test)] + fn scope_spawn_many() { let pool = ThreadPool::new_with_threads(1); let count = Arc::new(AtomicU8::new(0)); diff --git a/distaff/src/threadpool.rs b/distaff/src/threadpool.rs index 86d2af0..8185750 100644 --- a/distaff/src/threadpool.rs +++ b/distaff/src/threadpool.rs @@ -58,8 +58,8 @@ mod tests { use super::*; #[test] - #[traced_test] - fn spawn_borrow() { + #[cfg_attr(not(miri), traced_test)] + fn pool_spawn_borrow() { let pool = ThreadPool::new_with_threads(1); let mut x = 0; pool.scope(|scope| { @@ -72,7 +72,8 @@ mod tests { } #[test] - fn spawn_future() { + #[cfg_attr(not(miri), traced_test)] + fn pool_spawn_future() { let pool = ThreadPool::new_with_threads(1); let mut x = 0; let task = pool.scope(|scope| { @@ -88,7 +89,8 @@ mod tests { } #[test] - fn join() { + #[cfg_attr(not(miri), traced_test)] + fn pool_join() { let pool = ThreadPool::new_with_threads(1); let (a, b) = pool.join(|| 3 + 4, || 5 * 6); assert_eq!(a, 7); diff --git a/distaff/src/util.rs b/distaff/src/util.rs index ddc9ec1..0c24452 100644 --- a/distaff/src/util.rs +++ b/distaff/src/util.rs @@ -104,6 +104,7 @@ impl SendPtr { /// as the pointer. /// The pointer must be aligned to `BITS` bits, i.e. `align_of::() >= 2^BITS`. #[repr(transparent)] +#[derive(Debug)] pub struct TaggedAtomicPtr { ptr: AtomicPtr<()>, _pd: PhantomData, @@ -138,6 +139,19 @@ impl TaggedAtomicPtr { self.ptr.load(order).addr() & Self::mask() } + pub fn fetch_or_tag(&self, tag: usize, order: Ordering) -> usize { + let mask = Self::mask(); + let old_ptr = self.ptr.fetch_or(tag & mask, order); + old_ptr.addr() & mask + } + + /// returns the tag and clears it + pub fn take_tag(&self, order: Ordering) -> usize { + let mask = Self::mask(); + let old_ptr = self.ptr.fetch_and(!mask, order); + old_ptr.addr() & mask + } + /// returns tag #[inline(always)] fn compare_exchange_tag_inner( @@ -432,10 +446,29 @@ impl Send { } } +pub fn unwrap_or_panic(result: std::thread::Result) -> T { + match result { + Ok(value) => value, + Err(payload) => std::panic::resume_unwind(payload), + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn tagged_ptr_zero_tag() { + let ptr = Box::into_raw(Box::new(42u32)); + let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + unsafe { + _ = Box::from_raw(ptr); + } + } + #[test] fn tagged_ptr_exchange() { let ptr = Box::into_raw(Box::new(42u32)); diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index 4e10118..ea412f9 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -1,5 +1,6 @@ use std::{ cell::{Cell, UnsafeCell}, + hint::cold_path, ptr::NonNull, sync::Arc, time::Duration, @@ -9,52 +10,34 @@ use crossbeam_utils::CachePadded; use crate::{ context::{Context, Heartbeat}, - job::{Job, JobList, JobResult}, - latch::{AsCoreLatch, CoreLatch, Probe}, + heartbeat::OwnedHeartbeatReceiver, + job::{JobQueue as JobList, JobResult, QueuedJob as Job, QueuedJob, StackJob}, + latch::{AsCoreLatch, CoreLatch, Probe, WorkerLatch}, util::DropGuard, }; pub struct WorkerThread { pub(crate) context: Arc, - pub(crate) index: usize, pub(crate) queue: UnsafeCell, - pub(crate) heartbeat: NonNull>, + pub(crate) heartbeat: OwnedHeartbeatReceiver, pub(crate) join_count: Cell, } -impl Drop for WorkerThread { - fn drop(&mut self) { - // remove the current worker thread from the heartbeat list - self.context.shared().remove_heartbeat(self.index); - - // SAFETY: we removed the heartbeat from the context, so we can safely drop it. - unsafe { - _ = Box::from_non_null(self.heartbeat); - } - } -} - thread_local! { static WORKER: UnsafeCell>> = const { UnsafeCell::new(None) }; } impl WorkerThread { pub fn new_in(context: Arc) -> Self { - let (heartbeat, index) = context.shared().new_heartbeat(); + let heartbeat = context.heartbeats.new_heartbeat(); Self { context, - index, queue: UnsafeCell::new(JobList::new()), heartbeat, join_count: Cell::new(0), } } - - pub(crate) fn heartbeat(&self) -> &CachePadded { - // SAFETY: the heartbeat is always set when the worker thread is created - unsafe { self.heartbeat.as_ref() } - } } impl WorkerThread { @@ -80,53 +63,77 @@ impl WorkerThread { } fn run_inner(&self) { - let mut job = self.context.shared().pop_job(); + let mut job = None; 'outer: loop { - while let Some(j) = job { - self.execute(j); + if let Some(job) = job { + self.execute(job); + } - let mut guard = self.context.shared(); - if guard.should_exit() { - // if the context is stopped, break out of the outer loop which - // will exit the thread. - break 'outer; - } - - // we executed the shared job, now we want to check for any - // local jobs which this job might have spawned. - job = self.pop_front().or_else(|| guard.pop_job()); + if self.context.should_exit() { + // if the context is stopped, break out of the outer loop which + // will exit the thread. + break 'outer; } // no more jobs, wait to be notified of a new job or a heartbeat. - match self.heartbeat().latch.wait_and_reset() { - crate::latch::WakeResult::Wake => { - let mut guard = self.context.shared(); - if guard.should_exit() { - break 'outer; - } - - job = guard.pop_job(); - } - crate::latch::WakeResult::Heartbeat => { - self.tick(); - } - crate::latch::WakeResult::Set => { - // check if we should exit the thread - if self.context.shared().should_exit() { - break 'outer; - } - panic!("this thread shouldn't be woken by a finished job") - } - } + job = self.find_work_or_wait(); } } } impl WorkerThread { + pub(crate) fn find_work(&self) -> Option> { + self.find_work_inner().left() + } + + /// Looks for work in the local queue, then in the shared context, and if no + /// work is found, waits for the thread to be notified of a new job, after + /// which it returns `None`. + /// The caller should then check for `should_exit` to determine if the + /// thread should exit, or look for work again. + pub(crate) fn find_work_or_wait(&self) -> Option> { + match self.find_work_inner() { + either::Either::Left(job) => { + return Some(job); + } + either::Either::Right(mut guard) => { + // no jobs found, wait for a heartbeat or a new job + tracing::trace!("WorkerThread::find_work_or_wait: waiting for new job"); + self.heartbeat.latch().wait_with_lock(&mut guard); + tracing::trace!("WorkerThread::find_work_or_wait: woken up from wait"); + None + } + } + } + + #[inline] + fn find_work_inner( + &self, + ) -> either::Either, parking_lot::MutexGuard<'_, crate::context::Shared>> { + // first check the local queue for jobs + if let Some(job) = self.pop_front() { + tracing::trace!("WorkerThread::find_work_inner: found local job: {:?}", job); + return either::Either::Left(job); + } + + // then check the shared context for jobs + let mut guard = self.context.shared(); + + if let Some(job) = guard.pop_job() { + tracing::trace!("WorkerThread::find_work_inner: found shared job: {:?}", job); + return either::Either::Left(job); + } + + either::Either::Right(guard) + } + #[inline(always)] pub(crate) fn tick(&self) { - if self.heartbeat().is_pending() { - tracing::trace!("received heartbeat, thread id: {:?}", self.index); + if self.heartbeat.take() { + tracing::trace!( + "received heartbeat, thread id: {:?}", + self.heartbeat.index() + ); self.heartbeat_cold(); } } @@ -134,21 +141,22 @@ impl WorkerThread { #[inline] fn execute(&self, job: NonNull) { self.tick(); - Job::execute(job); + unsafe { Job::execute(job.as_ptr()) }; } #[cold] fn heartbeat_cold(&self) { let mut guard = self.context.shared(); - if !guard.jobs.contains_key(&self.index) { + if !guard.jobs.contains_key(&self.heartbeat.id()) { if let Some(job) = self.pop_back() { + Job::set_shared(unsafe { job.as_ref() }); tracing::trace!("heartbeat: sharing job: {:?}", job); + guard.jobs.insert(self.heartbeat.id(), job); unsafe { - job.as_ref().set_pending(); + // SAFETY: we are holding the lock on the shared context. + self.context.notify_job_shared(); } - guard.jobs.insert(self.index, job); - guard.notify_job_shared(); } } } @@ -234,19 +242,12 @@ impl HeartbeatThread { let mut i = 0; loop { let sleep_for = { - let guard = self.ctx.shared(); - if guard.should_exit() { + if self.ctx.should_exit() { break; } - if let Some((_, heartbeat)) = guard.heartbeats.iter().nth(i) { - unsafe { - heartbeat.as_ref().latch.signal_heartbeat(); - } - } - let num_heartbeats = guard.heartbeats.len(); - - drop(guard); + self.ctx.heartbeats.notify_nth(i); + let num_heartbeats = self.ctx.heartbeats.len(); if i >= num_heartbeats { i = 0; @@ -265,120 +266,100 @@ impl HeartbeatThread { } impl WorkerThread { - #[cold] - fn wait_until_latch_cold(&self, latch: &CoreLatch) { - 'outer: while !latch.probe() { - // process local jobs before locking shared context - while let Some(job) = self.pop_front() { - tracing::trace!("thread {:?} executing local job: {:?}", self.index, job); - unsafe { - job.as_ref().set_pending(); - } - Job::execute(job); - tracing::trace!("thread {:?} finished local job: {:?}", self.index, job); - } - - // take a shared job, if it exists - 'inner: loop { - if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { - tracing::trace!( - "thread {:?} executing shared job: {:?}", - self.index, - shared_job - ); - Job::execute(shared_job); - } - - while !latch.probe() { - tracing::trace!("thread {:?} looking for shared jobs", self.index); - - let job = { - let mut guard = self.context.shared(); - guard.jobs.remove(&self.index).or_else(|| guard.pop_job()) - }; - - match job { - Some(job) => { - tracing::trace!("thread {:?} found job: {:?}", self.index, job); - Job::execute(job); - - continue 'outer; - } - None => { - tracing::trace!("thread {:?} is sleeping", self.index); - - match self.heartbeat().latch.wait_and_reset() { - // why were we woken up? - // 1. the heartbeat thread ticked and set the - // latch, so we should see if we have any work - // to share. - // 2. a job was shared and we were notified, so - // we should execute it. - // 3. the job we were waiting on was completed, - // so we should return it. - crate::latch::WakeResult::Set => { - break 'outer; // we were woken up by a job being set, so we should exit the loop. - } - crate::latch::WakeResult::Wake => { - // skip checking for local jobs, since we - // were woken up to check for shared jobs. - continue 'inner; - } - crate::latch::WakeResult::Heartbeat => { - self.tick(); - continue 'outer; - } - } - // since we were sleeping, the shared job can't be populated, - // so resuming the inner loop is fine. - } - } - } - - break; - } - } - - tracing::trace!( - "thread {:?} finished waiting on latch {:?}", - self.index, - latch - ); - self.heartbeat().latch.as_core_latch().unset(); - return; - } - - pub fn wait_until_job(&self, job: &Job, latch: &CoreLatch) -> Option> { + pub fn wait_until_queued_job( + &self, + job: *const QueuedJob, + ) -> Option> { + let recv = unsafe { (*job).as_receiver::() }; // we've already checked that the job was popped from the queue // check if shared job is our job - if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { - if core::ptr::eq(shared_job.as_ptr(), job as *const Job as _) { + if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) { + if core::ptr::eq(shared_job.as_ptr(), job as *const Job as _) { // this is the job we are looking for, so we want to // short-circuit and call it inline return None; } else { // this isn't the job we are looking for, but we still need to // execute it - Job::execute(shared_job); + unsafe { Job::execute(shared_job.as_ptr()) }; } } // do the usual thing and wait for the job's latch - if !latch.probe() { - self.wait_until_latch_cold(latch); - } + loop { + match recv.poll() { + Some(t) => { + return Some(t); + } + None => { + cold_path(); - Some(job.wait()) + // check local jobs before locking shared context + if let Some(job) = self.find_work_or_wait() { + tracing::trace!( + "thread {:?} executing local job: {:?}", + self.heartbeat.index(), + job + ); + unsafe { + Job::execute(job.as_ptr()); + } + tracing::trace!( + "thread {:?} finished local job: {:?}", + self.heartbeat.index(), + job + ); + continue; + } + } + } + } } pub fn wait_until_latch(&self, latch: &L) where - L: AsCoreLatch, + L: Probe, { - let latch = latch.as_core_latch(); if !latch.probe() { - tracing::trace!("thread {:?} waiting on latch {:?}", self.index, latch); - self.wait_until_latch_cold(latch) + tracing::trace!("thread {:?} waiting on latch", self.heartbeat.index()); + self.wait_until_latch_cold(latch); + } + } + + #[cold] + fn wait_until_latch_cold(&self, latch: &L) + where + L: Probe, + { + if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) { + tracing::trace!( + "thread {:?} reclaiming shared job: {:?}", + self.heartbeat.index(), + shared_job + ); + unsafe { Job::execute(shared_job.as_ptr()) }; + } + + // do the usual thing and wait for the job's latch + // do the usual thing??? chatgipity really said this.. + while !latch.probe() { + // check local jobs before locking shared context + if let Some(job) = self.find_work_or_wait() { + tracing::trace!( + "thread {:?} executing local job: {:?}", + self.heartbeat.index(), + job + ); + unsafe { + Job::execute(job.as_ptr()); + } + tracing::trace!( + "thread {:?} finished local job: {:?}", + self.heartbeat.index(), + job + ); + continue; + } } } } diff --git a/examples/join.rs b/examples/join.rs index 5cc1a07..9a2c31b 100644 --- a/examples/join.rs +++ b/examples/join.rs @@ -86,7 +86,6 @@ fn join_distaff() { let sum = sum(&tree, tree.root().unwrap(), s); sum }); - eprintln!("sum: {sum}"); std::hint::black_box(sum); } }