#![feature( vec_deque_pop_if, unsafe_cell_access, debug_closure_helpers, cold_path, fn_align, box_vec_non_null, box_as_ptr, atomic_try_update, let_chains )] use std::{ cell::{Cell, UnsafeCell}, future::Future, mem::MaybeUninit, num::NonZero, pin::{pin, Pin}, ptr::{self, NonNull}, sync::{ atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, Arc, }, task::Context, thread::available_parallelism, time::Duration, }; use async_task::{Runnable, Task}; use bitflags::bitflags; use crossbeam::{ atomic::AtomicCell, deque::{Injector, Stealer, Worker}, utils::CachePadded, }; use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; use parking_lot::{Condvar, Mutex}; use scope::Scope; use task::{HeapTask, StackTask, TaskRef}; use tracing::debug; pub mod job; pub mod util; pub mod task { use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; pub trait Task { unsafe fn execute(this: *const ()); } pub struct TaskRef { ptr: *const (), execute_fn: unsafe fn(*const ()), } impl TaskRef { pub unsafe fn new(task: *const T) -> TaskRef where T: Task, { Self { ptr: task.cast(), execute_fn: ::execute, } } pub unsafe fn new_raw(ptr: *const (), execute_fn: unsafe fn(*const ())) -> TaskRef { Self { ptr, execute_fn } } #[inline] pub fn id(&self) -> impl Eq { (self.ptr, self.execute_fn) } #[inline] pub fn execute(self) { unsafe { (self.execute_fn)(self.ptr) } } #[inline] pub unsafe fn execute_with_scope(self, scope: &mut T) { unsafe { core::mem::transmute::<_, unsafe fn(*const (), &mut T)>(self.execute_fn)( self.ptr, scope, ) } } } unsafe impl Send for TaskRef {} unsafe impl Sync for TaskRef {} pub struct StackTask { task: UnsafeCell>, _phantom: PhantomPinned, } impl StackTask { pub fn new(task: F) -> StackTask { Self { task: UnsafeCell::new(Some(task)), _phantom: PhantomPinned, } } #[inline] pub fn run(self) { self.task.into_inner().unwrap()(); } #[inline] pub unsafe fn run_as_ref(&self) { ((&mut *self.task.get()).take().unwrap())(); } #[inline] pub fn as_task_ref(self: Pin<&Self>) -> TaskRef { unsafe { TaskRef::new(&*self) } } } impl Task for StackTask { #[inline] unsafe fn execute(this: *const ()) { let this = &*this.cast::(); let task = (&mut *this.task.get()).take().unwrap(); task(); } } pub struct HeapTask { task: F, _phantom: PhantomPinned, } impl HeapTask { pub fn new(task: F) -> Box> { Box::new(Self { task, _phantom: PhantomPinned, }) } #[inline] pub unsafe fn into_static_task_ref(self: Box) -> TaskRef where F: 'static, { self.into_task_ref() } #[inline] pub unsafe fn into_task_ref(self: Box) -> TaskRef { TaskRef::new(Box::into_raw(self)) } } impl Task for HeapTask { #[inline] unsafe fn execute(this: *const ()) { let this = Box::from_raw(this.cast::().cast_mut()); (this.task)(); } } } pub mod latch { use core::marker::PhantomData; use std::{ sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, task::Wake, }; use parking_lot::{Condvar, Mutex}; use crate::{ThreadPool, WorkerThread}; pub trait Latch { unsafe fn set_raw(this: *const Self); } pub trait Probe { fn probe(&self) -> bool; } #[derive(Debug)] pub struct AtomicLatch(AtomicBool); impl AtomicLatch { #[inline] pub const fn new() -> AtomicLatch { Self(AtomicBool::new(false)) } #[inline] pub fn reset(&self) { self.0.store(false, Ordering::Release); } } impl Latch for AtomicLatch { #[inline] unsafe fn set_raw(this: *const Self) { (*this).0.store(true, Ordering::Release); } } impl Probe for AtomicLatch { #[inline] fn probe(&self) -> bool { self.0.load(Ordering::Acquire) } } pub struct ClosureLatch { set: S, probe: P, } impl ClosureLatch { pub fn new(set: S, probe: P) -> Self { Self { set, probe } } pub fn new_boxed(set: S, probe: P) -> Box { Box::new(Self { set, probe }) } } impl Latch for ClosureLatch where S: Fn(), { unsafe fn set_raw(this: *const Self) { let this = &*this; (this.set)(); } } impl Probe for ClosureLatch where P: Fn() -> bool, { fn probe(&self) -> bool { (self.probe)() } } pub struct ThreadWakeLatch { inner: AtomicLatch, index: usize, pool: &'static ThreadPool, } impl ThreadWakeLatch { #[inline] pub const fn new(thread: &WorkerThread) -> ThreadWakeLatch { Self { inner: AtomicLatch::new(), pool: thread.pool, index: thread.index, } } #[inline] pub fn reset(&self) { self.inner.reset() } } impl Latch for ThreadWakeLatch { #[inline] unsafe fn set_raw(this: *const Self) { let (pool, index) = { let this = &*this; (this.pool, this.index) }; Latch::set_raw(&(*this).inner); pool.wake_thread(index); } } impl Probe for ThreadWakeLatch { #[inline] fn probe(&self) -> bool { self.inner.probe() } } pub struct LatchRef<'a, L: Latch> { inner: *const L, _marker: PhantomData<&'a L>, } impl<'a, L: Latch> LatchRef<'a, L> { #[inline] pub const fn new(latch: &'a L) -> Self { Self { inner: latch, _marker: PhantomData, } } } impl<'a, L: Latch> Latch for LatchRef<'a, L> { #[inline] unsafe fn set_raw(this: *const Self) { let this = &*this; Latch::set_raw(this.inner); } } impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> { #[inline] fn probe(&self) -> bool { unsafe { let this = &*self.inner; Probe::probe(this) } } } pub struct NopLatch; impl Latch for NopLatch { #[inline] unsafe fn set_raw(_this: *const Self) { // do nothing } } pub struct MutexLatch { mutex: Mutex, signal: Condvar, } impl MutexLatch { #[inline] pub const fn new() -> MutexLatch { Self { mutex: Mutex::new(false), signal: Condvar::new(), } } #[inline] pub fn wait(&self) { let mut guard = self.mutex.lock(); while !*guard { self.signal.wait(&mut guard); } } #[inline] pub fn wait_and_reset(&self) { let mut guard = self.mutex.lock(); while !*guard { self.signal.wait(&mut guard); } *guard = false; } } impl Latch for MutexLatch { #[inline] unsafe fn set_raw(this: *const Self) { let mut guard = (*this).mutex.lock(); *guard = true; (*this).signal.notify_all(); } } pub struct CountWakeLatch { counter: AtomicUsize, inner: ThreadWakeLatch, } impl CountWakeLatch { #[inline] pub const fn new(count: usize, thread: &WorkerThread) -> CountWakeLatch { Self { counter: AtomicUsize::new(count), inner: ThreadWakeLatch::new(thread), } } #[inline] pub fn increment(&self) { self.counter.fetch_add(1, Ordering::Relaxed); } } impl Latch for CountWakeLatch { #[inline] unsafe fn set_raw(this: *const Self) { if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 { Latch::set_raw(&(*this).inner); } } } impl Probe for CountWakeLatch { #[inline] fn probe(&self) -> bool { self.inner.probe() } } pub struct LatchWaker(L); impl LatchWaker { #[inline] pub fn new(latch: L) -> Arc { Arc::new(Self(latch)) } #[inline] pub fn latch(&self) -> &L { &self.0 } } impl Wake for LatchWaker where L: Latch, { #[inline] fn wake(self: Arc) { self.wake_by_ref(); } #[inline] fn wake_by_ref(self: &Arc) { unsafe { Latch::set_raw(&self.0); } } } } pub mod melange; pub mod praetor; pub struct ThreadPoolState { num_threads: AtomicUsize, lock: Mutex<()>, heartbeat_state: CachePadded, } bitflags! { #[derive(Clone)] pub struct ThreadStatus: u8 { const RUNNING = 1 << 0; const SLEEPING = 1 << 1; const SHOULD_WAKE = 1 << 2; } } pub struct ThreadControl { status: Mutex, status_changed: Condvar, should_terminate: AtomicLatch, } pub struct ThreadState { should_shove: AtomicBool, control: ThreadControl, stealer: Stealer, worker: AtomicCell>>, shoved_task: CachePadded>, } impl ThreadControl { pub const fn new() -> Self { Self { status: Mutex::new(ThreadStatus::empty()), status_changed: Condvar::new(), should_terminate: AtomicLatch::new(), } } /// returns true if thread was sleeping #[inline] pub fn wake(&self) -> bool { let mut guard = self.status.lock(); guard.insert(ThreadStatus::SHOULD_WAKE); self.status_changed.notify_all(); guard.contains(ThreadStatus::SLEEPING) } #[inline] pub fn wait_for_running(&self) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::RUNNING) { self.status_changed.wait(&mut guard); } } #[inline] pub fn wait_for_should_wake(&self) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::SHOULD_WAKE) { guard.insert(ThreadStatus::SLEEPING); self.status_changed.wait(&mut guard); } guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); } #[inline] pub fn wait_for_should_wake_timeout(&self, timeout: Duration) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::SHOULD_WAKE) { guard.insert(ThreadStatus::SLEEPING); if self .status_changed .wait_for(&mut guard, timeout) .timed_out() { break; } } guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); } #[inline] pub fn wait_for_termination(&self) { let mut guard = self.status.lock(); while guard.contains(ThreadStatus::RUNNING) { self.status_changed.wait(&mut guard); } } #[inline] pub fn notify_running(&self) { let mut guard = self.status.lock(); guard.insert(ThreadStatus::RUNNING); self.status_changed.notify_all(); } #[inline] pub fn notify_termination(&self) { let mut guard = self.status.lock(); *guard = ThreadStatus::empty(); self.status_changed.notify_all(); } #[inline] pub fn notify_should_terminate(&self) { unsafe { Latch::set_raw(&self.should_terminate); } self.wake(); } } const MAX_THREADS: usize = 32; type ThreadCallback = dyn Fn(&WorkerThread) + Send + Sync + 'static; pub struct ThreadPoolCallbacks { at_entry: Option>, at_exit: Option>, } impl ThreadPoolCallbacks { pub const fn new_empty() -> ThreadPoolCallbacks { Self { at_entry: None, at_exit: None, } } } pub struct ThreadPool { threads: [CachePadded; MAX_THREADS], pool_state: CachePadded, global_queue: Injector, callbacks: CachePadded, } impl ThreadPool { pub fn new() -> Self { Self::new_with_callbacks(ThreadPoolCallbacks::new_empty()) } pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| { let worker = Worker::::new_fifo(); let stealer = worker.stealer(); let thread = CachePadded::new(ThreadState { should_shove: AtomicBool::new(false), shoved_task: Slot::new().into(), control: ThreadControl { status: Mutex::new(ThreadStatus::empty()), status_changed: Condvar::new(), should_terminate: AtomicLatch::new(), }, stealer, worker: AtomicCell::new(Some(worker)), }); uninit.write(thread); unsafe { uninit.assume_init() } }); Self { threads, pool_state: CachePadded::new(ThreadPoolState { num_threads: AtomicUsize::new(0), lock: Mutex::new(()), heartbeat_state: ThreadControl { status: Mutex::new(ThreadStatus::empty()), status_changed: Condvar::new(), should_terminate: AtomicLatch::new(), } .into(), }), global_queue: Injector::new(), callbacks: CachePadded::new(callbacks), } } #[inline] fn threads(&self) -> &[CachePadded] { &self.threads[..self.pool_state.num_threads.load(Ordering::Relaxed) as usize] } pub fn wake_thread(&self, index: usize) -> Option { Some(self.threads.get(index as usize)?.control.wake()) } pub fn wake_any(&self, count: usize) -> usize { if count > 0 { let num_woken = self .threads .iter() .filter_map(|thread| thread.control.wake().then_some(())) .take(count) .count(); num_woken } else { 0 } } #[inline] pub fn id(&self) -> impl Eq { core::ptr::from_ref(self) as usize } #[allow(dead_code)] fn push_local_or_inject_balanced(&self, task: TaskRef) { let global_len = self.global_queue.len(); WorkerThread::with(|worker| match worker { Some(worker) if worker.pool.id() == self.id() => { let worker_len = worker.worker.len(); if worker_len == 0 { worker.push_task(task); } else if global_len == 0 { self.inject(task); } else { if worker_len >= global_len { worker.push_task(task); } else { self.inject(task); } } } _ => self.inject(task), }) } fn push_local_or_inject(&self, task: TaskRef) { WorkerThread::with(|worker| match worker { Some(worker) if worker.pool.id() == self.id() => worker.push_task(task), _ => self.inject(task), }) } #[allow(dead_code)] fn inject_many(&self, tasks: I) where I: Iterator, { let mut n = 0; for task in tasks { n += 1; self.global_queue.push(task); } self.wake_any(n); } #[allow(unused_variables)] fn inject_maybe_local(&self, task: TaskRef) { #[cfg(all(not(feature = "never-local"), feature = "prefer-local"))] self.push_local_or_inject(task); #[cfg(all(not(feature = "prefer-local"), feature = "never-local"))] self.inject(task); #[cfg(not(any(feature = "prefer-local", feature = "never-local")))] self.push_local_or_inject_balanced(task); } fn inject(&self, task: TaskRef) { self.global_queue.push(task); self.wake_any(1); } #[allow(dead_code)] fn resize usize>(&'static self, size: F) -> usize { if WorkerThread::is_worker_thread() { // acquire required here? debug!("tried to resize from within threadpool!"); return self.pool_state.num_threads.load(Ordering::Acquire); } #[cfg(feature = "cpu-pinning")] let cpus = core_affinity::get_core_ids().unwrap(); let _guard = self.pool_state.lock.lock(); let current_size = self.pool_state.num_threads.load(Ordering::Acquire); let new_size = size(current_size).clamp(0, MAX_THREADS); debug!(current_size, new_size, "resizing threadpool"); if new_size == current_size { return current_size; } self.pool_state .num_threads .store(new_size, Ordering::Release); match new_size.cmp(¤t_size) { std::cmp::Ordering::Greater => { let new_threads = &self.threads[current_size..new_size]; for (i, _) in new_threads.iter().enumerate() { #[cfg(feature = "cpu-pinning")] let core = cpus[i]; std::thread::spawn(move || { #[cfg(feature = "cpu-pinning")] core_affinity::set_for_current(core); WorkerThread::worker_loop(&self, current_size + i); }); } for thread in new_threads { thread.control.wait_for_running(); } #[cfg(feature = "heartbeat")] if current_size == 0 { std::thread::spawn(move || { heartbeat_loop(self); }); self.pool_state.heartbeat_state.wait_for_running(); } } std::cmp::Ordering::Less => { debug!( "waiting for threads {:?} to terminate.", new_size..current_size ); let terminating_threads = &self.threads[new_size..current_size]; for thread in terminating_threads { thread.control.notify_should_terminate(); } for thread in terminating_threads { thread.control.wait_for_termination(); } #[cfg(feature = "heartbeat")] if new_size == 0 { self.pool_state.heartbeat_state.notify_should_terminate(); self.pool_state.heartbeat_state.wait_for_termination(); } } std::cmp::Ordering::Equal => unreachable!(), } new_size } #[allow(dead_code)] fn ensure_one_worker(&'static self) -> usize { self.resize(|current| current.max(1)) } #[allow(dead_code)] fn resize_to_available(&'static self) { self.resize_to(available_parallelism().map(NonZero::get).unwrap_or(1)); } #[allow(dead_code)] fn resize_to(&'static self, new_size: usize) -> usize { self.resize(|_| new_size) } #[allow(dead_code)] fn grow_by(&'static self, num_threads: usize) -> usize { self.resize(|current| current.saturating_add(num_threads)) } #[allow(dead_code)] fn shrink_by(&'static self, num_threads: usize) -> usize { self.resize(|current| current.saturating_sub(num_threads)) } #[allow(dead_code)] fn shrink_to(&'static self, num_threads: usize) -> usize { self.resize(|_| num_threads) } fn in_worker(&'static self, f: F) -> T where F: FnOnce(&WorkerThread, bool) -> T + Send, T: Send, { WorkerThread::with(|worker| match worker { Some(worker) => { if worker.pool.id() == self.id() { self.in_worker_cross(worker, f) } else { f(worker, false) } } None => self.in_worker_cold(f), }) } #[cold] fn in_worker_cold(&'static self, f: F) -> T where F: FnOnce(&WorkerThread, bool) -> T + Send, T: Send, { std::thread_local! {static LATCH: MutexLatch = const {MutexLatch::new()}}; LATCH.with(|latch| { let mut result = None; let task = StackTask::new(|| { WorkerThread::with(|worker| { let worker = worker.unwrap(); result = Some(f(worker, true)); unsafe { // SAFETY: static thread-local Latch::set_raw(latch); } }) }); let pinned = pin!(task); let taskref = pinned.as_ref().as_task_ref(); self.inject(taskref); latch.wait_and_reset(); result.unwrap() }) } /// run f in `self`, but block current thread until work is complete. fn in_worker_cross(&'static self, worker: &WorkerThread, f: F) -> T where F: FnOnce(&WorkerThread, bool) -> T + Send, T: Send, { let latch = ThreadWakeLatch::new(worker); let mut result = None; let task = pin!(StackTask::new(|| { WorkerThread::with(|worker| { let worker = worker.unwrap(); result = Some(f(worker, true)); unsafe { // SAFETY: static thread-local Latch::set_raw(&latch); } }) })); let taskref = task.into_ref().as_task_ref(); self.push_local_or_inject(taskref); worker.run_until(&latch); result.unwrap() } } impl ThreadPool { pub fn spawn(&'static self, f: Fn) where Fn: FnOnce() + Send + 'static, { let task = HeapTask::new(f); let taskref = unsafe { task.into_static_task_ref() }; self.inject_maybe_local(taskref); } pub fn spawn_future(&'static self, future: Fut) -> Task where Fut: Future + Send + 'static, T: Send + 'static, { let schedule = move |runnable: Runnable| { let taskref = unsafe { TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| { let this = NonNull::new_unchecked(this.cast_mut()); let runnable = Runnable::<()>::from_raw(this); runnable.run(); }) }; self.inject_maybe_local(taskref); }; let (runnable, task) = async_task::spawn(future, schedule); runnable.schedule(); task } pub fn spawn_async(&'static self, f: Fn) -> Task where Fn: FnOnce() -> Fut + Send + 'static, Fut: Future + Send + 'static, T: Send + 'static, { self.spawn_future(async move { f().await }) } pub fn block_on(&'static self, mut future: Fut) where Fut: Future + Send + 'static, T: Send + 'static, { let mut future = unsafe { Pin::new_unchecked(&mut future) }; self.in_worker(|worker, _| { let wake = LatchWaker::new(ThreadWakeLatch::new(worker)); let ctx_waker = Arc::clone(&wake).into(); let mut ctx = Context::from_waker(&ctx_waker); loop { match future.as_mut().poll(&mut ctx) { std::task::Poll::Ready(t) => { return t; } std::task::Poll::Pending => { worker.run_until(wake.latch()); wake.latch().reset(); } } } }); } pub fn join(&'static self, f: F, g: G) -> (T, U) where F: FnOnce() -> T + Send, G: FnOnce() -> U + Send, T: Send, U: Send, { self.join_threaded(f, g) } #[allow(dead_code)] fn join_seq(&'static self, f: F, g: G) -> (T, U) where F: FnOnce() -> T + Send, G: FnOnce() -> U + Send, T: Send, U: Send, { let a = f(); let b = g(); (a, b) } #[inline] fn join_threaded(&'static self, f: F, g: G) -> (T, U) where F: FnOnce() -> T + Send, G: FnOnce() -> U + Send, T: Send, U: Send, { self.in_worker(|worker, _| { let mut result_b = None; let latch_b = ThreadWakeLatch::new(worker); let task_b = pin!(StackTask::new(|| { result_b = Some(g()); unsafe { Latch::set_raw(&latch_b); } })); let ref_b = task_b.as_ref().as_task_ref(); let b_id = ref_b.id(); worker.push_task(ref_b); let result_a = f(); while !latch_b.probe() { match worker.pop_task() { Some(task) => { if task.id() == b_id { // we're not calling execute() here, so manually try // shoving a task. //worker.try_promote(); worker.shove_task(); unsafe { task_b.run_as_ref(); } break; } else { worker.execute(task); } } None => { worker.run_until(&latch_b); } } } (result_a, result_b.unwrap()) }) } pub fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T where Fn: FnOnce(&Scope<'scope>) -> T + Send, T: Send, { self.in_worker(|owner, _| { let scope = unsafe { Scope::<'scope>::new(owner) }; let result = f(&scope); scope.complete(owner); result }) } } pub struct WorkerThread { // queue: TaskQueue, worker: Worker, pool: &'static ThreadPool, index: usize, rng: rng::XorShift64Star, } const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) }; std::thread_local! { static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())}; } impl WorkerThread { #[inline] fn info(&self) -> &ThreadState { &self.pool.threads[self.index as usize] } #[inline] fn pool(&self) -> &'static ThreadPool { self.pool } #[inline] #[allow(dead_code)] fn index(&self) -> usize { self.index } #[inline] fn is_worker_thread() -> bool { Self::with(|worker| worker.is_some()) } fn with) -> T>(f: F) -> T { WORKER_THREAD_STATE.with(|thread| { f(NonNull::::new(thread.get().cast_mut()) .map(|ptr| unsafe { ptr.as_ref() })) }) } #[inline] fn pop_task(&self) -> Option { self.worker.pop() //self.queue.pop_front(task); } #[inline] fn push_task(&self, task: TaskRef) { self.worker.push(task); //self.queue.push_front(task); } #[inline] fn drain(&self) -> impl Iterator { // self.queue.drain() core::iter::empty() } #[inline] fn steal_tasks(&self) -> Option { // careful not to call threads() here because that omits any threads // that were killed, which might still have tasks. let threads = &self.pool.threads; let (start, end) = threads.split_at(self.rng.next_usize(threads.len())); end.iter() .chain(start) .find_map(|thread: &CachePadded| { thread.stealer.steal_batch_and_pop(&self.worker).success() }) } #[inline] fn claim_shoved_task(&self) -> Option { // take own shoved task first if let Some(task) = self.info().shoved_task.try_take() { return Some(task); } let threads = self.pool.threads(); if threads.is_empty() { return None; } let (start, end) = threads.split_at(self.rng.next_usize(threads.len())); end.iter() .chain(start) .find_map(|thread| thread.shoved_task.try_take()) } #[cold] fn shove_task(&self) { if !self.info().shoved_task.is_occupied() { if let Some(task) = self.info().stealer.steal().success() { match self.info().shoved_task.try_put(task) { // shoved task is occupied, reinsert into queue // this really shouldn't happen Some(_task) => unreachable!(), None => {} } } } else { // wake thread to execute task self.pool.wake_any(1); } } fn execute(&self, task: TaskRef) { self.try_promote(); task.execute(); } #[inline] fn try_promote(&self) { #[cfg(feature = "heartbeat")] let should_shove = self.info().should_shove.load(Ordering::Acquire); #[cfg(not(feature = "heartbeat"))] let should_shove = true; if should_shove { #[cfg(feature = "heartbeat")] self.info().should_shove.store(false, Ordering::Release); self.shove_task(); } } #[inline] fn find_any_task(&self) -> Option { // TODO: attempt stealing work here, too. #[allow(unused_mut)] let mut task = self .pop_task() .or_else(|| self.claim_shoved_task()) .or_else(|| { self.pool .global_queue .steal_batch_and_pop(&self.worker) .success() }); #[cfg(feature = "work-stealing")] { task = task.or_else(|| self.steal_tasks()); } task } #[inline] fn run_until(&self, latch: &L) where L: Probe, { if !latch.probe() { self.run_until_cold(latch); } } #[cold] fn run_until_cold(&self, latch: &L) where L: Probe, { while !latch.probe() { self.run_until_inner(); } } #[inline] fn run_until_inner(&self) { match self.find_any_task() { Some(task) => { self.execute(task); } None => { //debug!("waiting for tasks"); self.info().control.wait_for_should_wake(); } } } fn worker_loop(pool: &'static ThreadPool, index: usize) { let info = &pool.threads()[index as usize]; let worker = CachePadded::new(WorkerThread { // queue: TaskQueue::new(), worker: info.worker.take().unwrap(), pool, index, rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64), }); WORKER_THREAD_STATE.with(|cell| { cell.set(&*worker); if let Some(callback) = pool.callbacks.at_entry.as_ref() { callback(&worker); } info.control.notify_running(); // info.notify_running(); worker.run_until(&info.control.should_terminate); if let Some(callback) = pool.callbacks.at_exit.as_ref() { callback(&worker); } for task in worker.drain() { pool.inject(task); } if let Some(task) = info.shoved_task.try_take() { pool.inject(task); } cell.set(ptr::null()); }); let WorkerThread { worker, .. } = CachePadded::into_inner(worker); info.worker.store(Some(worker)); info.control.notify_termination(); } } fn heartbeat_loop(pool: &'static ThreadPool) { let state = &pool.pool_state.heartbeat_state; state.notify_running(); let mut i = 0; while !state.should_terminate.probe() { let threads = pool.threads(); if threads.is_empty() { break; } if i >= threads.len() { i = 0; continue; } threads[i].should_shove.store(true, Ordering::Relaxed); i += 1; let interval = HEARTBEAT_INTERVAL / threads.len() as u32; state.wait_for_should_wake_timeout(interval); } state.notify_termination(); } mod vec_queue { use std::{cell::UnsafeCell, collections::VecDeque}; pub struct TaskQueue(UnsafeCell>); impl TaskQueue { /// Creates a new [`TaskQueue`]. #[inline] #[allow(dead_code)] pub const fn new() -> Self { Self(UnsafeCell::new(VecDeque::new())) } #[inline] #[allow(dead_code)] pub fn get_mut(&self) -> &mut VecDeque { unsafe { &mut *self.0.get() } } #[inline] #[allow(dead_code)] pub fn pop_front(&self) -> Option { self.get_mut().pop_front() } #[inline] #[allow(dead_code)] pub fn pop_back(&self) -> Option { self.get_mut().pop_back() } #[inline] #[allow(dead_code)] pub fn push_back(&self, t: T) { self.get_mut().push_back(t); } #[inline] #[allow(dead_code)] pub fn push_front(&self, t: T) { self.get_mut().push_front(t); } #[inline] #[allow(dead_code)] pub fn take(&self) -> VecDeque { let this = core::mem::replace(self.get_mut(), VecDeque::new()); this } #[inline] #[allow(dead_code)] pub fn drain(&self) -> impl Iterator { self.take().into_iter() } } } #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SlotState { None, Locked, Occupied, } impl From for SlotState { fn from(value: u8) -> Self { unsafe { core::mem::transmute(value) } } } impl From for u8 { fn from(value: SlotState) -> Self { value as u8 } } pub struct Slot { slot: UnsafeCell>, state: AtomicU8, } unsafe impl Send for Slot where T: Send {} unsafe impl Sync for Slot where T: Send {} impl Drop for Slot { fn drop(&mut self) { if core::mem::needs_drop::() { if *self.state.get_mut() == SlotState::Occupied as u8 { unsafe { self.slot.get().drop_in_place(); } } } } } impl Slot { pub const fn new() -> Slot { Self { slot: UnsafeCell::new(MaybeUninit::uninit()), state: AtomicU8::new(SlotState::None as u8), } } pub fn is_occupied(&self) -> bool { self.state.load(Ordering::Acquire) == SlotState::Occupied.into() } #[inline] pub fn insert(&self, t: T) -> Option { let value = match self .state .swap(SlotState::Locked.into(), Ordering::AcqRel) .into() { SlotState::Locked => { // return early: was already locked. debug!("slot was already locked"); return None; } SlotState::Occupied => { let slot = self.slot.get(); // replace unsafe { let v = (*slot).assume_init_read(); (*slot).write(t); Some(v) } } SlotState::None => { let slot = self.slot.get(); // insert unsafe { (*slot).write(t); } None } }; // release lock self.state .store(SlotState::Occupied.into(), Ordering::Release); value } #[inline] pub fn try_put(&self, t: T) -> Option { match self.state.compare_exchange( SlotState::None.into(), SlotState::Locked.into(), Ordering::Acquire, Ordering::Relaxed, ) { Err(_) => Some(t), Ok(_) => { let slot = self.slot.get(); // SAFETY: we hold LOCKED on the spinlock unsafe { (*slot).write(t) }; // release lock self.state .store(SlotState::Occupied.into(), Ordering::Release); None } } } #[inline] pub fn try_take(&self) -> Option { match self.state.compare_exchange( SlotState::Occupied.into(), SlotState::Locked.into(), Ordering::Acquire, Ordering::Relaxed, ) { Ok(_) => { let slot = self.slot.get(); // SAFETY: we hold LOCKED on the spinlock let t = unsafe { (*slot).assume_init_read() }; // release lock self.state.store(SlotState::None.into(), Ordering::Release); Some(t) } Err(_) => None, } } } mod rng { use core::cell::Cell; pub struct XorShift64Star { state: Cell, } impl XorShift64Star { /// Initializes the prng with a seed. Provided seed must be nonzero. pub fn new(seed: u64) -> Self { XorShift64Star { state: Cell::new(seed), } } /// Returns a pseudorandom number. pub fn next(&self) -> u64 { let mut x = self.state.get(); debug_assert_ne!(x, 0); x ^= x >> 12; x ^= x << 25; x ^= x >> 27; self.state.set(x); x.wrapping_mul(0x2545_f491_4f6c_dd1d) } /// Return a pseudorandom number from `0..n`. pub fn next_usize(&self, n: usize) -> usize { (self.next() % n as u64) as usize } } } pub mod scope { use std::{ future::Future, marker::PhantomData, ptr::{self, NonNull}, }; use async_task::{Runnable, Task}; use crate::{ latch::{CountWakeLatch, Latch}, task::{HeapTask, TaskRef}, ThreadPool, WorkerThread, }; pub struct Scope<'scope> { pool: &'static ThreadPool, tasks_completed_latch: CountWakeLatch, _marker: PhantomData) + Send + Sync + 'scope>>, } impl<'scope> Scope<'scope> { pub unsafe fn new(owner: &WorkerThread) -> Scope<'scope> { Scope { pool: owner.pool(), tasks_completed_latch: CountWakeLatch::new(1, owner), _marker: PhantomData, } } pub fn join(&self, f: F, g: G) -> (T, U) where F: FnOnce(&Self) -> T + Send, G: FnOnce(&Self) -> U + Send, T: Send, U: Send, { self.pool.join(|| f(self), || g(self)) } pub fn spawn(&self, f: Fn) where Fn: FnOnce(&Scope<'scope>) + Send + 'scope, { self.tasks_completed_latch.increment(); let ptr = SendPtr::from_ref(self); let task = HeapTask::new(move || unsafe { let this = ptr.as_ref(); f(this); Latch::set_raw(&this.tasks_completed_latch); }); let taskref = unsafe { task.into_task_ref() }; self.pool.inject_maybe_local(taskref); } pub fn spawn_future(&self, future: Fut) -> Task where Fut: Future + Send + 'scope, T: Send + 'scope, { self.tasks_completed_latch.increment(); let ptr = SendPtr::from_ref(self); let schedule = move |runnable: Runnable| { let taskref = unsafe { TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| { let this = NonNull::new_unchecked(this.cast_mut()); let runnable = Runnable::<()>::from_raw(this); runnable.run(); }) }; unsafe { ptr.as_ref().pool.inject_maybe_local(taskref); } }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; runnable.schedule(); task } pub fn spawn_async(&self, f: Fn) -> Task where Fn: FnOnce() -> Fut + Send + 'scope, Fut: Future + Send + 'scope, T: Send + 'scope, { self.spawn_future(async move { f().await }) } pub fn complete(&self, owner: &WorkerThread) { unsafe { Latch::set_raw(&self.tasks_completed_latch); } owner.run_until(&self.tasks_completed_latch); } } struct SendPtr(*const T); impl SendPtr { fn from_ref(t: &T) -> Self { Self(ptr::from_ref(t).cast()) } unsafe fn as_ref(&self) -> &T { &*self.0 } } unsafe impl Send for SendPtr {} } // #[cfg(test)] // mod tests { // use std::{cell::Cell, hint::black_box, time::Instant}; // use tracing::info; // use super::*; // mod tree { // pub struct Tree { // nodes: Box<[Node]>, // root: Option, // } // pub struct Node { // pub leaf: T, // pub left: Option, // pub right: Option, // } // impl Tree { // pub fn new(depth: usize, t: T) -> Tree // where // T: Copy, // { // let mut nodes = Vec::with_capacity((0..depth).sum()); // let root = Self::build_node(&mut nodes, depth, t); // Self { // nodes: nodes.into_boxed_slice(), // root: Some(root), // } // } // pub fn root(&self) -> Option { // self.root // } // pub fn get(&self, index: usize) -> &Node { // &self.nodes[index] // } // pub fn build_node(nodes: &mut Vec>, depth: usize, t: T) -> usize // where // T: Copy, // { // let node = Node { // leaf: t, // left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), // right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), // }; // nodes.push(node); // nodes.len() - 1 // } // } // } // const PRIMES: &'static [usize] = &[ // 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, // 1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, // 1423, 1427, 1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, // 1499, 1511, 1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, 1607, // 1609, 1613, 1619, 1621, 1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721, // 1723, 1733, 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847, // 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, // ]; // #[cfg(feature = "spin-slow")] // const REPEAT: usize = 0x800; // #[cfg(not(feature = "spin-slow"))] // const REPEAT: usize = 0x8000; // const TREE_SIZE: usize = 10; // fn run_in_scope(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T { // let pool = Box::new(pool); // let ptr = Box::into_raw(pool); // let result = { // let pool: &'static ThreadPool = unsafe { &*ptr }; // // pool.ensure_one_worker(); // pool.resize_to_available(); // let now = std::time::Instant::now(); // let result = pool.scope(f); // let elapsed = now.elapsed().as_micros(); // info!("(mine) total time: {}ms", elapsed as f32 / 1e3); // pool.resize_to(0); // assert!(pool.global_queue.is_empty()); // result // }; // let _pool = unsafe { Box::from_raw(ptr) }; // result // } // #[test] // #[tracing_test::traced_test] // fn rayon() { // let pool = rayon::ThreadPoolBuilder::new() // .num_threads(bevy_tasks::available_parallelism()) // .build() // .unwrap(); // let now = std::time::Instant::now(); // pool.scope(|s| { // for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { // s.spawn(move |_| { // black_box(spinning(p)); // }); // } // }); // let elapsed = now.elapsed().as_micros(); // info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); // } // #[test] // #[tracing_test::traced_test] // fn rayon_join() { // let pool = rayon::ThreadPoolBuilder::new() // .num_threads(bevy_tasks::available_parallelism()) // .build() // .unwrap(); // let tree = tree::Tree::new(TREE_SIZE, 1u32); // fn sum(tree: &tree::Tree, node: usize) -> u32 { // let node = tree.get(node); // let (l, r) = rayon::join( // || node.left.map(|node| sum(tree, node)).unwrap_or_default(), // || node.right.map(|node| sum(tree, node)).unwrap_or_default(), // ); // node.leaf + l + r // } // let now = std::time::Instant::now(); // let sum = pool.scope(move |s| { // let root = tree.root().unwrap(); // sum(&tree, root) // }); // let elapsed = now.elapsed().as_micros(); // info!("(rayon) {sum} total time: {}ms", elapsed as f32 / 1e3); // } // #[test] // #[tracing_test::traced_test] // fn bevy_tasks() { // let pool = bevy_tasks::ComputeTaskPool::get_or_init(|| { // bevy_tasks::TaskPoolBuilder::new() // .num_threads(bevy_tasks::available_parallelism()) // .build() // }); // let now = std::time::Instant::now(); // pool.scope(|s| { // for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { // s.spawn(async move { // black_box(spinning(p)); // }); // } // }); // let elapsed = now.elapsed().as_micros(); // info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); // } // #[test] // #[tracing_test::traced_test] // fn mine() { // std::thread_local! { // static WAIT_COUNT: Cell = const {Cell::new(0)}; // } // let counter = Arc::new(AtomicUsize::new(0)); // { // let pool = ThreadPool::new(); // run_in_scope(pool, |s| { // for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { // s.spawn(move |_| { // black_box(spinning(p)); // }); // } // }); // }; // // eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); // } // #[test] // #[tracing_test::traced_test] // fn mine_join() { // let pool = ThreadPool::new(); // let tree = tree::Tree::new(TREE_SIZE, 1u32); // fn sum(tree: &tree::Tree, node: usize, scope: &Scope<'_>) -> u32 { // let node = tree.get(node); // let (l, r) = scope.join( // |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), // |s| { // node.right // .map(|node| sum(tree, node, s)) // .unwrap_or_default() // }, // ); // node.leaf + l + r // } // let sum = run_in_scope(pool, move |s| { // let root = tree.root().unwrap(); // sum(&tree, root, s) // }); // } // #[test] // #[tracing_test::traced_test] // fn melange_join() { // let pool = melange::ThreadPool::new(bevy_tasks::available_parallelism()); // let mut scope = pool.new_worker(); // let tree = tree::Tree::new(TREE_SIZE, 1u32); // fn sum(tree: &tree::Tree, node: usize, scope: &mut melange::WorkerThread) -> u32 { // let node = tree.get(node); // let (l, r) = scope.join( // |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), // |s| { // node.right // .map(|node| sum(tree, node, s)) // .unwrap_or_default() // }, // ); // node.leaf + l + r // } // let now = Instant::now(); // let res = sum(&tree, tree.root().unwrap(), &mut scope); // eprintln!( // "res: {res} took {}ms", // now.elapsed().as_micros() as f32 / 1e3 // ); // assert_ne!(res, 0); // } // #[test] // #[tracing_test::traced_test] // fn sync() { // let now = std::time::Instant::now(); // for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { // black_box(spinning(p)); // } // let elapsed = now.elapsed().as_micros(); // info!("(sync) total time: {}ms", elapsed as f32 / 1e3); // } // #[inline] // fn spinning(i: usize) { // #[cfg(feature = "spin-slow")] // spinning_slow(i); // #[cfg(not(feature = "spin-slow"))] // spinning_fast(i); // } // #[inline] // fn spinning_slow(i: usize) { // let rng = rng::XorShift64Star::new(i as u64); // (0..i).reduce(|a, b| { // black_box({ // let a = rng.next_usize(a.max(1)); // ((b as f32).exp() * (a as f32).sin().cbrt()).to_bits() as usize // }) // }); // } // #[inline] // fn spinning_fast(i: usize) { // let rng = rng::XorShift64Star::new(i as u64); // //(0..rng.next_usize(i)).reduce(|a, b| { // (0..20).reduce(|a, b| { // black_box({ // let a = rng.next_usize(a.max(1)); // a ^ b // }) // }); // } // }