From ed4acbfbd7ad6d4e5b603b0073792fea38e4d4df Mon Sep 17 00:00:00 2001 From: Janis Date: Tue, 24 Jun 2025 18:03:23 +0200 Subject: [PATCH] erm....... --- Cargo.toml | 1 + benches/join.rs | 33 ++++ distaff/Cargo.toml | 6 +- distaff/src/context.rs | 134 ++++++++++++++- distaff/src/job.rs | 320 +++++++++++++++++++++++++++++++++++- distaff/src/join.rs | 48 +++++- distaff/src/latch.rs | 178 +++++++++++++++++++- distaff/src/lib.rs | 6 + distaff/src/scope.rs | 241 ++++++++++++++++----------- distaff/src/threadpool.rs | 92 +++++++++++ distaff/src/util.rs | 65 ++++++++ distaff/src/workerthread.rs | 100 +++-------- src/praetor/mod.rs | 2 +- 13 files changed, 1038 insertions(+), 188 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index af5838e..4fe3ac4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,3 +48,4 @@ cfg-if = "1.0.0" [dev-dependencies] async-std = "1.13.0" tracing-test = "0.2.5" +distaff = {path = "distaff"} diff --git a/benches/join.rs b/benches/join.rs index 895e6e2..2321934 100644 --- a/benches/join.rs +++ b/benches/join.rs @@ -184,3 +184,36 @@ fn join_rayon(b: &mut Bencher) { assert_ne!(sum(&tree, tree.root().unwrap()), 0); }); } + +#[bench] +fn join_distaff(b: &mut Bencher) { + use distaff::*; + let pool = ThreadPool::new(); + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum<'scope, 'env>( + tree: &tree::Tree, + node: usize, + scope: &'scope Scope<'scope, 'env>, + ) -> u32 { + let node = tree.get(node); + let (l, r) = scope.join( + |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), + |s| { + node.right + .map(|node| sum(tree, node, s)) + .unwrap_or_default() + }, + ); + + node.leaf + l + r + } + + b.iter(move || { + pool.scope(|s| { + let sum = sum(&tree, tree.root().unwrap(), s); + // eprintln!("{sum}"); + assert_ne!(sum, 0); + }); + }); +} diff --git a/distaff/Cargo.toml b/distaff/Cargo.toml index 5206454..e5ca077 100644 --- a/distaff/Cargo.toml +++ b/distaff/Cargo.toml @@ -13,4 +13,8 @@ tracing = "0.1.40" parking_lot_core = "0.9.10" crossbeam-utils = "0.8.21" -async-task = "4.7.1" \ No newline at end of file +async-task = "4.7.1" + +[dev-dependencies] +tracing-test = "0.2.5" +futures = "0.3" \ No newline at end of file diff --git a/distaff/src/context.rs b/distaff/src/context.rs index 71427a0..051eec4 100644 --- a/distaff/src/context.rs +++ b/distaff/src/context.rs @@ -8,11 +8,12 @@ use std::{ use alloc::collections::BTreeMap; +use async_task::Runnable; use crossbeam_utils::CachePadded; use parking_lot::{Condvar, Mutex}; use crate::{ - job::{Job, StackJob}, + job::{HeapJob, Job, StackJob}, latch::{LatchRef, MutexLatch, WakeLatch}, workerthread::{HeartbeatThread, WorkerThread}, }; @@ -50,10 +51,6 @@ impl Heartbeat { pub fn is_pending(&self) -> bool { self.heartbeat.load(Ordering::Relaxed) == Self::PENDING } - - pub fn is_sleeping(&self) -> bool { - self.heartbeat.load(Ordering::Relaxed) == Self::SLEEPING - } } pub struct Context { @@ -87,6 +84,7 @@ impl Shared { // this is unlikely, so make the function cold? // TODO: profile this if !self.injected_jobs.is_empty() { + // SAFETY: we checked that injected_jobs is not empty unsafe { return Some(self.pop_injected_job()) }; } else { self.jobs.pop_first().map(|(_, job)| job) @@ -105,7 +103,7 @@ impl Shared { impl Context { #[inline] - pub fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> { + pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> { self.shared.lock() } @@ -159,6 +157,17 @@ impl Context { this } + pub fn set_should_exit(&self) { + let mut shared = self.shared.lock(); + shared.should_exit = true; + for (_, heartbeat) in shared.heartbeats.iter() { + if let Some(heartbeat) = heartbeat.upgrade() { + heartbeat.latch.set(); + } + } + self.shared_job.notify_all(); + } + pub fn new() -> Arc { Self::new_with_threads(crate::util::available_parallelism()) } @@ -270,6 +279,66 @@ impl Context { } } +impl Context { + pub fn spawn(self: &Arc, f: F) + where + F: FnOnce() + Send + 'static, + { + let job = Box::new(HeapJob::new(f)).into_boxed_job(); + tracing::trace!("Context::spawn: spawning job: {:?}", job); + unsafe { + (&*job).set_pending(); + self.inject_job(NonNull::new_unchecked(job)); + } + } + + pub fn spawn_future(self: &Arc, future: F) -> async_task::Task + where + F: Future + Send + 'static, + T: Send + 'static, + { + let schedule = move |runnable: Runnable| { + #[align(8)] + unsafe fn harness(this: *const (), job: *const Job) { + unsafe { + let runnable = + Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); + runnable.run(); + + // SAFETY: job was turned into raw + drop(Box::from_raw(job.cast_mut())); + } + } + + let job = Box::new(Job::::new(harness::, runnable.into_raw())); + + // casting into Job<()> here + unsafe { + job.set_pending(); + self.inject_job(NonNull::new_unchecked(Box::into_raw(job) as *mut Job<()>)); + } + }; + + let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; + + runnable.schedule(); + + task + } + + #[allow(dead_code)] + fn spawn_async(self: &Arc, f: Fn) -> async_task::Task + where + Fn: FnOnce() -> Fut + Send + 'static, + Fut: Future + Send + 'static, + T: Send + 'static, + { + let future = async move { f().await }; + + self.spawn_future(future) + } +} + pub fn run_in_worker(f: F) -> T where T: Send, @@ -277,3 +346,56 @@ where { Context::global_context().run_in_worker(f) } + +#[cfg(test)] +mod tests { + use tracing_test::traced_test; + + use super::*; + + #[test] + fn run_in_worker_test() { + let ctx = Context::global_context().clone(); + let result = ctx.run_in_worker(|_| 42); + assert_eq!(result, 42); + } + + #[test] + fn spawn_future_test() { + let ctx = Context::global_context().clone(); + let task = ctx.spawn_future(async { 42 }); + + // Wait for the task to complete + let result = futures::executor::block_on(task); + assert_eq!(result, 42); + } + + #[test] + fn spawn_async_test() { + let ctx = Context::global_context().clone(); + let task = ctx.spawn_async(|| async { 42 }); + + // Wait for the task to complete + let result = futures::executor::block_on(task); + assert_eq!(result, 42); + } + + #[test] + fn spawn_test() { + let ctx = Context::global_context().clone(); + let counter = Arc::new(AtomicU8::new(0)); + let barrier = Arc::new(std::sync::Barrier::new(2)); + + ctx.spawn({ + let counter = counter.clone(); + let barrier = barrier.clone(); + move || { + counter.fetch_add(1, Ordering::SeqCst); + barrier.wait(); + } + }); + + barrier.wait(); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } +} diff --git a/distaff/src/job.rs b/distaff/src/job.rs index ed5830f..85e5bdd 100644 --- a/distaff/src/job.rs +++ b/distaff/src/job.rs @@ -40,6 +40,59 @@ impl JobState { } pub use joblist::JobList; +pub use jobvec::JobVec; + +// replacement for `JobList` that uses a VecDeque instead of a linked list. +mod jobvec { + use std::ptr::NonNull; + + use super::Job; + use alloc::collections::VecDeque; + + #[derive(Debug)] + pub struct JobVec { + jobs: VecDeque>, + } + + impl JobVec { + pub fn new() -> Self { + Self { + jobs: VecDeque::new(), + } + } + + pub fn remove(&mut self, job: &Job) { + // SAFETY: job is guaranteed to be valid and non-null + let job_ptr = unsafe { NonNull::new_unchecked(job as *const Job as _) }; + self.jobs.retain(|j| *j != job_ptr); + } + + pub fn push_front(&mut self, job: *const Job) { + let job_ptr = unsafe { NonNull::new_unchecked(job as _) }; + self.jobs.push_front(job_ptr); + } + pub fn push_back(&mut self, job: *const Job) { + let job_ptr = unsafe { NonNull::new_unchecked(job as _) }; + self.jobs.push_back(job_ptr); + } + + pub fn pop_front(&mut self) -> Option> { + self.jobs.pop_front() + } + + pub fn pop_back(&mut self) -> Option> { + self.jobs.pop_back() + } + + pub fn is_empty(&self) -> bool { + self.jobs.is_empty() + } + + pub fn len(&self) -> usize { + self.jobs.len() + } + } +} mod joblist { use core::{fmt::Debug, ptr::NonNull}; @@ -87,6 +140,12 @@ mod joblist { self.tail } + pub fn remove(&mut self, job: &Job) { + job.unlink(); + + self.job_count -= 1; + } + /// `job` must be valid until it is removed from the list. pub unsafe fn push_front(&mut self, job: *const Job) { self.job_count += 1; @@ -124,8 +183,6 @@ mod joblist { } pub fn pop_front(&mut self) -> Option> { - self.job_count -= 1; - let headlink = unsafe { self.head.as_ref().link_mut() }; // SAFETY: headlink.next is guaranteed to be Some. @@ -139,12 +196,13 @@ mod joblist { headlink.next = Some(next); next_link.prev = Some(self.head); + // decrement job count after having potentially short-circuited + self.job_count -= 1; + Some(job) } pub fn pop_back(&mut self) -> Option> { - self.job_count -= 1; - let taillink = unsafe { self.tail.as_ref().link_mut() }; // SAFETY: taillink.prev is guaranteed to be Some. @@ -158,6 +216,9 @@ mod joblist { taillink.prev = Some(prev); prev_link.next = Some(self.tail); + // decrement job count after having potentially short-circuited + self.job_count -= 1; + Some(job) } @@ -266,8 +327,6 @@ impl Clone for Link { // `Link` is invariant over `T` impl Copy for Link {} -struct Thread; - union ValueOrThis { uninit: (), value: ManuallyDrop>, @@ -385,7 +444,8 @@ impl Job { } /// assumes job is in a `JobList` - pub unsafe fn unlink(&self) { + pub fn unlink(&self) { + // SAFETY: if the job isn't linked, these will operate on a dummy value. unsafe { let mut dummy = None; let Link { prev, next } = *self.link_mut(); @@ -590,6 +650,7 @@ mod stackjob { let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); + tracing::trace!("job completed: {:?}", job); let job = unsafe { &*job.cast::>() }; job.complete(result); @@ -664,3 +725,248 @@ mod heapjob { pub use heapjob::HeapJob; pub use stackjob::StackJob; + +#[cfg(test)] +mod tests { + use crate::latch::{AtomicLatch, LatchRef}; + + use super::*; + + #[test] + fn job_lifecycle() { + let latch = AtomicLatch::new(); + let stack = StackJob::new(|| 3 + 4, LatchRef::new(&latch)); + + let job = stack.as_job::(); + + assert_eq!(job.state(), JobState::Empty as u8); + + job.set_pending(); + assert_eq!(job.state(), JobState::Pending as u8); + + // execute the job + Job::<()>::execute(unsafe { NonNull::new_unchecked(&job as *const Job as _) }); + + // wait for the job to finish + let result = unsafe { job.transmute_ref::().wait() }; + assert_eq!(result.into_result(), 7); + } + + #[test] + fn job_lifecycle_panic() { + let latch = AtomicLatch::new(); + let stack = StackJob::new(|| panic!("test panic"), LatchRef::new(&latch)); + + let job = stack.as_job::(); + + assert_eq!(job.state(), JobState::Empty as u8); + + job.set_pending(); + assert_eq!(job.state(), JobState::Pending as u8); + + // execute the job + Job::<()>::execute(unsafe { NonNull::new_unchecked(&job as *const Job as _) }); + + // wait for the job to finish + let result = unsafe { job.transmute_ref::().wait() }; + assert!(result.into_inner().is_err()); + } + + #[test] + fn joblist_popback() { + let mut list = JobList::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + + unsafe { + list.push_back(job1); + list.push_back(job2); + } + + assert_eq!(list.len(), 2); + + let popped_job = list.pop_back().unwrap(); + assert_eq!(popped_job.as_ptr(), job2 as _); + + let popped_job = list.pop_back().unwrap(); + assert_eq!(popped_job.as_ptr(), job1 as _); + + assert!(list.is_empty()); + } + + #[test] + fn joblist_popfront() { + let mut list = JobList::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + + unsafe { + list.push_front(job1); + list.push_front(job2); + } + + assert_eq!(list.len(), 2); + + let popped_job = list.pop_front().unwrap(); + assert_eq!(popped_job.as_ptr(), job2 as _); + + let popped_job = list.pop_front().unwrap(); + assert_eq!(popped_job.as_ptr(), job1 as _); + + assert!(list.is_empty()); + } + + #[test] + fn joblist_unlink_middle() { + let mut list = JobList::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + let job3 = Box::into_raw(Box::new(Job::::empty())); + + unsafe { + list.push_back(job1); + list.push_back(job2); + list.push_back(job3); + } + + assert_eq!(list.len(), 3); + + // Unlink the middle job (job2) + unsafe { + (&*job2).unlink(); + } + + // Check that job1 and job3 are still in the list + let popped_job1 = list.pop_front().unwrap(); + assert_eq!(popped_job1.as_ptr(), job1 as _); + + let popped_job3 = list.pop_front().unwrap(); + assert_eq!(popped_job3.as_ptr(), job3 as _); + } + + #[test] + fn joblist_unlink_head() { + let mut list = JobList::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + + unsafe { + list.push_back(job1); + list.push_back(job2); + } + + assert_eq!(list.len(), 2); + + unsafe { + (&*job1).unlink(); + } + + // Check that job2 is still in the list + let popped_job2 = list.pop_front().unwrap(); + assert_eq!(popped_job2.as_ptr(), job2 as _); + } + + #[test] + fn joblist_unlink_tail() { + let mut list = JobList::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + + unsafe { + list.push_back(job1); + list.push_back(job2); + } + + assert_eq!(list.len(), 2); + + unsafe { + (&*job2).unlink(); + } + + // Check that job1 is still in the list + let popped_job1 = list.pop_front().unwrap(); + assert_eq!(popped_job1.as_ptr(), job1 as _); + } + + #[test] + fn joblist_unlink_single() { + let mut list = JobList::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + + unsafe { + list.push_back(job1); + } + + assert_eq!(list.len(), 1); + + unsafe { + (&*job1).unlink(); + } + + // Check that popping from an empty list returns None + assert!(list.pop_front().is_none()); + } + + #[test] + fn joblist_pop_empty() { + let mut list = JobList::new(); + + // Popping from an empty list should return None + assert!(list.pop_front().is_none()); + assert!(list.pop_back().is_none()); + } + + #[test] + fn jobvec_push_front() { + let mut vec = JobVec::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + + vec.push_front(job1); + vec.push_front(job2); + + assert_eq!(vec.len(), 2); + + let popped_job = vec.pop_front().unwrap(); + assert_eq!(popped_job.as_ptr(), job2 as _); + let popped_job = vec.pop_front().unwrap(); + assert_eq!(popped_job.as_ptr(), job1 as _); + assert!(vec.is_empty()); + } + + #[test] + fn jobvec_push_back() { + let mut vec = JobVec::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + + vec.push_back(job1); + vec.push_back(job2); + + assert_eq!(vec.len(), 2); + + let popped_job = vec.pop_back().unwrap(); + assert_eq!(popped_job.as_ptr(), job2 as _); + let popped_job = vec.pop_back().unwrap(); + assert_eq!(popped_job.as_ptr(), job1 as _); + assert!(vec.is_empty()); + } + + #[test] + fn jobvec_push_front_pop_back() { + let mut vec = JobVec::new(); + let job1 = Box::into_raw(Box::new(Job::::empty())); + let job2 = Box::into_raw(Box::new(Job::::empty())); + + vec.push_front(job1); + vec.push_front(job2); + + assert_eq!(vec.len(), 2); + + let popped_job = vec.pop_back().unwrap(); + assert_eq!(popped_job.as_ptr(), job1 as _); + let popped_job = vec.pop_back().unwrap(); + assert_eq!(popped_job.as_ptr(), job2 as _); + assert!(vec.is_empty()); + } +} diff --git a/distaff/src/join.rs b/distaff/src/join.rs index eef4ea4..3a4e4e9 100644 --- a/distaff/src/join.rs +++ b/distaff/src/join.rs @@ -1,6 +1,7 @@ -use std::hint::cold_path; +use std::{hint::cold_path, sync::Arc}; use crate::{ + context::Context, job::{JobState, StackJob}, latch::{AsCoreLatch, LatchRef, WakeLatch}, workerthread::WorkerThread, @@ -69,7 +70,6 @@ impl WorkerThread { // WorkerThread::current_ref() // .expect("stackjob is run in workerthread.") // .tick(); - a() }, LatchRef::new(&latch), @@ -82,6 +82,7 @@ impl WorkerThread { Ok(val) => val, Err(payload) => { cold_path(); + tracing::debug!("join_heartbeat: b panicked, waiting for a to finish"); // if b panicked, we need to wait for a to finish self.wait_until_latch(&latch); resume_unwind(payload); @@ -89,8 +90,11 @@ impl WorkerThread { }; let ra = if job.state() == JobState::Empty as u8 { + // remove job from the queue, so it doesn't get run again. + // job.unlink(); + //SAFETY: we are in a worker thread, so we can safely access the queue. unsafe { - job.unlink(); + self.queue.as_mut_unchecked().remove(&job); } // a is allowed to panic here, because we already finished b. @@ -108,3 +112,41 @@ impl WorkerThread { (ra, rb) } } + +impl Context { + #[inline] + pub fn join(self: &Arc, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + // SAFETY: join_heartbeat_every is safe to call from a worker thread. + self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b)) + } +} + +/// run two closures potentially in parallel, in the global threadpool. +#[allow(dead_code)] +pub fn join(a: A, b: B) -> (RA, RB) +where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, +{ + join_in(Context::global_context().clone(), a, b) +} + +/// run two closures potentially in parallel, in the global threadpool. +#[allow(dead_code)] +fn join_in(context: Arc, a: A, b: B) -> (RA, RB) +where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, +{ + context.join(a, b) +} diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs index 6a8c20b..767f86f 100644 --- a/distaff/src/latch.rs +++ b/distaff/src/latch.rs @@ -37,6 +37,13 @@ impl AtomicLatch { inner: AtomicU8::new(Self::UNSET), } } + + pub const fn new_set() -> Self { + Self { + inner: AtomicU8::new(Self::SET), + } + } + #[inline] pub fn reset(&self) { self.inner.store(Self::UNSET, Ordering::Release); @@ -46,6 +53,10 @@ impl AtomicLatch { self.inner.load(Ordering::Acquire) } + pub fn set_sleeping(&self) { + self.inner.store(Self::SLEEPING, Ordering::Release); + } + /// returns true if the latch was previously sleeping. #[inline] pub unsafe fn set(this: *const Self) -> bool { @@ -244,7 +255,7 @@ impl Latch for CountLatch { impl Probe for CountLatch { #[inline] fn probe(&self) -> bool { - self.inner.probe() + self.count.load(Ordering::Relaxed) == 0 } } @@ -365,3 +376,168 @@ impl AsCoreLatch for WakeLatch { &self.inner } } + +#[cfg(test)] +mod tests { + use std::sync::Barrier; + + use tracing::Instrument; + use tracing_test::traced_test; + + use super::*; + + #[test] + fn test_atomic_latch() { + let latch = AtomicLatch::new(); + assert_eq!(latch.get(), AtomicLatch::UNSET); + unsafe { + assert!(!latch.probe()); + AtomicLatch::set_raw(&latch); + } + assert_eq!(latch.get(), AtomicLatch::SET); + assert!(latch.probe()); + latch.reset(); + assert_eq!(latch.get(), AtomicLatch::UNSET); + } + + #[test] + fn core_latch_sleep() { + let latch = AtomicLatch::new(); + assert_eq!(latch.get(), AtomicLatch::UNSET); + latch.set_sleeping(); + assert_eq!(latch.get(), AtomicLatch::SLEEPING); + unsafe { + assert!(!latch.probe()); + assert!(AtomicLatch::set(&latch)); + } + assert_eq!(latch.get(), AtomicLatch::SET); + assert!(latch.probe()); + latch.reset(); + assert_eq!(latch.get(), AtomicLatch::UNSET); + } + + #[test] + fn nop_latch() { + assert!( + core::mem::size_of::() == 0, + "NopLatch should be zero-sized" + ); + } + + #[test] + fn thread_wake_latch() { + let latch = Arc::new(ThreadWakeLatch::new()); + let main = Arc::new(ThreadWakeLatch::new()); + + let handle = std::thread::spawn({ + let latch = latch.clone(); + let main = main.clone(); + move || unsafe { + Latch::set_raw(&*main); + latch.wait(); + } + }); + + unsafe { + main.wait(); + Latch::set_raw(&*latch); + } + + handle.join().expect("Thread should join successfully"); + assert!( + !latch.probe() && !main.probe(), + "Latch should be set after waiting thread wakes up" + ); + } + + #[test] + fn count_latch() { + let latch = CountLatch::new(AtomicLatch::new()); + assert_eq!(latch.count(), 0); + latch.increment(); + assert_eq!(latch.count(), 1); + assert!(!latch.probe()); + latch.increment(); + assert_eq!(latch.count(), 2); + assert!(!latch.probe()); + + unsafe { + Latch::set_raw(&latch); + } + assert!(!latch.probe()); + assert_eq!(latch.count(), 1); + + unsafe { + Latch::set_raw(&latch); + } + assert!(latch.probe()); + assert_eq!(latch.count(), 0); + } + + #[test] + fn mutex_latch() { + let latch = Arc::new(MutexLatch::new()); + assert!(!latch.probe()); + latch.set(); + assert!(latch.probe()); + latch.reset(); + assert!(!latch.probe()); + + // Test wait functionality + let latch_clone = latch.clone(); + let handle = std::thread::spawn(move || { + latch_clone.wait(); + }); + + // Give the thread time to block + std::thread::sleep(std::time::Duration::from_millis(100)); + assert!(!latch.probe()); + + latch.set(); + assert!(latch.probe()); + handle.join().expect("Thread should join successfully"); + } + + #[test] + fn wake_latch() { + let context = Context::new_with_threads(1); + let count = Arc::new(AtomicUsize::new(0)); + let barrier = Arc::new(Barrier::new(2)); + + tracing::info!("running scope in worker thread"); + let latch = context.run_in_worker(|worker| { + tracing::info!("worker thread started: {:?}", worker.index); + let latch = WakeLatch::new(worker.context.clone(), worker.index); + worker.context.spawn({ + let heartbeat = worker.heartbeat.clone(); + let barrier = barrier.clone(); + let count = count.clone(); + // set sleeping outside of the closure so we don't have to deal with lifetimes + latch.as_core_latch().set_sleeping(); + move || { + tracing::info!("sleeping workerthread"); + heartbeat.latch.wait_and_reset(); + tracing::info!("woken up workerthread"); + count.fetch_add(1, Ordering::SeqCst); + tracing::info!("waiting on barrier"); + barrier.wait(); + } + }); + + latch + }); + + tracing::info!("setting latch in main thread"); + unsafe { + Latch::set_raw(&latch); + } + + tracing::info!("main thread set latch, waiting for worker thread to wake up"); + barrier.wait(); + assert_eq!( + count.load(Ordering::SeqCst), + 1, + "Latch should have woken the worker thread" + ); + } +} diff --git a/distaff/src/lib.rs b/distaff/src/lib.rs index 2bad781..f064e21 100644 --- a/distaff/src/lib.rs +++ b/distaff/src/lib.rs @@ -20,3 +20,9 @@ mod scope; mod threadpool; pub mod util; mod workerthread; + +pub use context::run_in_worker; +pub use join::join; +pub use scope::{Scope, scope}; +pub use threadpool::ThreadPool; +pub use workerthread::WorkerThread; diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs index 106ab11..e663922 100644 --- a/distaff/src/scope.rs +++ b/distaff/src/scope.rs @@ -11,14 +11,14 @@ use std::{ use async_task::Runnable; use crate::{ - context::{Context, run_in_worker}, + context::Context, job::{HeapJob, Job}, latch::{AsCoreLatch, CountLatch, WakeLatch}, util::{DropGuard, SendPtr}, workerthread::WorkerThread, }; -pub struct Scope<'scope> { +pub struct Scope<'scope, 'env: 'scope> { // latch to wait on before the scope finishes job_counter: CountLatch, // local threadpool @@ -26,55 +26,44 @@ pub struct Scope<'scope> { // panic error panic: AtomicPtr>, // variant lifetime - _pd: PhantomData, + _scope: PhantomData<&'scope mut &'scope ()>, + _env: PhantomData<&'env mut &'env ()>, } -pub fn scope<'scope, F, R>(f: F) -> R +pub fn scope<'env, F, R>(f: F) -> R where - F: FnOnce(&Scope<'scope>) -> R + Send, + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send, R: Send, { - Scope::<'scope>::scope(f) + scope_with_context(Context::global_context(), f) } -impl<'scope> Scope<'scope> { +pub fn scope_with_context<'env, F, R>(context: &Arc, f: F) -> R +where + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send, + R: Send, +{ + context.run_in_worker(|worker| { + // SAFETY: we call complete() after creating this scope, which + // ensures that any jobs spawned from the scope exit before the + // scope closes. + let this = unsafe { Scope::from_context(context.clone()) }; + this.complete(worker, || f(&this)) + }) +} + +impl<'scope, 'env> Scope<'scope, 'env> { fn wait_for_jobs(&self, worker: &WorkerThread) { - tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); - tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe { - worker.queue.as_ref_unchecked() - }); + if self.job_counter.count() > 0 { + tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); + tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe { + worker.queue.as_ref_unchecked() + }); - // set worker index in the job counter - self.job_counter.inner().set_worker_index(worker.index); - worker.wait_until_latch(self.job_counter.as_core_latch()); - } - - pub fn scope(f: F) -> R - where - F: FnOnce(&Self) -> R + Send, - R: Send, - { - run_in_worker(|worker| { - // SAFETY: we call complete() after creating this scope, which - // ensures that any jobs spawned from the scope exit before the - // scope closes. - let this = unsafe { Self::from_context(worker.context.clone()) }; - this.complete(worker, || f(&this)) - }) - } - - fn scope_with_context(context: Arc, f: F) -> R - where - F: FnOnce(&Self) -> R + Send, - R: Send, - { - context.run_in_worker(|worker| { - // SAFETY: we call complete() after creating this scope, which - // ensures that any jobs spawned from the scope exit before the - // scope closes. - let this = unsafe { Self::from_context(context.clone()) }; - this.complete(worker, || f(&this)) - }) + // set worker index in the job counter + self.job_counter.inner().set_worker_index(worker.index); + worker.wait_until_latch(self.job_counter.as_core_latch()); + } } /// should be called from within a worker thread. @@ -153,9 +142,9 @@ impl<'scope> Scope<'scope> { }); } - pub fn spawn(&self, f: F) + pub fn spawn(&'scope self, f: F) where - F: FnOnce(&Scope<'scope>) + Send, + F: FnOnce(&'scope Self) + Send, { self.context.run_in_worker(|worker| { self.job_counter.increment(); @@ -176,70 +165,81 @@ impl<'scope> Scope<'scope> { }); } - pub fn spawn_future(&self, future: F) -> async_task::Task + pub fn spawn_future(&'scope self, future: F) -> async_task::Task where F: Future + Send + 'scope, T: Send + 'scope, { - self.context.run_in_worker(|worker| { - self.job_counter.increment(); - - let this = SendPtr::new_const(&self.job_counter).unwrap(); - - let future = async move { - let _guard = DropGuard::new(move || unsafe { - this.as_ref().decrement(); - }); - future.await - }; - - let schedule = move |runnable: Runnable| { - #[align(8)] - unsafe fn harness(this: *const (), job: *const Job) { - unsafe { - let runnable = - Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); - runnable.run(); - - // SAFETY: job was turned into raw - drop(Box::from_raw(job.cast_mut())); - } - } - - let job = Box::new(Job::::new(harness::, runnable.into_raw())); - - // casting into Job<()> here - worker.push_front(Box::into_raw(job) as _); - }; - - let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; - - runnable.schedule(); - - task - }) + self.spawn_async_internal(move |_| future) } #[allow(dead_code)] - fn spawn_async<'a, T, Fut, Fn>(&'a self, f: Fn) -> async_task::Task + pub fn spawn_async(&'scope self, f: Fn) -> async_task::Task where - Fn: FnOnce(&Scope) -> Fut + Send + 'static, - Fut: Future + Send + 'static, - T: Send + 'static, + Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope, + Fut: Future + Send + 'scope, + T: Send + 'scope, { - let this = SendPtr::new_const(self).unwrap(); - let future = async move { f(unsafe { this.as_ref() }).await }; - - self.spawn_future(future) + self.spawn_async_internal(f) } #[inline] - pub fn join(&self, a: A, b: B) -> (RA, RB) + fn spawn_async_internal(&'scope self, f: Fn) -> async_task::Task + where + Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope, + Fut: Future + Send + 'scope, + T: Send + 'scope, + { + self.job_counter.increment(); + + let this = SendPtr::new_const(self).unwrap(); + // let this = SendPtr::new_const(&self.job_counter).unwrap(); + + let future = async move { + // SAFETY: this is valid until we decrement the job counter. + unsafe { + let _guard = DropGuard::new(move || { + this.as_unchecked_ref().job_counter.decrement(); + }); + f(this.as_ref()).await + } + }; + + let schedule = move |runnable: Runnable| { + #[align(8)] + unsafe fn harness(this: *const (), job: *const Job) { + unsafe { + let runnable = + Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); + runnable.run(); + + // SAFETY: job was turned into raw + drop(Box::from_raw(job.cast_mut())); + } + } + + let job = Box::new(Job::new(harness, runnable.into_raw())); + + // casting into Job<()> here + WorkerThread::current_ref() + .expect("spawn_async_internal is run in workerthread.") + .push_front(Box::into_raw(job) as _); + }; + + let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; + + runnable.schedule(); + + task + } + + #[inline] + pub fn join(&'scope self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, - A: FnOnce(&Self) -> RA + Send, - B: FnOnce(&Self) -> RB + Send, + A: FnOnce(&'scope Self) -> RA + Send, + B: FnOnce(&'scope Self) -> RB + Send, { let worker = WorkerThread::current_ref().expect("join is run in workerthread."); let this = SendPtr::new_const(self).unwrap(); @@ -261,7 +261,60 @@ impl<'scope> Scope<'scope> { context: ctx.clone(), job_counter: CountLatch::new(WakeLatch::new(ctx, 0)), panic: AtomicPtr::new(ptr::null_mut()), - _pd: PhantomData, + _scope: PhantomData, + _env: PhantomData, } } } + +#[cfg(test)] +mod tests { + use std::sync::atomic::AtomicU8; + + use tracing_test::traced_test; + + use super::*; + use crate::ThreadPool; + + #[test] + fn spawn() { + let pool = ThreadPool::new_with_threads(1); + let count = Arc::new(AtomicU8::new(0)); + + scope_with_context(&pool.context, |scope| { + scope.spawn(|_| { + count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + }); + }); + + assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 1); + } + + #[test] + #[traced_test] + fn join() { + let pool = ThreadPool::new_with_threads(1); + + let a = pool.scope(|scope| { + let (a, b) = scope.join(|_| 3 + 4, |_| 5 + 6); + a + b + }); + + assert_eq!(a, 18); + } + + #[test] + fn spawn_future() { + let pool = ThreadPool::new_with_threads(1); + let mut x = 0; + pool.scope(|scope| { + let task = scope.spawn_async(|_| async { + x += 1; + }); + + task.detach(); + }); + + assert_eq!(x, 1); + } +} diff --git a/distaff/src/threadpool.rs b/distaff/src/threadpool.rs index 8b13789..89c5223 100644 --- a/distaff/src/threadpool.rs +++ b/distaff/src/threadpool.rs @@ -1 +1,93 @@ +use std::sync::Arc; +use crate::{Scope, context::Context, scope::scope_with_context}; + +pub struct ThreadPool { + pub(crate) context: Arc, +} + +impl Drop for ThreadPool { + fn drop(&mut self) { + // Ensure that the context is properly cleaned up when the thread pool is dropped. + self.context.set_should_exit(); + } +} + +impl ThreadPool { + pub fn new_with_threads(num_threads: usize) -> Self { + let context = Context::new_with_threads(num_threads); + Self { context } + } + + /// Creates a new thread pool with a thread per hardware thread. + pub fn new() -> Self { + let context = Context::new(); + Self { context } + } + + pub fn scope<'env, F, R>(&self, f: F) -> R + where + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send, + R: Send, + { + scope_with_context(&self.context, f) + } + + pub fn spawn(&self, f: F) + where + F: FnOnce() + Send + 'static, + { + self.context.spawn(f) + } + + pub fn join(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + self.context.join(a, b) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn spawn_borrow() { + let pool = ThreadPool::new_with_threads(1); + let mut x = 0; + pool.scope(|scope| { + scope.spawn(|_| { + x += 1; + }); + }); + assert_eq!(x, 1); + } + + #[test] + fn spawn_future() { + let pool = ThreadPool::new_with_threads(1); + let mut x = 0; + let task = pool.scope(|scope| { + let task = scope.spawn_async(|_| async { + x += 1; + }); + + task + }); + + futures::executor::block_on(task); + assert_eq!(x, 1); + } + + #[test] + fn join() { + let pool = ThreadPool::new_with_threads(1); + let (a, b) = pool.join(|| 3 + 4, || 5 * 6); + assert_eq!(a, 7); + assert_eq!(b, 30); + } +} diff --git a/distaff/src/util.rs b/distaff/src/util.rs index 5f250cd..93ac342 100644 --- a/distaff/src/util.rs +++ b/distaff/src/util.rs @@ -93,6 +93,11 @@ impl SendPtr { pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self { unsafe { Self::new_unchecked(ptr.cast_mut()) } } + + pub unsafe fn as_unchecked_ref(&self) -> &T { + // SAFETY: `self.0` is a valid non-null pointer. + unsafe { self.0.as_ref() } + } } /// A tagged atomic pointer that can store a pointer and a tag `BITS` wide in the same space @@ -402,3 +407,63 @@ pub fn available_parallelism() -> usize { .map(|n| n.get()) .unwrap_or(1) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tagged_ptr_exchange() { + let ptr = Box::into_raw(Box::new(42u32)); + let tagged_ptr = TaggedAtomicPtr::::new(ptr, 0b11); + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + assert_eq!( + tagged_ptr + .compare_exchange_tag(0b11, 0b10, Ordering::Relaxed, Ordering::Relaxed) + .unwrap(), + 0b11 + ); + + assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b10); + assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr); + + unsafe { + _ = Box::from_raw(ptr); + } + } + + #[test] + fn value_inline() { + assert!(SmallBox::::is_inline(), "u32 should be inline"); + assert!(SmallBox::::is_inline(), "u8 should be inline"); + assert!( + SmallBox::>::is_inline(), + "Box should be inline" + ); + assert!( + SmallBox::<[u32; 2]>::is_inline(), + "[u32; 2] should be inline" + ); + assert!( + !SmallBox::<[u32; 3]>::is_inline(), + "[u32; 3] should not be inline" + ); + assert!(SmallBox::::is_inline(), "usize should be inline"); + + #[repr(C, align(16))] + struct LargeType(u8); + assert!( + !SmallBox::::is_inline(), + "LargeType should not be inline" + ); + + #[repr(C, align(4))] + struct SmallType(u8); + assert!( + SmallBox::::is_inline(), + "SmallType should be inline" + ); + } +} diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index f68a8bb..bbc72a0 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -6,11 +6,10 @@ use std::{ }; use crossbeam_utils::CachePadded; -use parking_lot_core::SpinWait; use crate::{ context::{Context, Heartbeat}, - job::{Job, JobList, JobResult}, + job::{Job, JobResult, JobVec as JobList}, latch::{AsCoreLatch, CoreLatch, Probe}, util::DropGuard, }; @@ -19,7 +18,7 @@ pub struct WorkerThread { pub(crate) context: Arc, pub(crate) index: usize, pub(crate) queue: UnsafeCell, - heartbeat: Arc>, + pub(crate) heartbeat: Arc>, pub(crate) join_count: Cell, } @@ -39,11 +38,6 @@ impl WorkerThread { join_count: Cell::new(0), } } - - fn new() -> Self { - let context = Context::global_context().clone(); - Self::new_in(context) - } } impl WorkerThread { @@ -72,7 +66,7 @@ impl WorkerThread { let mut job = self.context.shared().pop_job(); 'outer: loop { let mut guard = loop { - if let Some(job) = job { + if let Some(job) = job.take() { self.execute(job); } @@ -83,9 +77,11 @@ impl WorkerThread { break 'outer; } + // TODO: also check the local queue? match guard.pop_job() { - Some(job) => { - tracing::trace!("worker: popping job: {:?}", job); + Some(popped) => { + tracing::trace!("worker: popping job: {:?}", popped); + job = Some(popped); // found job, continue inner loop continue; } @@ -107,6 +103,7 @@ impl WorkerThread { #[inline(always)] fn tick(&self) { if self.heartbeat.is_pending() { + tracing::trace!("received heartbeat, thread id: {:?}", self.index); self.heartbeat_cold(); } } @@ -190,8 +187,11 @@ impl WorkerThread { unsafe fn drop_in_place(this: *mut Self) { unsafe { - this.drop_in_place(); - drop(Box::from_raw(this)); + // SAFETY: this is only called when the thread is exiting, so we can + // safely drop the thread. We use `drop_in_place` to prevent `Box` + // from creating a no-alias reference to the worker thread. + core::ptr::drop_in_place(this); + _ = Box::>::from_raw(this as _); } } } @@ -258,11 +258,6 @@ impl WorkerThread { assert!(!latch.probe()); 'outer: while !latch.probe() { - // take a shared job, if it exists - if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { - self.execute(shared_job); - } - // process local jobs before locking shared context while let Some(job) = self.pop_front() { unsafe { @@ -271,8 +266,16 @@ impl WorkerThread { self.execute(job); } + // take a shared job, if it exists + if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { + self.execute(shared_job); + } + while !latch.probe() { - let job = self.context.shared().pop_job(); + let job = { + let mut guard = self.context.shared(); + guard.jobs.remove(&self.index).or_else(|| guard.pop_job()) + }; match job { Some(job) => { @@ -281,8 +284,6 @@ impl WorkerThread { continue 'outer; } None => { - tracing::trace!("waiting for shared job, thread id: {:?}", self.index); - // TODO: wait on latch? if we have something that can // signal being done, e.g. can be waited on instead of // shared jobs, we should wait on it instead, but we @@ -297,6 +298,9 @@ impl WorkerThread { // Yield? same as spinning, really, so just exit and let the upstream use wait // std::thread::yield_now(); + tracing::trace!("thread {:?} is sleeping", self.index); + + latch.set_sleeping(); self.heartbeat.latch.wait_and_reset(); // since we were sleeping, the shared job can't be populated, // so resuming the inner loop is fine. @@ -339,58 +343,4 @@ impl WorkerThread { self.wait_until_latch_cold(latch) } } - - #[inline] - fn wait_until_predicate(&self, pred: F) - where - F: Fn() -> bool, - { - 'outer: while !pred() { - // take a shared job, if it exists - if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { - self.execute(shared_job); - } - - // process local jobs before locking shared context - while let Some(job) = self.pop_front() { - unsafe { - job.as_ref().set_pending(); - } - self.execute(job); - } - - while !pred() { - let mut guard = self.context.shared(); - let mut _spin = SpinWait::new(); - - match guard.pop_job() { - Some(job) => { - drop(guard); - self.execute(job); - - continue 'outer; - } - None => { - tracing::trace!("waiting for shared job, thread id: {:?}", self.index); - - // TODO: wait on latch? if we have something that can - // signal being done, e.g. can be waited on instead of - // shared jobs, we should wait on it instead, but we - // would also want to receive shared jobs still? - // Spin? probably just wastes CPU time. - // self.context.shared_job.wait(&mut guard); - // if spin.spin() { - // // wait for more shared jobs. - // // self.context.shared_job.wait(&mut guard); - // return; - // } - // Yield? same as spinning, really, so just exit and let the upstream use wait - // std::thread::yield_now(); - return; - } - } - } - } - return; - } } diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index 90191cd..08e782c 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -813,7 +813,7 @@ mod job { } } - /// call this when popping value from local queue + /// must be called before `execute()` pub fn set_pending(&self) { let mut spin = SpinWait::new(); loop {