diff --git a/benches/join.rs b/benches/join.rs index 2321934..749d36d 100644 --- a/benches/join.rs +++ b/benches/join.rs @@ -56,7 +56,7 @@ mod tree { } } -const TREE_SIZE: usize = 16; +const TREE_SIZE: usize = 8; #[bench] fn join_melange(b: &mut Bencher) { @@ -210,10 +210,12 @@ fn join_distaff(b: &mut Bencher) { } b.iter(move || { - pool.scope(|s| { + let sum = pool.scope(|s| { let sum = sum(&tree, tree.root().unwrap(), s); - // eprintln!("{sum}"); assert_ne!(sum, 0); + sum }); + eprintln!("{sum}"); }); + eprintln!("Done with distaff join"); } diff --git a/distaff/src/context.rs b/distaff/src/context.rs index bfae9e2..2cc55a0 100644 --- a/distaff/src/context.rs +++ b/distaff/src/context.rs @@ -16,6 +16,7 @@ use crate::{ heartbeat::HeartbeatList, job::{HeapJob, JobSender, QueuedJob as Job, StackJob}, latch::{AsCoreLatch, MutexLatch, NopLatch, WorkerLatch}, + util::DropGuard, workerthread::{HeartbeatThread, WorkerThread}, }; @@ -80,6 +81,8 @@ impl Context { } pub fn new_with_threads(num_threads: usize) -> Arc { + tracing::trace!("Creating context with {} threads", num_threads); + let this = Arc::new(Self { shared: Mutex::new(Shared { jobs: BTreeMap::new(), @@ -90,8 +93,6 @@ impl Context { heartbeats: HeartbeatList::new(), }); - tracing::trace!("Creating thread pool with {} threads", num_threads); - // Create a barrier to synchronize the worker threads and the heartbeat thread let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2)); @@ -104,8 +105,7 @@ impl Context { .spawn(move || { let worker = Box::new(WorkerThread::new_in(ctx)); - barrier.wait(); - worker.run(); + worker.run(barrier); }) .expect("Failed to spawn worker thread"); } @@ -117,8 +117,7 @@ impl Context { std::thread::Builder::new() .name("heartbeat-thread".to_string()) .spawn(move || { - barrier.wait(); - HeartbeatThread::new(ctx).run(); + HeartbeatThread::new(ctx).run(barrier); }) .expect("Failed to spawn heartbeat thread"); } @@ -234,6 +233,9 @@ impl Context { T: Send, F: FnOnce(&WorkerThread) -> T + Send, { + let _guard = DropGuard::new(|| { + tracing::trace!("run_in_worker: finished"); + }); match WorkerThread::current_ref() { Some(worker) => { // check if worker is in the same context diff --git a/distaff/src/job.rs b/distaff/src/job.rs index 70d3d26..5d55c4b 100644 --- a/distaff/src/job.rs +++ b/distaff/src/job.rs @@ -1054,6 +1054,7 @@ const ERROR: usize = 1 << 1; impl JobSender { #[tracing::instrument(level = "trace", skip_all)] pub fn send(&self, result: std::thread::Result, mutex: *const WorkerLatch) { + tracing::trace!("sending job ({:?}) result", &raw const *self); // We want to lock here so that we can be sure that we wake the worker // only if it was waiting, and not immediately after having received the // result and waiting for further work: diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs index 17f843a..cdcb605 100644 --- a/distaff/src/latch.rs +++ b/distaff/src/latch.rs @@ -460,7 +460,7 @@ impl WorkerLatch { tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other"); // because `other` is logically unlocked, we swap it with `other2` and then forget `other2` - core::mem::swap(&mut *other2, &mut *other); + core::mem::swap(&mut other2, other); core::mem::forget(other2); let mut guard = self.mutex.lock(); @@ -546,11 +546,12 @@ mod tests { tracing::info!("Thread waiting on latch"); latch.wait_with_lock(&mut guard); - count.fetch_add(1, Ordering::Relaxed); + count.fetch_add(1, Ordering::SeqCst); tracing::info!("Thread woke up from latch"); barrier.wait(); + barrier.wait(); tracing::info!("Thread finished waiting on barrier"); - count.fetch_add(1, Ordering::Relaxed); + count.fetch_add(1, Ordering::SeqCst); } }); @@ -566,17 +567,18 @@ mod tests { latch.wake(); tracing::info!("Main thread woke up thread"); } - assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0"); + assert_eq!(count.load(Ordering::SeqCst), 0, "Count should still be 0"); barrier.wait(); assert_eq!( - count.load(Ordering::Relaxed), + count.load(Ordering::SeqCst), 1, "Count should be 1 after waking up" ); + barrier.wait(); thread.join().expect("Thread should join successfully"); assert_eq!( - count.load(Ordering::Relaxed), + count.load(Ordering::SeqCst), 2, "Count should be 2 after thread has finished" ); @@ -645,7 +647,7 @@ mod tests { } #[test] - #[traced_test] + #[cfg_attr(not(miri), traced_test)] fn mutex_latch() { let latch = Arc::new(MutexLatch::new()); assert!(!latch.probe()); diff --git a/distaff/src/lib.rs b/distaff/src/lib.rs index 4055b9a..c2f0afd 100644 --- a/distaff/src/lib.rs +++ b/distaff/src/lib.rs @@ -8,6 +8,7 @@ box_as_ptr, box_vec_non_null, strict_provenance_atomic_ptr, + likely_unlikely, let_chains )] diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs index 813736e..f7d7293 100644 --- a/distaff/src/scope.rs +++ b/distaff/src/scope.rs @@ -184,13 +184,10 @@ impl<'scope, 'env> Scope<'scope, 'env> { ptr::null(), ); - tracing::trace!("allocated heapjob"); - - WorkerThread::current_ref() - .expect("spawn is run in workerthread.") - .push_front(job.as_ptr()); - - tracing::trace!("leaked heapjob"); + self.context.inject_job(job); + // WorkerThread::current_ref() + // .expect("spawn is run in workerthread.") + // .push_front(job.as_ptr()); } pub fn spawn_future(&'scope self, future: F) -> async_task::Task @@ -247,16 +244,17 @@ impl<'scope, 'env> Scope<'scope, 'env> { } } - let job = Box::into_raw(Box::new(Job::from_harness( + let job = Box::into_non_null(Box::new(Job::from_harness( harness, runnable.into_raw(), ptr::null(), ))); // casting into Job<()> here - WorkerThread::current_ref() - .expect("spawn_async_internal is run in workerthread.") - .push_front(job); + self.context.inject_job(job); + // WorkerThread::current_ref() + // .expect("spawn_async_internal is run in workerthread.") + // .push_front(job); }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index dccb48d..a3b525a 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -2,7 +2,7 @@ use std::{ cell::{Cell, UnsafeCell}, hint::cold_path, ptr::NonNull, - sync::Arc, + sync::{Arc, Barrier}, time::Duration, }; @@ -42,7 +42,7 @@ impl WorkerThread { impl WorkerThread { #[tracing::instrument(level = "trace", skip_all)] - pub fn run(self: Box) { + pub fn run(self: Box, barrier: Arc) { let this = Box::into_raw(self); unsafe { Self::set_current(this); @@ -56,6 +56,7 @@ impl WorkerThread { tracing::trace!("WorkerThread::run: starting worker thread"); + barrier.wait(); unsafe { (&*this).run_inner(); } @@ -106,6 +107,34 @@ impl WorkerThread { tracing::trace!("WorkerThread::find_work_or_wait: waiting for new job"); self.heartbeat.latch().wait_with_lock(&mut guard); tracing::trace!("WorkerThread::find_work_or_wait: woken up from wait"); + + None + } + } + } + + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) fn find_work_or_wait_unless(&self, mut pred: F) -> Option> + where + F: FnMut(&mut crate::context::Shared) -> bool, + { + match self.find_work_inner() { + either::Either::Left(job) => { + return Some(job); + } + either::Either::Right(mut guard) => { + // check the predicate while holding the lock + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + // this is very important, because the lock must be held when + // notifying us of the result of a job we scheduled. + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + if !pred(std::ops::DerefMut::deref_mut(&mut guard)) { + // no jobs found, wait for a heartbeat or a new job + tracing::trace!("WorkerThread::find_work_or_wait_unless: waiting for new job"); + self.heartbeat.latch().wait_with_lock(&mut guard); + tracing::trace!("WorkerThread::find_work_or_wait_unless: woken up from wait"); + } + None } } @@ -146,8 +175,8 @@ impl WorkerThread { #[inline] #[tracing::instrument(level = "trace", skip(self))] fn execute(&self, job: NonNull) { - self.tick(); unsafe { Job::execute(job.as_ptr()) }; + self.tick(); } #[cold] @@ -243,8 +272,9 @@ impl HeartbeatThread { } #[tracing::instrument(level = "trace", skip(self))] - pub fn run(self) { + pub fn run(self, barrier: Arc) { tracing::trace!("new heartbeat thread {:?}", std::thread::current()); + barrier.wait(); let mut i = 0; loop { @@ -282,6 +312,10 @@ impl WorkerThread { // we've already checked that the job was popped from the queue // check if shared job is our job + // skip checking if the job hasn't yet been claimed, because the + // overhead of waking a thread is so much bigger that it might never get + // the chance to actually claim it. + // if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) { // if core::ptr::eq(shared_job.as_ptr(), job as *const Job as _) { // // this is the job we are looking for, so we want to @@ -306,34 +340,22 @@ impl WorkerThread { // } // } - loop { - match recv.poll() { - Some(t) => { - return Some(t); - } - None => { - cold_path(); + let mut out = recv.poll(); - // check local jobs before locking shared context - if let Some(job) = self.find_work_or_wait() { - tracing::trace!( - "thread {:?} executing local job: {:?}", - self.heartbeat.index(), - job - ); - unsafe { - Job::execute(job.as_ptr()); - } - tracing::trace!( - "thread {:?} finished local job: {:?}", - self.heartbeat.index(), - job - ); - continue; - } + while std::hint::unlikely(out.is_none()) { + if let Some(job) = self.find_work_or_wait_unless(|_| { + out = recv.poll(); + out.is_some() + }) { + unsafe { + Job::execute(job.as_ptr()); } } + + out = recv.poll(); } + + out } #[tracing::instrument(level = "trace", skip_all)] @@ -365,20 +387,10 @@ impl WorkerThread { // do the usual thing??? chatgipity really said this.. while !latch.probe() { // check local jobs before locking shared context - if let Some(job) = self.find_work_or_wait() { - tracing::trace!( - "thread {:?} executing local job: {:?}", - self.heartbeat.index(), - job - ); + if let Some(job) = self.find_work_or_wait_unless(|_| latch.probe()) { unsafe { Job::execute(job.as_ptr()); } - tracing::trace!( - "thread {:?} finished local job: {:?}", - self.heartbeat.index(), - job - ); continue; } } diff --git a/examples/join.rs b/examples/join.rs index 48e6073..4691164 100644 --- a/examples/join.rs +++ b/examples/join.rs @@ -62,7 +62,8 @@ fn join_pool(tree_size: usize) { fn join_distaff(tree_size: usize) { use distaff::*; - let pool = ThreadPool::new(); + let pool = ThreadPool::new_with_threads(6); + let tree = Tree::new(tree_size, 1); fn sum<'scope, 'env>(tree: &Tree, node: usize, scope: &'scope Scope<'scope, 'env>) -> u32 { @@ -81,11 +82,13 @@ fn join_distaff(tree_size: usize) { node.leaf + l + r } - let sum = pool.scope(|s| { - let sum = sum(&tree, tree.root().unwrap(), s); - sum - }); - std::hint::black_box(sum); + for _ in 0..1000 { + let sum = pool.scope(|s| { + let sum = sum(&tree, tree.root().unwrap(), s); + sum + }); + std::hint::black_box(sum); + } } fn join_chili(tree_size: usize) { @@ -131,12 +134,17 @@ fn join_rayon(tree_size: usize) { } fn main() { + tracing_subscriber::fmt::init(); // use tracing_subscriber::layer::SubscriberExt; // tracing::subscriber::set_global_default( - // tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()), + // tracing_subscriber::registry() + // .with(tracing_tracy::TracyLayer::default()), // ) // .expect("Failed to set global default subscriber"); + // eprintln!("Press Enter to start profiling..."); + // std::io::stdin().read_line(&mut String::new()).unwrap(); + let size = std::env::args() .nth(2) .and_then(|s| s.parse::().ok()) @@ -158,7 +166,7 @@ fn main() { } } - eprintln!("Done!"); - // wait for user input before exiting + // eprintln!("Done!"); + // // wait for user input before exiting // std::io::stdin().read_line(&mut String::new()).unwrap(); }