diff --git a/benches/join.rs b/benches/join.rs index b12080c..ba345da 100644 --- a/benches/join.rs +++ b/benches/join.rs @@ -56,7 +56,7 @@ mod tree { } } -const TREE_SIZE: usize = 16; +const TREE_SIZE: usize = 13; #[bench] fn join_melange(b: &mut Bencher) { @@ -93,6 +93,7 @@ fn join_melange(b: &mut Bencher) { #[bench] fn join_praetor(b: &mut Bencher) { + tracing_subscriber::fmt().with_test_writer().init(); use executor::praetor::Scope; let pool = executor::praetor::ThreadPool::global(); @@ -121,6 +122,7 @@ fn join_praetor(b: &mut Bencher) { #[bench] fn join_sync(b: &mut Bencher) { + tracing_subscriber::fmt().with_test_writer().init(); let tree = tree::Tree::new(TREE_SIZE, 1u32); fn sum(tree: &tree::Tree, node: usize) -> u32 { diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index 72cab6f..ae267a1 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -426,12 +426,46 @@ mod job { } } - #[derive(Debug)] + // for some reason I confused head and tail here and the list is something like this: + // tail <-> job1 <-> job2 <-> ... <-> head pub struct JobList { head: Box, tail: Box, } + impl Debug for JobList { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JobList") + .field("head", &self.head) + .field("tail", &self.tail) + .field_with("jobs", |f| { + let mut list = f.debug_list(); + + // SAFETY: head always has prev + let mut job = unsafe { self.head().as_ref().link_mut().prev.unwrap() }; + loop { + if job == self.tail() { + break; + } + + let job_ref = unsafe { job.as_ref() }; + list.entry(&job_ref); + + // SAFETY: we are iterating over the linked list + if let Some(next) = unsafe { job_ref.link_mut().prev } { + job = next; + } else { + tracing::trace!("prev job is none?"); + break; + }; + } + + list.finish() + }) + .finish() + } + } + impl JobList { pub fn new() -> JobList { let head = Box::new(Job::empty()); @@ -743,6 +777,7 @@ mod job { return JobResult::new(result); } else { // spin until lock is released. + tracing::trace!("spin-waiting for job: {:?}", self); spin.spin(); } } @@ -779,6 +814,12 @@ mod job { } pub fn execute(job: NonNull) { + tracing::trace!( + "thread {:?}: executing job: {:?}", + std::thread::current().name(), + job + ); + // SAFETY: self is non-null unsafe { let this = job.as_ref(); @@ -990,9 +1031,10 @@ use async_task::Runnable; use crossbeam::utils::CachePadded; use job::*; use parking_lot::{Condvar, Mutex}; +use parking_lot_core::SpinWait; use util::{DropGuard, SendPtr}; -use crate::latch::{AtomicLatch, LatchRef, NopLatch}; +use crate::latch::{AtomicLatch, LatchRef, NopLatch, Probe}; #[derive(Debug, Default)] pub struct JobCounter { @@ -1212,11 +1254,12 @@ impl WorkerThread { if !guard.jobs.contains_key(&self.index) { if let Some(job) = self.pop_back() { + tracing::trace!("heartbeat: sharing job: {:?}", job); unsafe { job.as_ref().set_pending(); } guard.jobs.insert(self.index, job); - self.context.shared_job.notify_one(); + self.context.notify_shared_job(); } } @@ -1228,69 +1271,72 @@ impl WorkerThread { // does this optimise? assert!(!latch.probe()); - 'outer: while !latch.probe() { - // take the shared job, if it exists - if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) { - self.execute(shared_job); - } - - while !latch.probe() { - let mut guard = self.context.shared.lock(); - - match guard.jobs.pop_first().map(|(_, job)| job) { - Some(job) => { - drop(guard); - self.execute(job); - continue 'outer; - } - None => { - // TODO: spin2win - self.context.shared_job.wait(&mut guard); - } - } - } - } + self.wait_until_predicate(|| latch.probe()) } pub fn wait_until_latch(&self, latch: &Latch) { if !latch.probe() { - self.wait_until_latch_cold(latch); + self.wait_until_latch_cold(latch) } } - pub fn wait_until_job(&self, job: &Job) -> Option> { - // take the shared job and check if it is our job - let shared_job = self.context.shared.lock().jobs.remove(&self.index); + #[inline] + fn wait_until_predicate(&self, pred: F) + where + F: Fn() -> bool, + { + 'outer: while !pred() { + // take a shared job, if it exists + if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) { + self.execute(shared_job); + } - if let Some(ptr) = shared_job { - if ptr.as_ptr() == &*job as *const _ as *mut _ { - // we can more efficiently run the job inline - return None; - } else { - // execute this job since it hasn't been taken - self.execute(ptr); + // process local jobs before locking shared context + while let Some(job) = self.pop_front() { + unsafe { + job.as_ref().set_pending(); + } + self.execute(job); + } + + while !pred() { + let mut guard = self.context.shared.lock(); + let mut spin = SpinWait::new(); + + match guard.pop_job() { + Some(job) => { + drop(guard); + self.execute(job); + + continue 'outer; + } + None => { + // TODO: spin2win + tracing::trace!("waiting for shared job, thread id: {:?}", self.index); + + // TODO: wait on latch? if we have something that can + // signal being done, e.g. can be waited on instead of + // shared jobs, we should wait on it instead, but we + // would also want to receive shared jobs still? + // self.context.shared_job.wait(&mut guard); + // if spin.spin() { + // // wait for more shared jobs. + // // self.context.shared_job.wait(&mut guard); + // return; + // } + std::thread::yield_now(); + } + } } } + return; + } - while job.state() != JobState::Finished as u8 { - let Some(job) = self - .context - .shared - .lock() - .jobs - .pop_first() - .map(|(_, job)| job) - // .or_else(|| { - // self.pop_front().inspect(|job| unsafe { - // job.as_ref().set_pending(); - // }) - // }) - else { - break; - }; - - self.execute(job); - } + pub fn wait_until_job(&self, job: &Job) -> Option> { + self.wait_until_predicate(|| { + // check if job is finished + job.state() == JobState::Finished as u8 + }); // someone else has this job and is working on it, // while job isn't done, suspend thread. @@ -1308,10 +1354,14 @@ where impl<'scope> Scope<'scope> { fn wait_for_jobs(&self) { - tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); - let thread = WorkerThread::current_ref().unwrap(); + tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); + tracing::trace!("thread id: {:?}, jobs: {:?}", thread.index, unsafe { + thread.queue.as_ref_unchecked() + }); + thread.wait_until_latch(&self.job_counter); + unsafe { self.job_counter.wait() }; } pub fn scope(f: F) -> R @@ -1325,7 +1375,7 @@ impl<'scope> Scope<'scope> { }) } - pub fn scope_with_context(context: Arc, f: F) -> R + fn scope_with_context(context: Arc, f: F) -> R where F: FnOnce(&Self) -> R + Send, R: Send, @@ -1416,7 +1466,7 @@ impl<'scope> Scope<'scope> { where F: FnOnce(&Scope<'scope>) + Send, { - WorkerThread::with_in(&self.context, |worker| { + self.context.run_in_worker(|worker| { self.job_counter.increment(); let this = SendPtr::new_const(self).unwrap(); @@ -1522,13 +1572,14 @@ impl<'scope> Scope<'scope> { A: FnOnce(&Self) -> RA + Send, B: FnOnce(&Self) -> RB + Send, { - // let count = self.join_count.get(); - // self.join_count.set(count.wrapping_add(1) % TIMES); - let count = self - .join_count - .update(Ordering::Relaxed, Ordering::Relaxed, |n| { - n.wrapping_add(1) % TIMES - }); + let count = self.join_count.load(Ordering::Relaxed); + self.join_count + .store(count.wrapping_add(1) % TIMES, Ordering::Relaxed); + // let count = self + // .join_count + // .update(Ordering::Relaxed, Ordering::Relaxed, |n| { + // n.wrapping_add(1) % TIMES + // }); if count == 1 { self.join_heartbeat(a, b) @@ -1562,13 +1613,23 @@ impl<'scope> Scope<'scope> { let job = a.as_job(); worker.push_front(&job); - let rb = b(self); + use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; + let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) { + Ok(val) => val, + Err(payload) => { + cold_path(); + // if b panicked, we need to wait for a to finish + worker.wait_until_job::(unsafe { job.transmute_ref::() }); + resume_unwind(payload); + } + }; let ra = if job.state() == JobState::Empty as u8 { unsafe { job.unlink(); } + // a is allowed to panic here, because we already finished b. unsafe { a.unwrap()() } } else { match worker.wait_until_job::(unsafe { job.transmute_ref::() }) { @@ -1592,16 +1653,17 @@ impl<'scope> Scope<'scope> { } } -// #[allow(dead_code)] -// pub fn join(a: A, b: B) -> (RA, RB) -// where -// RA: Send, -// RB: Send, -// A: FnOnce() -> RA + Send, -// B: FnOnce() -> RB + Send, -// { -// Scope::with(|scope| scope.join(|_| a(), |_| b())) -// } +/// run two closures potentially in parallel, in the global threadpool. +#[allow(dead_code)] +pub fn join(a: A, b: B) -> (RA, RB) +where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, +{ + Scope::scope(|scope| scope.join(|_| a(), |_| b())) +} pub struct ThreadPool { context: Arc, @@ -1648,7 +1710,7 @@ unsafe impl Send for SharedContext {} impl SharedContext { fn new_heartbeat(&mut self) -> (Arc>, usize) { let index = self.heartbeats_id; - self.heartbeats_id.checked_add(1).unwrap(); + self.heartbeats_id = self.heartbeats_id.checked_add(1).unwrap(); let is_set = Arc::new(CachePadded::new(AtomicBool::new(false))); let weak = Arc::downgrade(&is_set); @@ -1693,14 +1755,21 @@ impl Context { // let num_threads = 2; let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1)); - for _ in 0..num_threads { + for i in 0..num_threads { let ctx = this.clone(); let barrier = barrier.clone(); - std::thread::spawn(|| worker(ctx, barrier)); + std::thread::Builder::new() + .name(format!("{:?}-worker-{}", Arc::as_ptr(&this), i)) + .spawn(|| worker(ctx, barrier)) + .expect("Failed to spawn worker thread"); } let ctx = this.clone(); - std::thread::spawn(|| heartbeat_worker(ctx)); + + std::thread::Builder::new() + .name(format!("{:?}-heartbeat", Arc::as_ptr(&this))) + .spawn(|| heartbeat_worker(ctx)) + .expect("Failed to spawn heartbeat thread"); barrier.wait(); @@ -1714,10 +1783,14 @@ impl Context { pub fn inject_job(&self, job: NonNull) { let mut guard = self.shared.lock(); guard.injected_jobs.push(job); + self.notify_shared_job(); + } + + fn notify_shared_job(&self) { self.shared_job.notify_one(); } - /// Runs closure in this context, processing the worker's jobs while waiting for the result. + /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. fn run_in_worker_cross(self: &Arc, worker: &WorkerThread, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, @@ -1738,8 +1811,10 @@ impl Context { ); let job = job.as_job(); + job.set_pending(); self.inject_job(Into::into(&job)); + // no need to wait for latch to signal, because we're waiting on the job anyway worker.wait_until_latch(&latch); let t = unsafe { job.transmute_ref::().wait().into_result() }; @@ -1769,6 +1844,7 @@ impl Context { ); let job = job.as_job(); + job.set_pending(); self.inject_job(Into::into(&job)); latch.wait(); @@ -1788,12 +1864,19 @@ impl Context { Some(worker) => { // check if worker is in the same context if Arc::ptr_eq(&worker.context, self) { + tracing::trace!("run_in_worker: current thread"); f(worker) } else { + // current thread is a worker for a different context + tracing::trace!("run_in_worker: cross-context"); self.run_in_worker_cross(worker, f) } } - None => self.run_in_worker_cold(f), + None => { + // current thread is not a worker for any context + tracing::trace!("run_in_worker: inject into context"); + self.run_in_worker_cold(f) + } } } } @@ -1826,18 +1909,18 @@ fn worker(ctx: Arc, barrier: Arc) { } let _guard = DropGuard::new(|| unsafe { + tracing::trace!("worker thread dropping {:?}", std::thread::current()); WorkerThread::drop_in_place_and_dealloc(WorkerThread::unset_current().unwrap()); }); - let scope = WorkerThread::current_ref().unwrap(); + let worker = WorkerThread::current_ref().unwrap(); barrier.wait(); let mut job = ctx.shared.lock().pop_job(); loop { - tracing::trace!("worker({:?}): new job {:?}", std::thread::current(), job); if let Some(job) = job { - scope.execute(job); + worker.execute(job); } let mut guard = ctx.shared.lock(); diff --git a/src/praetor/tests.rs b/src/praetor/tests.rs index 0e88370..ba2f049 100644 --- a/src/praetor/tests.rs +++ b/src/praetor/tests.rs @@ -3,6 +3,8 @@ use std::{ pin::{pin, Pin}, }; +use tracing_test::traced_test; + use super::{util::TaggedAtomicPtr, *}; fn pin_ptr(pin: &Pin<&mut T>) -> NonNull { @@ -446,6 +448,35 @@ fn join() { eprintln!("x: {x}"); } +#[test] +#[traced_test] +fn join_many() { + use crate::util::tree::{Tree, TREE_SIZE}; + + let pool = ThreadPool::new(); + + let tree = Tree::new(16, 1u32); + + fn sum(tree: &Tree, node: usize, scope: &Scope) -> u32 { + let node = tree.get(node); + let (l, r) = scope.join_heartbeat( + |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), + |s| { + node.right + .map(|node| sum(tree, node, s)) + .unwrap_or_default() + }, + ); + + // eprintln!("node: {node:?}, l: {l}, r: {r}"); + + node.leaf + l + r + } + + let sum = pool.scope(|s| sum(&tree, tree.root().unwrap(), s)); + eprintln!("{sum}"); +} + #[test] fn rebox() { struct A(u32); diff --git a/src/util.rs b/src/util.rs index 5678402..9036839 100644 --- a/src/util.rs +++ b/src/util.rs @@ -65,3 +65,71 @@ impl XorShift64Star { (self.next() % n as u64) as usize } } + +#[macro_export] +macro_rules! cfg_miri { + ( + @miri => { $($tokens:tt)* }$(,)? + _ => { $($tokens2:tt)* } + ) => { + #[cfg(miri)] + { + $($tokens)* + } + #[cfg(not(miri))] + { + $($tokens2)* + } + }; +} + +pub mod tree { + + pub struct Tree { + nodes: Box<[Node]>, + root: Option, + } + + #[derive(Debug, Clone)] + pub struct Node { + pub leaf: T, + pub left: Option, + pub right: Option, + } + + impl Tree { + pub fn new(depth: usize, t: T) -> Tree + where + T: Copy, + { + let mut nodes = Vec::with_capacity((0..depth).sum()); + let root = Self::build_node(&mut nodes, depth, t); + Self { + nodes: nodes.into_boxed_slice(), + root: Some(root), + } + } + + pub fn root(&self) -> Option { + self.root + } + + pub fn get(&self, index: usize) -> &Node { + &self.nodes[index] + } + + pub fn build_node(nodes: &mut Vec>, depth: usize, t: T) -> usize + where + T: Copy, + { + let node = Node { + leaf: t, + left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), + right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), + }; + nodes.push(node); + nodes.len() - 1 + } + } + pub const TREE_SIZE: usize = 16; +}