From 736e4e1a605a8bcff8dcd5e1186129f8c741d7e7 Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 31 Jan 2025 16:30:22 +0100 Subject: [PATCH] idk this sucks --- Cargo.toml | 7 +- src/lib.rs | 626 +++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 466 insertions(+), 167 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 641c08c..31f526e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,8 +4,12 @@ version = "0.1.0" edition = "2021" [features] -internal_heartbeat = [] +heartbeat = [] +spin-slow = [] cpu-pinning = [] +work-stealing = [] +prefer-local = [] +never-local = [] [dependencies] @@ -16,6 +20,7 @@ bevy_tasks = "0.15.1" parking_lot = "0.12.3" thread_local = "1.1.8" crossbeam = "0.8.4" +st3 = "0.4" async-task = "4.7.1" diff --git a/src/lib.rs b/src/lib.rs index 09faa9e..068f513 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,10 @@ use std::{ - cell::{OnceCell, UnsafeCell}, - collections::VecDeque, + cell::{Cell, UnsafeCell}, future::Future, mem::MaybeUninit, num::NonZero, pin::{pin, Pin}, - ptr::NonNull, + ptr::{self, NonNull}, sync::{ atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, Arc, @@ -17,7 +16,11 @@ use std::{ use async_task::{Runnable, Task}; use bitflags::bitflags; -use crossbeam::{queue::SegQueue, utils::CachePadded}; +use crossbeam::{ + atomic::AtomicCell, + deque::{Injector, Stealer, Worker}, + utils::CachePadded, +}; use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; use parking_lot::{Condvar, Mutex}; use scope::Scope; @@ -337,7 +340,7 @@ pub mod latch { pub struct ThreadPoolState { num_threads: AtomicUsize, lock: Mutex<()>, - heartbeat_state: CachePadded, + heartbeat_state: CachePadded, } bitflags! { @@ -348,15 +351,21 @@ bitflags! { } } -pub struct ThreadState { - should_shove: AtomicBool, - shoved_task: Slot, +pub struct ThreadControl { status: Mutex, status_changed: Condvar, should_terminate: AtomicLatch, } -impl ThreadState { +pub struct ThreadState { + should_shove: AtomicBool, + control: ThreadControl, + stealer: Stealer, + worker: AtomicCell>>, + shoved_task: CachePadded>, +} + +impl ThreadControl { /// returns true if thread was sleeping #[inline] fn wake(&self) -> bool { @@ -451,40 +460,48 @@ impl ThreadPoolCallbacks { pub struct ThreadPool { threads: [CachePadded; MAX_THREADS], pool_state: CachePadded, - global_queue: SegQueue, + global_queue: Injector, callbacks: CachePadded, } impl ThreadPool { - const INITIAL_THREAD_STATE: CachePadded = CachePadded::new(ThreadState { - should_shove: AtomicBool::new(false), - shoved_task: Slot::new(), - status: Mutex::new(ThreadStatus::empty()), - status_changed: Condvar::new(), - should_terminate: AtomicLatch::new(), - }); - pub const fn new() -> Self { - Self { - threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, - pool_state: CachePadded::new(ThreadPoolState { - num_threads: AtomicUsize::new(0), - lock: Mutex::new(()), - heartbeat_state: Self::INITIAL_THREAD_STATE, - }), - global_queue: SegQueue::new(), - callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()), - } + pub fn new() -> Self { + Self::new_with_callbacks(ThreadPoolCallbacks::new_empty()) } - pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { + pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { + let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| { + let worker = Worker::::new_fifo(); + let stealer = worker.stealer(); + + let thread = CachePadded::new(ThreadState { + should_shove: AtomicBool::new(false), + shoved_task: Slot::new().into(), + control: ThreadControl { + status: Mutex::new(ThreadStatus::empty()), + status_changed: Condvar::new(), + should_terminate: AtomicLatch::new(), + }, + stealer, + worker: AtomicCell::new(Some(worker)), + }); + uninit.write(thread); + unsafe { uninit.assume_init() } + }); + Self { - threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, + threads, pool_state: CachePadded::new(ThreadPoolState { num_threads: AtomicUsize::new(0), lock: Mutex::new(()), - heartbeat_state: Self::INITIAL_THREAD_STATE, + heartbeat_state: ThreadControl { + status: Mutex::new(ThreadStatus::empty()), + status_changed: Condvar::new(), + should_terminate: AtomicLatch::new(), + } + .into(), }), - global_queue: SegQueue::new(), + global_queue: Injector::new(), callbacks: CachePadded::new(callbacks), } } @@ -495,7 +512,7 @@ impl ThreadPool { } pub fn wake_thread(&self, index: usize) -> Option { - Some(self.threads.get(index as usize)?.wake()) + Some(self.threads.get(index as usize)?.control.wake()) } pub fn wake_any(&self, count: usize) -> usize { @@ -503,7 +520,7 @@ impl ThreadPool { let num_woken = self .threads .iter() - .filter_map(|thread| thread.wake().then_some(())) + .filter_map(|thread| thread.control.wake().then_some(())) .take(count) .count(); num_woken @@ -517,6 +534,27 @@ impl ThreadPool { core::ptr::from_ref(self) as usize } + fn push_local_or_inject_balanced(&self, task: TaskRef) { + let global_len = self.global_queue.len(); + WorkerThread::with(|worker| match worker { + Some(worker) if worker.pool.id() == self.id() => { + let worker_len = worker.worker.len(); + if worker_len == 0 { + worker.push_task(task); + } else if global_len == 0 { + self.inject(task); + } else { + if worker_len >= global_len { + worker.push_task(task); + } else { + self.inject(task); + } + } + } + _ => self.inject(task), + }) + } + fn push_local_or_inject(&self, task: TaskRef) { WorkerThread::with(|worker| match worker { Some(worker) if worker.pool.id() == self.id() => worker.push_task(task), @@ -536,6 +574,15 @@ impl ThreadPool { self.wake_any(n); } + fn inject_maybe_local(&self, task: TaskRef) { + #[cfg(all(not(feature = "never-local"), feature = "prefer-local"))] + self.push_local_or_inject(task); + #[cfg(all(not(feature = "prefer-local"), feature = "never-local"))] + self.inject(task); + #[cfg(not(any(feature = "prefer-local", feature = "never-local")))] + self.push_local_or_inject_balanced(task); + } + fn inject(&self, task: TaskRef) { self.global_queue.push(task); @@ -581,10 +628,10 @@ impl ThreadPool { } for thread in new_threads { - thread.wait_for_running(); + thread.control.wait_for_running(); } - #[cfg(not(feature = "internal_heartbeat"))] + #[cfg(feature = "heartbeat")] if current_size == 0 { std::thread::spawn(move || { heartbeat_loop(self); @@ -601,13 +648,13 @@ impl ThreadPool { let terminating_threads = &self.threads[new_size..current_size]; for thread in terminating_threads { - thread.notify_should_terminate(); + thread.control.notify_should_terminate(); } for thread in terminating_threads { - thread.wait_for_termination(); + thread.control.wait_for_termination(); } - #[cfg(not(feature = "internal_heartbeat"))] + #[cfg(feature = "heartbeat")] if new_size == 0 { self.pool_state.heartbeat_state.notify_should_terminate(); self.pool_state.heartbeat_state.wait_for_termination(); @@ -712,7 +759,7 @@ impl ThreadPool { })); let taskref = task.into_ref().as_task_ref(); - self.inject(taskref); + self.push_local_or_inject(taskref); worker.run_until(&latch); result.unwrap() @@ -727,7 +774,7 @@ impl ThreadPool { let task = HeapTask::new(f); let taskref = unsafe { task.into_static_task_ref() }; - self.push_local_or_inject(taskref); + self.inject_maybe_local(taskref); } fn spawn_future(&'static self, future: Fut) -> Task @@ -745,7 +792,7 @@ impl ThreadPool { }) }; - self.push_local_or_inject(taskref); + self.inject_maybe_local(taskref); }; let (runnable, task) = async_task::spawn(future, schedule); @@ -789,6 +836,30 @@ impl ThreadPool { } fn join(&'static self, f: F, g: G) -> (T, U) + where + F: FnOnce() -> T + Send, + G: FnOnce() -> U + Send, + T: Send, + U: Send, + { + self.join_threaded(f, g) + } + + fn join_seq(&'static self, f: F, g: G) -> (T, U) + where + F: FnOnce() -> T + Send, + G: FnOnce() -> U + Send, + T: Send, + U: Send, + { + let a = f(); + let b = g(); + + (a, b) + } + + #[inline] + fn join_threaded(&'static self, f: F, g: G) -> (T, U) where F: FnOnce() -> T + Send, G: FnOnce() -> U + Send, @@ -808,7 +879,6 @@ impl ThreadPool { let ref_b = task_b.as_ref().as_task_ref(); let b_id = ref_b.id(); - // TODO: maybe try to push this off to another thread immediately first? worker.push_task(ref_b); let result_a = f(); @@ -817,13 +887,17 @@ impl ThreadPool { match worker.pop_task() { Some(task) => { if task.id() == b_id { - worker.try_promote(); + // we're not calling execute() here, so manually try + // shoving a task. + //worker.try_promote(); + worker.shove_task(); unsafe { task_b.run_as_ref(); } break; + } else { + worker.execute(task); } - worker.execute(task); } None => { worker.run_until(&latch_b); @@ -837,12 +911,12 @@ impl ThreadPool { fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T where - Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send, + Fn: FnOnce(&Scope<'scope>) -> T + Send, T: Send, { self.in_worker(|owner, _| { - let scope = pin!(unsafe { Scope::<'scope>::new(owner) }); - let result = f(scope.as_ref()); + let scope = unsafe { Scope::<'scope>::new(owner) }; + let result = f(&scope); scope.complete(owner); result }) @@ -850,7 +924,8 @@ impl ThreadPool { } pub struct WorkerThread { - queue: TaskQueue, + // queue: TaskQueue, + worker: Worker, pool: &'static ThreadPool, index: usize, rng: rng::XorShift64Star, @@ -860,7 +935,7 @@ pub struct WorkerThread { const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) }; std::thread_local! { - static WORKER_THREAD_STATE: CachePadded> = const {CachePadded::new(OnceCell::new())}; + static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())}; } impl WorkerThread { @@ -880,25 +955,47 @@ impl WorkerThread { fn is_worker_thread() -> bool { Self::with(|worker| worker.is_some()) } + fn with) -> T>(f: F) -> T { - WORKER_THREAD_STATE.with(|thread| f(thread.get())) + WORKER_THREAD_STATE.with(|thread| { + f(NonNull::::new(thread.get().cast_mut()) + .map(|ptr| unsafe { ptr.as_ref() })) + }) } #[inline] fn pop_task(&self) -> Option { - self.queue.pop_front() + self.worker.pop() + //self.queue.pop_front(task); } #[inline] fn push_task(&self, task: TaskRef) { - self.queue.push_front(task); + self.worker.push(task); + //self.queue.push_front(task); } #[inline] fn drain(&self) -> impl Iterator { - self.queue.drain() + // self.queue.drain() + core::iter::empty() + } + + #[inline] + fn steal_tasks(&self) -> Option { + // careful not to call threads() here because that omits any threads + // that were killed, which might still have tasks. + let threads = &self.pool.threads; + let (start, end) = threads.split_at(self.rng.next_usize(threads.len())); + + end.iter() + .chain(start) + .find_map(|thread: &CachePadded| { + thread.stealer.steal_batch_and_pop(&self.worker).success() + }) } #[inline] fn claim_shoved_task(&self) -> Option { + // take own shoved task first if let Some(task) = self.info().shoved_task.try_take() { return Some(task); } @@ -916,12 +1013,16 @@ impl WorkerThread { #[cold] fn shove_task(&self) { - if let Some(task) = self.queue.pop_back() { - match self.info().shoved_task.try_put(task) { - // shoved task is occupied, reinsert into queue - Some(task) => self.queue.push_back(task), - None => {} + if !self.info().shoved_task.is_occupied() { + if let Some(task) = self.info().stealer.steal().success() { + match self.info().shoved_task.try_put(task) { + // shoved task is occupied, reinsert into queue + // this really shouldn't happen + Some(_task) => unreachable!(), + None => {} + } } + } else { // wake thread to execute task self.pool.wake_any(1); } @@ -934,24 +1035,15 @@ impl WorkerThread { #[inline] fn try_promote(&self) { - #[cfg(feature = "internal_heartbeat")] - let now = std::time::Instant::now(); - // SAFETY: workerthread is thread-local non-sync - - #[cfg(feature = "internal_heartbeat")] - let should_shove = - unsafe { *self.last_heartbeat.get() }.duration_since(now) > HEARTBEAT_INTERVAL; - #[cfg(not(feature = "internal_heartbeat"))] + #[cfg(feature = "heartbeat")] let should_shove = self.info().should_shove.load(Ordering::Acquire); + #[cfg(not(feature = "heartbeat"))] + let should_shove = true; if should_shove { - // SAFETY: workerthread is thread-local non-sync - #[cfg(feature = "internal_heartbeat")] - unsafe { - *&mut *self.last_heartbeat.get() = now; - } - #[cfg(not(feature = "internal_heartbeat"))] + #[cfg(feature = "heartbeat")] self.info().should_shove.store(false, Ordering::Release); + self.shove_task(); } } @@ -959,9 +1051,22 @@ impl WorkerThread { #[inline] fn find_any_task(&self) -> Option { // TODO: attempt stealing work here, too. - self.pop_task() + let mut task = self + .pop_task() .or_else(|| self.claim_shoved_task()) - .or_else(|| self.pool.global_queue.pop()) + .or_else(|| { + self.pool + .global_queue + .steal_batch_and_pop(&self.worker) + .success() + }); + + #[cfg(feature = "work-stealing")] + { + task = task.or_else(|| self.steal_tasks()); + } + + task } #[inline] @@ -991,34 +1096,36 @@ impl WorkerThread { self.execute(task); } None => { - debug!("waiting for tasks"); - self.info().wait_for_should_wake(); + //debug!("waiting for tasks"); + self.info().control.wait_for_should_wake(); } } } fn worker_loop(pool: &'static ThreadPool, index: usize) { let info = &pool.threads()[index as usize]; + let worker = CachePadded::new(WorkerThread { + // queue: TaskQueue::new(), + worker: info.worker.take().unwrap(), + pool, + index, + rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64), + last_heartbeat: UnsafeCell::new(std::time::Instant::now()), + }); - WORKER_THREAD_STATE.with(|worker| { - let worker = worker.get_or_init(|| WorkerThread { - queue: TaskQueue::new(), - pool, - index, - rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64), - last_heartbeat: UnsafeCell::new(std::time::Instant::now()), - }); + WORKER_THREAD_STATE.with(|cell| { + cell.set(&*worker); if let Some(callback) = pool.callbacks.at_entry.as_ref() { - callback(worker); + callback(&worker); } - info.notify_running(); + info.control.notify_running(); // info.notify_running(); - worker.run_until(&info.should_terminate); + worker.run_until(&info.control.should_terminate); if let Some(callback) = pool.callbacks.at_exit.as_ref() { - callback(worker); + callback(&worker); } for task in worker.drain() { @@ -1028,9 +1135,14 @@ impl WorkerThread { if let Some(task) = info.shoved_task.try_take() { pool.inject(task); } + + cell.set(ptr::null()); }); - info.notify_termination(); + let WorkerThread { worker, .. } = CachePadded::into_inner(worker); + info.worker.store(Some(worker)); + + info.control.notify_termination(); } } @@ -1061,56 +1173,68 @@ fn heartbeat_loop(pool: &'static ThreadPool) { state.notify_termination(); } -pub struct TaskQueue(UnsafeCell>); +use vec_queue::TaskQueue; -impl TaskQueue { - /// Creates a new [`TaskQueue`]. - #[inline] - const fn new() -> Self { - Self(UnsafeCell::new(VecDeque::new())) - } - #[inline] - fn get_mut(&self) -> &mut VecDeque { - unsafe { &mut *self.0.get() } - } - #[inline] - fn pop_front(&self) -> Option { - self.get_mut().pop_front() - } - #[inline] - fn pop_back(&self) -> Option { - self.get_mut().pop_back() - } - #[inline] - fn push_back(&self, t: T) { - self.get_mut().push_back(t); - } - #[inline] - fn push_front(&self, t: T) { - self.get_mut().push_front(t); - } - #[inline] - fn take(&self) -> VecDeque { - let this = core::mem::replace(self.get_mut(), VecDeque::new()); - this - } - #[inline] - fn drain(&self) -> impl Iterator { - self.take().into_iter() +mod vec_queue { + use std::{cell::UnsafeCell, collections::VecDeque}; + + pub struct TaskQueue(UnsafeCell>); + + impl TaskQueue { + /// Creates a new [`TaskQueue`]. + #[inline] + pub const fn new() -> Self { + Self(UnsafeCell::new(VecDeque::new())) + } + #[inline] + pub fn get_mut(&self) -> &mut VecDeque { + unsafe { &mut *self.0.get() } + } + #[inline] + pub fn pop_front(&self) -> Option { + self.get_mut().pop_front() + } + #[inline] + pub fn pop_back(&self) -> Option { + self.get_mut().pop_back() + } + #[inline] + pub fn push_back(&self, t: T) { + self.get_mut().push_back(t); + } + #[inline] + pub fn push_front(&self, t: T) { + self.get_mut().push_front(t); + } + #[inline] + pub fn take(&self) -> VecDeque { + let this = core::mem::replace(self.get_mut(), VecDeque::new()); + this + } + #[inline] + pub fn drain(&self) -> impl Iterator { + self.take().into_iter() + } } } -bitflags! { - #[derive(Debug, Clone, Copy)] - pub struct SlotState: u8 { - const LOCKED = 1 << 1; - const OCCUPIED = 1 << 2; +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SlotState { + None, + Locked, + Occupied, +} + +impl From for SlotState { + fn from(value: u8) -> Self { + unsafe { core::mem::transmute(value) } } } impl From for u8 { fn from(value: SlotState) -> Self { - value.bits() + value as u8 } } @@ -1125,10 +1249,7 @@ unsafe impl Sync for Slot where T: Send {} impl Drop for Slot { fn drop(&mut self) { if core::mem::needs_drop::() { - if SlotState::from_bits(*self.state.get_mut()) - .unwrap() - .contains(SlotState::OCCUPIED) - { + if *self.state.get_mut() == SlotState::Occupied as u8 { unsafe { self.slot.get().drop_in_place(); } @@ -1141,15 +1262,56 @@ impl Slot { pub const fn new() -> Slot { Self { slot: UnsafeCell::new(MaybeUninit::uninit()), - state: AtomicU8::new(SlotState::empty().bits()), + state: AtomicU8::new(SlotState::None as u8), } } + pub fn is_occupied(&self) -> bool { + self.state.load(Ordering::Acquire) == SlotState::Occupied.into() + } + + #[inline] + pub fn insert(&self, t: T) -> Option { + let value = match self + .state + .swap(SlotState::Locked.into(), Ordering::AcqRel) + .into() + { + SlotState::Locked => { + // return early: was already locked. + debug!("slot was already locked"); + return None; + } + SlotState::Occupied => { + let slot = self.slot.get(); + // replace + unsafe { + let v = (*slot).assume_init_read(); + (*slot).write(t); + Some(v) + } + } + SlotState::None => { + let slot = self.slot.get(); + // insert + unsafe { + (*slot).write(t); + } + None + } + }; + + // release lock + self.state + .store(SlotState::Occupied.into(), Ordering::Release); + value + } + #[inline] pub fn try_put(&self, t: T) -> Option { match self.state.compare_exchange( - SlotState::empty().into(), - SlotState::LOCKED.into(), + SlotState::None.into(), + SlotState::Locked.into(), Ordering::Acquire, Ordering::Relaxed, ) { @@ -1161,7 +1323,7 @@ impl Slot { // release lock self.state - .store(SlotState::OCCUPIED.into(), Ordering::Release); + .store(SlotState::Occupied.into(), Ordering::Release); None } } @@ -1170,8 +1332,8 @@ impl Slot { #[inline] pub fn try_take(&self) -> Option { match self.state.compare_exchange( - SlotState::OCCUPIED.into(), - SlotState::LOCKED.into(), + SlotState::Occupied.into(), + SlotState::Locked.into(), Ordering::Acquire, Ordering::Relaxed, ) { @@ -1181,8 +1343,7 @@ impl Slot { let t = unsafe { (*slot).assume_init_read() }; // release lock - self.state - .store(SlotState::empty().into(), Ordering::Release); + self.state.store(SlotState::None.into(), Ordering::Release); Some(t) } Err(_) => None, @@ -1227,14 +1388,15 @@ mod scope { use std::{ future::Future, marker::{PhantomData, PhantomPinned}, + pin::pin, ptr::{self, NonNull}, }; use async_task::{Runnable, Task}; use crate::{ - latch::{CountWakeLatch, Latch}, - task::{HeapTask, TaskRef}, + latch::{CountWakeLatch, Latch, Probe, ThreadWakeLatch}, + task::{HeapTask, StackTask, TaskRef}, ThreadPool, WorkerThread, }; @@ -1253,6 +1415,16 @@ mod scope { } } + pub fn join(&self, f: F, g: G) -> (T, U) + where + F: FnOnce(&Self) -> T + Send, + G: FnOnce(&Self) -> U + Send, + T: Send, + U: Send, + { + self.pool.join(|| f(self), || g(self)) + } + pub fn spawn(&self, f: Fn) where Fn: FnOnce(&Scope<'scope>) + Send + 'scope, @@ -1267,7 +1439,7 @@ mod scope { }); let taskref = unsafe { task.into_task_ref() }; - self.pool.push_local_or_inject(taskref); + self.pool.inject_maybe_local(taskref); } pub fn spawn_future(&self, future: Fut) -> Task @@ -1289,7 +1461,7 @@ mod scope { }; unsafe { - ptr.as_ref().pool.push_local_or_inject(taskref); + ptr.as_ref().pool.inject_maybe_local(taskref); } }; @@ -1332,8 +1504,58 @@ mod scope { mod tests { use std::{cell::Cell, hint::black_box}; + use tracing::info; + use super::*; + mod tree { + + pub struct Tree { + nodes: Box<[Node]>, + root: Option, + } + pub struct Node { + pub leaf: T, + pub left: Option, + pub right: Option, + } + + impl Tree { + pub fn new(depth: usize, t: T) -> Tree + where + T: Copy, + { + let mut nodes = Vec::with_capacity((0..depth).sum()); + let root = Self::build_node(&mut nodes, depth, t); + Self { + nodes: nodes.into_boxed_slice(), + root: Some(root), + } + } + + pub fn root(&self) -> Option { + self.root + } + + pub fn get(&self, index: usize) -> &Node { + &self.nodes[index] + } + + pub fn build_node(nodes: &mut Vec>, depth: usize, t: T) -> usize + where + T: Copy, + { + let node = Node { + leaf: t, + left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), + right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), + }; + nodes.push(node); + nodes.len() - 1 + } + } + } + const PRIMES: &'static [usize] = &[ 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, @@ -1344,9 +1566,14 @@ mod tests { 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, ]; - const REPEAT: usize = 0x100; + #[cfg(feature = "spin-slow")] + const REPEAT: usize = 0x800; + #[cfg(not(feature = "spin-slow"))] + const REPEAT: usize = 0x8000; - fn run_in_scope(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { + const TREE_SIZE: usize = 10; + + fn run_in_scope(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T { let pool = Box::new(pool); let ptr = Box::into_raw(pool); @@ -1357,9 +1584,9 @@ mod tests { let now = std::time::Instant::now(); let result = pool.scope(f); let elapsed = now.elapsed().as_micros(); - eprintln!("(mine) total time: {}ms", elapsed as f32 / 1e3); + info!("(mine) total time: {}ms", elapsed as f32 / 1e3); pool.resize_to(0); - assert!(pool.global_queue.pop().is_none()); + assert!(pool.global_queue.is_empty()); result }; @@ -1385,7 +1612,38 @@ mod tests { }); let elapsed = now.elapsed().as_micros(); - eprintln!("(rayon) total time: {}ms", elapsed as f32 / 1e3); + info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); + } + + #[test] + #[tracing_test::traced_test] + fn rayon_join() { + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(bevy_tasks::available_parallelism()) + .build() + .unwrap(); + + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum(tree: &tree::Tree, node: usize) -> u32 { + let node = tree.get(node); + let (l, r) = rayon::join( + || node.left.map(|node| sum(tree, node)).unwrap_or_default(), + || node.right.map(|node| sum(tree, node)).unwrap_or_default(), + ); + + node.leaf + l + r + } + + let now = std::time::Instant::now(); + let sum = pool.scope(move |s| { + let root = tree.root().unwrap(); + sum(&tree, root) + }); + + let elapsed = now.elapsed().as_micros(); + + info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); } #[test] @@ -1407,7 +1665,7 @@ mod tests { }); let elapsed = now.elapsed().as_micros(); - eprintln!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); + info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); } #[test] @@ -1418,18 +1676,7 @@ mod tests { } let counter = Arc::new(AtomicUsize::new(0)); { - let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks { - at_entry: Some(Arc::new(|_worker| { - // eprintln!("new worker thread: {}", worker.index); - })), - at_exit: Some(Arc::new({ - let counter = counter.clone(); - move |_worker: &WorkerThread| { - // eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get()); - counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed); - } - })), - }); + let pool = ThreadPool::new(); run_in_scope(pool, |s| { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { @@ -1443,6 +1690,33 @@ mod tests { // eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); } + #[test] + #[tracing_test::traced_test] + fn mine_join() { + let pool = ThreadPool::new(); + + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum(tree: &tree::Tree, node: usize, scope: &Scope<'_>) -> u32 { + let node = tree.get(node); + let (l, r) = scope.join( + |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), + |s| { + node.right + .map(|node| sum(tree, node, s)) + .unwrap_or_default() + }, + ); + + node.leaf + l + r + } + + let sum = run_in_scope(pool, move |s| { + let root = tree.root().unwrap(); + sum(&tree, root, s) + }); + } + #[test] #[tracing_test::traced_test] fn sync() { @@ -1452,11 +1726,19 @@ mod tests { } let elapsed = now.elapsed().as_micros(); - eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3); + info!("(sync) total time: {}ms", elapsed as f32 / 1e3); } #[inline] fn spinning(i: usize) { + #[cfg(feature = "spin-slow")] + spinning_slow(i); + #[cfg(not(feature = "spin-slow"))] + spinning_fast(i); + } + + #[inline] + fn spinning_slow(i: usize) { let rng = rng::XorShift64Star::new(i as u64); (0..i).reduce(|a, b| { black_box({ @@ -1465,4 +1747,16 @@ mod tests { }) }); } + + #[inline] + fn spinning_fast(i: usize) { + let rng = rng::XorShift64Star::new(i as u64); + //(0..rng.next_usize(i)).reduce(|a, b| { + (0..20).reduce(|a, b| { + black_box({ + let a = rng.next_usize(a.max(1)); + a ^ b + }) + }); + } }