From fd0cd86a8d087770f29de96d80f62cba242747c6 Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 31 Jan 2025 00:19:57 +0100 Subject: [PATCH] scope --- Cargo.toml | 6 +- src/lib.rs | 324 +++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 295 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index daf874e..93b6150 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,8 @@ tracing-subscriber = {version ="0.3.18", features = ["env-filter"]} anyhow = "1.0.89" thiserror = "2.0" bitflags = "2.6" -# derive_more = "1.0.0" \ No newline at end of file +core_affinity = "0.8.1" +# derive_more = "1.0.0" + +[dev-dependencies] +tracing-test = "0.2.5" diff --git a/src/lib.rs b/src/lib.rs index a0ab289..3195884 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,9 @@ use bitflags::bitflags; use crossbeam::{queue::SegQueue, utils::CachePadded}; use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; use parking_lot::{Condvar, Mutex}; +use scope::Scope; use task::{HeapTask, StackTask, TaskRef}; +use tracing::debug; pub mod task { use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; @@ -78,6 +80,9 @@ pub mod task { pub fn run(self) { self.task.into_inner().unwrap()(); } + pub unsafe fn run_as_ref(&self) { + ((&mut *self.task.get()).take().unwrap())(); + } pub fn as_task_ref(self: Pin<&Self>) -> TaskRef { unsafe { TaskRef::new(&*self) } @@ -380,39 +385,69 @@ impl ThreadState { self.status_changed.notify_all(); } - fn set_should_terminate(&self) { + fn notify_should_terminate(&self) { unsafe { Latch::set_raw(&self.should_terminate); } + self.wake(); } } const MAX_THREADS: usize = 32; +type ThreadCallback = dyn Fn(&WorkerThread) + Send + Sync + 'static; +pub struct ThreadPoolCallbacks { + at_entry: Option>, + at_exit: Option>, +} + +impl ThreadPoolCallbacks { + pub const fn new_empty() -> ThreadPoolCallbacks { + Self { + at_entry: None, + at_exit: None, + } + } +} + pub struct ThreadPool { threads: [CachePadded; MAX_THREADS], pool_state: CachePadded, global_queue: SegQueue, + callbacks: CachePadded, } impl ThreadPool { + const INITIAL_THREAD_STATE: CachePadded = CachePadded::new(ThreadState { + should_shove: AtomicBool::new(false), + shoved_task: Slot::new(), + status: Mutex::new(ThreadStatus::empty()), + status_changed: Condvar::new(), + should_terminate: AtomicLatch::new(), + }); pub const fn new() -> Self { - const INITIAL_THREAD_STATE: CachePadded = CachePadded::new(ThreadState { - should_shove: AtomicBool::new(false), - shoved_task: Slot::new(), - status: Mutex::new(ThreadStatus::empty()), - status_changed: Condvar::new(), - should_terminate: AtomicLatch::new(), - }); - Self { - threads: const { [INITIAL_THREAD_STATE; MAX_THREADS] }, + threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, pool_state: CachePadded::new(ThreadPoolState { num_threads: AtomicUsize::new(0), lock: Mutex::new(()), - heartbeat_state: INITIAL_THREAD_STATE, + heartbeat_state: Self::INITIAL_THREAD_STATE, }), global_queue: SegQueue::new(), + callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()), + } + } + + pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { + Self { + threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, + pool_state: CachePadded::new(ThreadPoolState { + num_threads: AtomicUsize::new(0), + lock: Mutex::new(()), + heartbeat_state: Self::INITIAL_THREAD_STATE, + }), + global_queue: SegQueue::new(), + callbacks: CachePadded::new(callbacks), } } @@ -470,13 +505,15 @@ impl ThreadPool { fn resize usize>(&'static self, size: F) -> usize { if WorkerThread::is_worker_thread() { // acquire required here? + debug!("tried to resize from within threadpool!"); return self.pool_state.num_threads.load(Ordering::Acquire); } let _guard = self.pool_state.lock.lock(); let current_size = self.pool_state.num_threads.load(Ordering::Acquire); - let new_size = size(current_size).max(MAX_THREADS); + let new_size = size(current_size).clamp(0, MAX_THREADS); + debug!(current_size, new_size, "resizing threadpool"); if new_size == current_size { return current_size; @@ -490,7 +527,7 @@ impl ThreadPool { std::cmp::Ordering::Greater => { let new_threads = &self.threads[current_size..new_size]; - for (i, thread) in new_threads.iter().enumerate() { + for (i, _) in new_threads.iter().enumerate() { std::thread::spawn(move || { WorkerThread::worker_loop(&self, current_size + i); }); @@ -510,10 +547,14 @@ impl ThreadPool { } } std::cmp::Ordering::Less => { + debug!( + "waiting for threads {:?} to terminate.", + new_size..current_size + ); let terminating_threads = &self.threads[new_size..current_size]; for thread in terminating_threads { - thread.set_should_terminate(); + thread.notify_should_terminate(); } for thread in terminating_threads { thread.wait_for_termination(); @@ -521,7 +562,7 @@ impl ThreadPool { #[cfg(not(feature = "internal_heartbeat"))] if new_size == 0 { - self.pool_state.heartbeat_state.set_should_terminate(); + self.pool_state.heartbeat_state.notify_should_terminate(); self.pool_state.heartbeat_state.wait_for_termination(); } } @@ -699,6 +740,66 @@ impl ThreadPool { } }); } + + fn join(&'static self, f: F, g: G) -> (T, U) + where + F: FnOnce() -> T + Send, + G: FnOnce() -> U + Send, + T: Send, + U: Send, + { + self.in_worker(|worker, _| { + let mut result_b = None; + + let latch_b = ThreadWakeLatch::new(worker); + let task_b = pin!(StackTask::new(|| { + result_b = Some(g()); + unsafe { + Latch::set_raw(&latch_b); + } + })); + + let ref_b = task_b.as_ref().as_task_ref(); + let b_id = ref_b.id(); + // TODO: maybe try to push this off to another thread immediately first? + worker.push_task(ref_b); + + let result_a = f(); + + while !latch_b.probe() { + match worker.pop_task() { + Some(task) => { + if task.id() == b_id { + worker.try_promote(); + unsafe { + task_b.run_as_ref(); + } + break; + } + worker.execute(task); + } + None => { + worker.run_until(&latch_b); + } + } + } + + (result_a, result_b.unwrap()) + }) + } + + fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T + where + Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send, + T: Send, + { + self.in_worker(|owner, _| { + let scope = pin!(unsafe { Scope::<'scope>::new(owner) }); + let result = f(scope.as_ref()); + scope.complete(owner); + result + }) + } } pub struct WorkerThread { @@ -719,6 +820,12 @@ impl WorkerThread { fn info(&self) -> &ThreadState { &self.pool.threads[self.index as usize] } + fn pool(&self) -> &'static ThreadPool { + self.pool + } + fn index(&self) -> usize { + self.index + } fn is_worker_thread() -> bool { Self::with(|worker| worker.is_some()) } @@ -758,11 +865,10 @@ impl WorkerThread { match self.info().shoved_task.try_put(task) { // shoved task is occupied, reinsert into queue Some(task) => self.queue.push_back(task), - None => { - // wake thread to execute task - self.pool.wake_any(1); - } + None => {} } + // wake thread to execute task + self.pool.wake_any(1); } } @@ -843,10 +949,18 @@ impl WorkerThread { last_heartbeat: UnsafeCell::new(std::time::Instant::now()), }); + if let Some(callback) = pool.callbacks.at_entry.as_ref() { + callback(worker); + } + info.notify_running(); // info.notify_running(); worker.run_until(&info.should_terminate); + if let Some(callback) = pool.callbacks.at_exit.as_ref() { + callback(worker); + } + for task in worker.drain() { pool.inject(task); } @@ -1039,9 +1153,116 @@ mod rng { } } +mod scope { + use std::{ + future::Future, + marker::{PhantomData, PhantomPinned}, + ptr::{self, NonNull}, + }; + + use async_task::{Runnable, Task}; + + use crate::{ + latch::{CountWakeLatch, Latch}, + task::{HeapTask, TaskRef}, + ThreadPool, WorkerThread, + }; + + pub struct Scope<'scope> { + pool: &'static ThreadPool, + tasks_completed_latch: CountWakeLatch, + _marker: PhantomData) + Send + Sync + 'scope>>, + } + + impl<'scope> Scope<'scope> { + pub unsafe fn new(owner: &WorkerThread) -> Scope<'scope> { + Scope { + pool: owner.pool(), + tasks_completed_latch: CountWakeLatch::new(1, owner), + _marker: PhantomData, + } + } + + pub fn spawn(&self, f: Fn) + where + Fn: FnOnce(&Scope<'scope>) + Send + 'scope, + { + self.tasks_completed_latch.increment(); + + let ptr = SendPtr::from_ref(self); + let task = HeapTask::new(move || unsafe { + let this = ptr.as_ref(); + f(this); + Latch::set_raw(&this.tasks_completed_latch); + }); + + let taskref = unsafe { task.into_task_ref() }; + self.pool.push_local_or_inject(taskref); + } + + pub fn spawn_future(&self, future: Fut) -> Task + where + Fut: Future + Send + 'scope, + T: Send + 'scope, + { + self.tasks_completed_latch.increment(); + + let ptr = SendPtr::from_ref(self); + let schedule = move |runnable: Runnable| { + let taskref = unsafe { + TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| { + let this = NonNull::new_unchecked(this.cast_mut()); + + let runnable = Runnable::<()>::from_raw(this); + runnable.run(); + }) + }; + + unsafe { + ptr.as_ref().pool.push_local_or_inject(taskref); + } + }; + + let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; + + runnable.schedule(); + task + } + + pub fn spawn_async(&self, f: Fn) -> Task + where + Fn: FnOnce() -> Fut + Send + 'scope, + Fut: Future + Send + 'scope, + T: Send + 'scope, + { + self.spawn_future(async move { f().await }) + } + + pub fn complete(&self, owner: &WorkerThread) { + unsafe { + Latch::set_raw(&self.tasks_completed_latch); + } + owner.run_until(&self.tasks_completed_latch); + } + } + + struct SendPtr(*const T); + impl SendPtr { + fn from_ref(t: &T) -> Self { + Self(ptr::from_ref(t).cast()) + } + unsafe fn as_ref(&self) -> &T { + &*self.0 + } + } + unsafe impl Send for SendPtr {} +} + #[cfg(test)] mod tests { - use std::cell::Cell; + use std::{cell::Cell, hint::black_box}; + + use crate::latch::CountWakeLatch; use super::*; @@ -1055,30 +1276,65 @@ mod tests { 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, ]; - fn run_on_static_pool(f: impl FnOnce(&'static ThreadPool)) { - let pool = Box::new(ThreadPool::new()); + fn run_in_scope(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { + let pool = Box::new(pool); let ptr = Box::into_raw(pool); - { + let result = { let pool: &'static ThreadPool = unsafe { &*ptr }; - pool.ensure_one_worker(); - f(pool); + // pool.ensure_one_worker(); + pool.resize_to_available(); + let result = pool.scope(f); pool.resize_to(0); assert!(pool.global_queue.pop().is_none()); - } + result + }; let _pool = unsafe { Box::from_raw(ptr) }; + result } #[test] + #[tracing_test::traced_test] fn spawn_random() { - std::thread_local! {static WAIT_COUNT: Cell= Cell::new(0);} - run_on_static_pool(|pool| { - for &p in PRIMES { - pool.spawn(move || { - std::thread::sleep(Duration::from_micros(p as u64)); - }); - } - }); + std::thread_local! { + static WAIT_COUNT: Cell = const {Cell::new(0)}; + } + let counter = Arc::new(AtomicUsize::new(0)); + let elapsed = { + let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks { + at_entry: Some(Arc::new(|_worker| { + // eprintln!("new worker thread: {}", worker.index); + })), + at_exit: Some(Arc::new({ + let counter = counter.clone(); + move |_worker: &WorkerThread| { + // eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get()); + counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed); + } + })), + }); + + let now = std::time::Instant::now(); + run_in_scope(pool, |s| { + for &p in core::iter::repeat_n(PRIMES, 0x1000).flatten() { + s.spawn(move |_| { + // std::thread::sleep(Duration::from_micros(p as u64)); + // spin for + let tmp = (0..p).reduce(|a, b| black_box(a & b)); + black_box(tmp); + + // WAIT_COUNT.with(|count| { + // // eprintln!("{} + {p}", count.get()); + // count.set(count.get() + p); + // }); + }); + } + }); + now.elapsed().as_micros() + }; + + eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); + eprintln!("total time: {}ms", elapsed as f32 / 1e3); } }