idk this sucks

This commit is contained in:
Janis 2025-01-31 16:30:22 +01:00
parent a691b614bc
commit 736e4e1a60
2 changed files with 466 additions and 167 deletions

View file

@ -4,8 +4,12 @@ version = "0.1.0"
edition = "2021"
[features]
internal_heartbeat = []
heartbeat = []
spin-slow = []
cpu-pinning = []
work-stealing = []
prefer-local = []
never-local = []
[dependencies]
@ -16,6 +20,7 @@ bevy_tasks = "0.15.1"
parking_lot = "0.12.3"
thread_local = "1.1.8"
crossbeam = "0.8.4"
st3 = "0.4"
async-task = "4.7.1"

View file

@ -1,11 +1,10 @@
use std::{
cell::{OnceCell, UnsafeCell},
collections::VecDeque,
cell::{Cell, UnsafeCell},
future::Future,
mem::MaybeUninit,
num::NonZero,
pin::{pin, Pin},
ptr::NonNull,
ptr::{self, NonNull},
sync::{
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
Arc,
@ -17,7 +16,11 @@ use std::{
use async_task::{Runnable, Task};
use bitflags::bitflags;
use crossbeam::{queue::SegQueue, utils::CachePadded};
use crossbeam::{
atomic::AtomicCell,
deque::{Injector, Stealer, Worker},
utils::CachePadded,
};
use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch};
use parking_lot::{Condvar, Mutex};
use scope::Scope;
@ -337,7 +340,7 @@ pub mod latch {
pub struct ThreadPoolState {
num_threads: AtomicUsize,
lock: Mutex<()>,
heartbeat_state: CachePadded<ThreadState>,
heartbeat_state: CachePadded<ThreadControl>,
}
bitflags! {
@ -348,15 +351,21 @@ bitflags! {
}
}
pub struct ThreadState {
should_shove: AtomicBool,
shoved_task: Slot<TaskRef>,
pub struct ThreadControl {
status: Mutex<ThreadStatus>,
status_changed: Condvar,
should_terminate: AtomicLatch,
}
impl ThreadState {
pub struct ThreadState {
should_shove: AtomicBool,
control: ThreadControl,
stealer: Stealer<TaskRef>,
worker: AtomicCell<Option<Worker<TaskRef>>>,
shoved_task: CachePadded<Slot<TaskRef>>,
}
impl ThreadControl {
/// returns true if thread was sleeping
#[inline]
fn wake(&self) -> bool {
@ -451,40 +460,48 @@ impl ThreadPoolCallbacks {
pub struct ThreadPool {
threads: [CachePadded<ThreadState>; MAX_THREADS],
pool_state: CachePadded<ThreadPoolState>,
global_queue: SegQueue<TaskRef>,
global_queue: Injector<TaskRef>,
callbacks: CachePadded<ThreadPoolCallbacks>,
}
impl ThreadPool {
const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState {
pub fn new() -> Self {
Self::new_with_callbacks(ThreadPoolCallbacks::new_empty())
}
pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| {
let worker = Worker::<TaskRef>::new_fifo();
let stealer = worker.stealer();
let thread = CachePadded::new(ThreadState {
should_shove: AtomicBool::new(false),
shoved_task: Slot::new(),
shoved_task: Slot::new().into(),
control: ThreadControl {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
},
stealer,
worker: AtomicCell::new(Some(worker)),
});
uninit.write(thread);
unsafe { uninit.assume_init() }
});
pub const fn new() -> Self {
Self {
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
pool_state: CachePadded::new(ThreadPoolState {
num_threads: AtomicUsize::new(0),
lock: Mutex::new(()),
heartbeat_state: Self::INITIAL_THREAD_STATE,
}),
global_queue: SegQueue::new(),
callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()),
}
}
pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
Self {
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
threads,
pool_state: CachePadded::new(ThreadPoolState {
num_threads: AtomicUsize::new(0),
lock: Mutex::new(()),
heartbeat_state: Self::INITIAL_THREAD_STATE,
heartbeat_state: ThreadControl {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
}
.into(),
}),
global_queue: SegQueue::new(),
global_queue: Injector::new(),
callbacks: CachePadded::new(callbacks),
}
}
@ -495,7 +512,7 @@ impl ThreadPool {
}
pub fn wake_thread(&self, index: usize) -> Option<bool> {
Some(self.threads.get(index as usize)?.wake())
Some(self.threads.get(index as usize)?.control.wake())
}
pub fn wake_any(&self, count: usize) -> usize {
@ -503,7 +520,7 @@ impl ThreadPool {
let num_woken = self
.threads
.iter()
.filter_map(|thread| thread.wake().then_some(()))
.filter_map(|thread| thread.control.wake().then_some(()))
.take(count)
.count();
num_woken
@ -517,6 +534,27 @@ impl ThreadPool {
core::ptr::from_ref(self) as usize
}
fn push_local_or_inject_balanced(&self, task: TaskRef) {
let global_len = self.global_queue.len();
WorkerThread::with(|worker| match worker {
Some(worker) if worker.pool.id() == self.id() => {
let worker_len = worker.worker.len();
if worker_len == 0 {
worker.push_task(task);
} else if global_len == 0 {
self.inject(task);
} else {
if worker_len >= global_len {
worker.push_task(task);
} else {
self.inject(task);
}
}
}
_ => self.inject(task),
})
}
fn push_local_or_inject(&self, task: TaskRef) {
WorkerThread::with(|worker| match worker {
Some(worker) if worker.pool.id() == self.id() => worker.push_task(task),
@ -536,6 +574,15 @@ impl ThreadPool {
self.wake_any(n);
}
fn inject_maybe_local(&self, task: TaskRef) {
#[cfg(all(not(feature = "never-local"), feature = "prefer-local"))]
self.push_local_or_inject(task);
#[cfg(all(not(feature = "prefer-local"), feature = "never-local"))]
self.inject(task);
#[cfg(not(any(feature = "prefer-local", feature = "never-local")))]
self.push_local_or_inject_balanced(task);
}
fn inject(&self, task: TaskRef) {
self.global_queue.push(task);
@ -581,10 +628,10 @@ impl ThreadPool {
}
for thread in new_threads {
thread.wait_for_running();
thread.control.wait_for_running();
}
#[cfg(not(feature = "internal_heartbeat"))]
#[cfg(feature = "heartbeat")]
if current_size == 0 {
std::thread::spawn(move || {
heartbeat_loop(self);
@ -601,13 +648,13 @@ impl ThreadPool {
let terminating_threads = &self.threads[new_size..current_size];
for thread in terminating_threads {
thread.notify_should_terminate();
thread.control.notify_should_terminate();
}
for thread in terminating_threads {
thread.wait_for_termination();
thread.control.wait_for_termination();
}
#[cfg(not(feature = "internal_heartbeat"))]
#[cfg(feature = "heartbeat")]
if new_size == 0 {
self.pool_state.heartbeat_state.notify_should_terminate();
self.pool_state.heartbeat_state.wait_for_termination();
@ -712,7 +759,7 @@ impl ThreadPool {
}));
let taskref = task.into_ref().as_task_ref();
self.inject(taskref);
self.push_local_or_inject(taskref);
worker.run_until(&latch);
result.unwrap()
@ -727,7 +774,7 @@ impl ThreadPool {
let task = HeapTask::new(f);
let taskref = unsafe { task.into_static_task_ref() };
self.push_local_or_inject(taskref);
self.inject_maybe_local(taskref);
}
fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T>
@ -745,7 +792,7 @@ impl ThreadPool {
})
};
self.push_local_or_inject(taskref);
self.inject_maybe_local(taskref);
};
let (runnable, task) = async_task::spawn(future, schedule);
@ -789,6 +836,30 @@ impl ThreadPool {
}
fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
self.join_threaded(f, g)
}
fn join_seq<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
let a = f();
let b = g();
(a, b)
}
#[inline]
fn join_threaded<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
@ -808,7 +879,6 @@ impl ThreadPool {
let ref_b = task_b.as_ref().as_task_ref();
let b_id = ref_b.id();
// TODO: maybe try to push this off to another thread immediately first?
worker.push_task(ref_b);
let result_a = f();
@ -817,14 +887,18 @@ impl ThreadPool {
match worker.pop_task() {
Some(task) => {
if task.id() == b_id {
worker.try_promote();
// we're not calling execute() here, so manually try
// shoving a task.
//worker.try_promote();
worker.shove_task();
unsafe {
task_b.run_as_ref();
}
break;
}
} else {
worker.execute(task);
}
}
None => {
worker.run_until(&latch_b);
}
@ -837,12 +911,12 @@ impl ThreadPool {
fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T
where
Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send,
Fn: FnOnce(&Scope<'scope>) -> T + Send,
T: Send,
{
self.in_worker(|owner, _| {
let scope = pin!(unsafe { Scope::<'scope>::new(owner) });
let result = f(scope.as_ref());
let scope = unsafe { Scope::<'scope>::new(owner) };
let result = f(&scope);
scope.complete(owner);
result
})
@ -850,7 +924,8 @@ impl ThreadPool {
}
pub struct WorkerThread {
queue: TaskQueue<TaskRef>,
// queue: TaskQueue<TaskRef>,
worker: Worker<TaskRef>,
pool: &'static ThreadPool,
index: usize,
rng: rng::XorShift64Star,
@ -860,7 +935,7 @@ pub struct WorkerThread {
const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) };
std::thread_local! {
static WORKER_THREAD_STATE: CachePadded<OnceCell<WorkerThread>> = const {CachePadded::new(OnceCell::new())};
static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())};
}
impl WorkerThread {
@ -880,25 +955,47 @@ impl WorkerThread {
fn is_worker_thread() -> bool {
Self::with(|worker| worker.is_some())
}
fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T {
WORKER_THREAD_STATE.with(|thread| f(thread.get()))
WORKER_THREAD_STATE.with(|thread| {
f(NonNull::<WorkerThread>::new(thread.get().cast_mut())
.map(|ptr| unsafe { ptr.as_ref() }))
})
}
#[inline]
fn pop_task(&self) -> Option<TaskRef> {
self.queue.pop_front()
self.worker.pop()
//self.queue.pop_front(task);
}
#[inline]
fn push_task(&self, task: TaskRef) {
self.queue.push_front(task);
self.worker.push(task);
//self.queue.push_front(task);
}
#[inline]
fn drain(&self) -> impl Iterator<Item = TaskRef> {
self.queue.drain()
// self.queue.drain()
core::iter::empty()
}
#[inline]
fn steal_tasks(&self) -> Option<TaskRef> {
// careful not to call threads() here because that omits any threads
// that were killed, which might still have tasks.
let threads = &self.pool.threads;
let (start, end) = threads.split_at(self.rng.next_usize(threads.len()));
end.iter()
.chain(start)
.find_map(|thread: &CachePadded<ThreadState>| {
thread.stealer.steal_batch_and_pop(&self.worker).success()
})
}
#[inline]
fn claim_shoved_task(&self) -> Option<TaskRef> {
// take own shoved task first
if let Some(task) = self.info().shoved_task.try_take() {
return Some(task);
}
@ -916,12 +1013,16 @@ impl WorkerThread {
#[cold]
fn shove_task(&self) {
if let Some(task) = self.queue.pop_back() {
if !self.info().shoved_task.is_occupied() {
if let Some(task) = self.info().stealer.steal().success() {
match self.info().shoved_task.try_put(task) {
// shoved task is occupied, reinsert into queue
Some(task) => self.queue.push_back(task),
// this really shouldn't happen
Some(_task) => unreachable!(),
None => {}
}
}
} else {
// wake thread to execute task
self.pool.wake_any(1);
}
@ -934,24 +1035,15 @@ impl WorkerThread {
#[inline]
fn try_promote(&self) {
#[cfg(feature = "internal_heartbeat")]
let now = std::time::Instant::now();
// SAFETY: workerthread is thread-local non-sync
#[cfg(feature = "internal_heartbeat")]
let should_shove =
unsafe { *self.last_heartbeat.get() }.duration_since(now) > HEARTBEAT_INTERVAL;
#[cfg(not(feature = "internal_heartbeat"))]
#[cfg(feature = "heartbeat")]
let should_shove = self.info().should_shove.load(Ordering::Acquire);
#[cfg(not(feature = "heartbeat"))]
let should_shove = true;
if should_shove {
// SAFETY: workerthread is thread-local non-sync
#[cfg(feature = "internal_heartbeat")]
unsafe {
*&mut *self.last_heartbeat.get() = now;
}
#[cfg(not(feature = "internal_heartbeat"))]
#[cfg(feature = "heartbeat")]
self.info().should_shove.store(false, Ordering::Release);
self.shove_task();
}
}
@ -959,9 +1051,22 @@ impl WorkerThread {
#[inline]
fn find_any_task(&self) -> Option<TaskRef> {
// TODO: attempt stealing work here, too.
self.pop_task()
let mut task = self
.pop_task()
.or_else(|| self.claim_shoved_task())
.or_else(|| self.pool.global_queue.pop())
.or_else(|| {
self.pool
.global_queue
.steal_batch_and_pop(&self.worker)
.success()
});
#[cfg(feature = "work-stealing")]
{
task = task.or_else(|| self.steal_tasks());
}
task
}
#[inline]
@ -991,34 +1096,36 @@ impl WorkerThread {
self.execute(task);
}
None => {
debug!("waiting for tasks");
self.info().wait_for_should_wake();
//debug!("waiting for tasks");
self.info().control.wait_for_should_wake();
}
}
}
fn worker_loop(pool: &'static ThreadPool, index: usize) {
let info = &pool.threads()[index as usize];
WORKER_THREAD_STATE.with(|worker| {
let worker = worker.get_or_init(|| WorkerThread {
queue: TaskQueue::new(),
let worker = CachePadded::new(WorkerThread {
// queue: TaskQueue::new(),
worker: info.worker.take().unwrap(),
pool,
index,
rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64),
last_heartbeat: UnsafeCell::new(std::time::Instant::now()),
});
WORKER_THREAD_STATE.with(|cell| {
cell.set(&*worker);
if let Some(callback) = pool.callbacks.at_entry.as_ref() {
callback(worker);
callback(&worker);
}
info.notify_running();
info.control.notify_running();
// info.notify_running();
worker.run_until(&info.should_terminate);
worker.run_until(&info.control.should_terminate);
if let Some(callback) = pool.callbacks.at_exit.as_ref() {
callback(worker);
callback(&worker);
}
for task in worker.drain() {
@ -1028,9 +1135,14 @@ impl WorkerThread {
if let Some(task) = info.shoved_task.try_take() {
pool.inject(task);
}
cell.set(ptr::null());
});
info.notify_termination();
let WorkerThread { worker, .. } = CachePadded::into_inner(worker);
info.worker.store(Some(worker));
info.control.notify_termination();
}
}
@ -1061,56 +1173,68 @@ fn heartbeat_loop(pool: &'static ThreadPool) {
state.notify_termination();
}
use vec_queue::TaskQueue;
mod vec_queue {
use std::{cell::UnsafeCell, collections::VecDeque};
pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>);
impl<T> TaskQueue<T> {
/// Creates a new [`TaskQueue<T>`].
#[inline]
const fn new() -> Self {
pub const fn new() -> Self {
Self(UnsafeCell::new(VecDeque::new()))
}
#[inline]
fn get_mut(&self) -> &mut VecDeque<T> {
pub fn get_mut(&self) -> &mut VecDeque<T> {
unsafe { &mut *self.0.get() }
}
#[inline]
fn pop_front(&self) -> Option<T> {
pub fn pop_front(&self) -> Option<T> {
self.get_mut().pop_front()
}
#[inline]
fn pop_back(&self) -> Option<T> {
pub fn pop_back(&self) -> Option<T> {
self.get_mut().pop_back()
}
#[inline]
fn push_back(&self, t: T) {
pub fn push_back(&self, t: T) {
self.get_mut().push_back(t);
}
#[inline]
fn push_front(&self, t: T) {
pub fn push_front(&self, t: T) {
self.get_mut().push_front(t);
}
#[inline]
fn take(&self) -> VecDeque<T> {
pub fn take(&self) -> VecDeque<T> {
let this = core::mem::replace(self.get_mut(), VecDeque::new());
this
}
#[inline]
fn drain(&self) -> impl Iterator<Item = T> {
pub fn drain(&self) -> impl Iterator<Item = T> {
self.take().into_iter()
}
}
}
bitflags! {
#[derive(Debug, Clone, Copy)]
pub struct SlotState: u8 {
const LOCKED = 1 << 1;
const OCCUPIED = 1 << 2;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SlotState {
None,
Locked,
Occupied,
}
impl From<u8> for SlotState {
fn from(value: u8) -> Self {
unsafe { core::mem::transmute(value) }
}
}
impl From<SlotState> for u8 {
fn from(value: SlotState) -> Self {
value.bits()
value as u8
}
}
@ -1125,10 +1249,7 @@ unsafe impl<T> Sync for Slot<T> where T: Send {}
impl<T> Drop for Slot<T> {
fn drop(&mut self) {
if core::mem::needs_drop::<T>() {
if SlotState::from_bits(*self.state.get_mut())
.unwrap()
.contains(SlotState::OCCUPIED)
{
if *self.state.get_mut() == SlotState::Occupied as u8 {
unsafe {
self.slot.get().drop_in_place();
}
@ -1141,15 +1262,56 @@ impl<T> Slot<T> {
pub const fn new() -> Slot<T> {
Self {
slot: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(SlotState::empty().bits()),
state: AtomicU8::new(SlotState::None as u8),
}
}
pub fn is_occupied(&self) -> bool {
self.state.load(Ordering::Acquire) == SlotState::Occupied.into()
}
#[inline]
pub fn insert(&self, t: T) -> Option<T> {
let value = match self
.state
.swap(SlotState::Locked.into(), Ordering::AcqRel)
.into()
{
SlotState::Locked => {
// return early: was already locked.
debug!("slot was already locked");
return None;
}
SlotState::Occupied => {
let slot = self.slot.get();
// replace
unsafe {
let v = (*slot).assume_init_read();
(*slot).write(t);
Some(v)
}
}
SlotState::None => {
let slot = self.slot.get();
// insert
unsafe {
(*slot).write(t);
}
None
}
};
// release lock
self.state
.store(SlotState::Occupied.into(), Ordering::Release);
value
}
#[inline]
pub fn try_put(&self, t: T) -> Option<T> {
match self.state.compare_exchange(
SlotState::empty().into(),
SlotState::LOCKED.into(),
SlotState::None.into(),
SlotState::Locked.into(),
Ordering::Acquire,
Ordering::Relaxed,
) {
@ -1161,7 +1323,7 @@ impl<T> Slot<T> {
// release lock
self.state
.store(SlotState::OCCUPIED.into(), Ordering::Release);
.store(SlotState::Occupied.into(), Ordering::Release);
None
}
}
@ -1170,8 +1332,8 @@ impl<T> Slot<T> {
#[inline]
pub fn try_take(&self) -> Option<T> {
match self.state.compare_exchange(
SlotState::OCCUPIED.into(),
SlotState::LOCKED.into(),
SlotState::Occupied.into(),
SlotState::Locked.into(),
Ordering::Acquire,
Ordering::Relaxed,
) {
@ -1181,8 +1343,7 @@ impl<T> Slot<T> {
let t = unsafe { (*slot).assume_init_read() };
// release lock
self.state
.store(SlotState::empty().into(), Ordering::Release);
self.state.store(SlotState::None.into(), Ordering::Release);
Some(t)
}
Err(_) => None,
@ -1227,14 +1388,15 @@ mod scope {
use std::{
future::Future,
marker::{PhantomData, PhantomPinned},
pin::pin,
ptr::{self, NonNull},
};
use async_task::{Runnable, Task};
use crate::{
latch::{CountWakeLatch, Latch},
task::{HeapTask, TaskRef},
latch::{CountWakeLatch, Latch, Probe, ThreadWakeLatch},
task::{HeapTask, StackTask, TaskRef},
ThreadPool, WorkerThread,
};
@ -1253,6 +1415,16 @@ mod scope {
}
}
pub fn join<F, G, T, U>(&self, f: F, g: G) -> (T, U)
where
F: FnOnce(&Self) -> T + Send,
G: FnOnce(&Self) -> U + Send,
T: Send,
U: Send,
{
self.pool.join(|| f(self), || g(self))
}
pub fn spawn<Fn>(&self, f: Fn)
where
Fn: FnOnce(&Scope<'scope>) + Send + 'scope,
@ -1267,7 +1439,7 @@ mod scope {
});
let taskref = unsafe { task.into_task_ref() };
self.pool.push_local_or_inject(taskref);
self.pool.inject_maybe_local(taskref);
}
pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T>
@ -1289,7 +1461,7 @@ mod scope {
};
unsafe {
ptr.as_ref().pool.push_local_or_inject(taskref);
ptr.as_ref().pool.inject_maybe_local(taskref);
}
};
@ -1332,8 +1504,58 @@ mod scope {
mod tests {
use std::{cell::Cell, hint::black_box};
use tracing::info;
use super::*;
mod tree {
pub struct Tree<T> {
nodes: Box<[Node<T>]>,
root: Option<usize>,
}
pub struct Node<T> {
pub leaf: T,
pub left: Option<usize>,
pub right: Option<usize>,
}
impl<T> Tree<T> {
pub fn new(depth: usize, t: T) -> Tree<T>
where
T: Copy,
{
let mut nodes = Vec::with_capacity((0..depth).sum());
let root = Self::build_node(&mut nodes, depth, t);
Self {
nodes: nodes.into_boxed_slice(),
root: Some(root),
}
}
pub fn root(&self) -> Option<usize> {
self.root
}
pub fn get(&self, index: usize) -> &Node<T> {
&self.nodes[index]
}
pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
where
T: Copy,
{
let node = Node {
leaf: t,
left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
};
nodes.push(node);
nodes.len() - 1
}
}
}
const PRIMES: &'static [usize] = &[
1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283,
1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409,
@ -1344,9 +1566,14 @@ mod tests {
1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
];
const REPEAT: usize = 0x100;
#[cfg(feature = "spin-slow")]
const REPEAT: usize = 0x800;
#[cfg(not(feature = "spin-slow"))]
const REPEAT: usize = 0x8000;
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T {
const TREE_SIZE: usize = 10;
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T {
let pool = Box::new(pool);
let ptr = Box::into_raw(pool);
@ -1357,9 +1584,9 @@ mod tests {
let now = std::time::Instant::now();
let result = pool.scope(f);
let elapsed = now.elapsed().as_micros();
eprintln!("(mine) total time: {}ms", elapsed as f32 / 1e3);
info!("(mine) total time: {}ms", elapsed as f32 / 1e3);
pool.resize_to(0);
assert!(pool.global_queue.pop().is_none());
assert!(pool.global_queue.is_empty());
result
};
@ -1385,7 +1612,38 @@ mod tests {
});
let elapsed = now.elapsed().as_micros();
eprintln!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
}
#[test]
#[tracing_test::traced_test]
fn rayon_join() {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(bevy_tasks::available_parallelism())
.build()
.unwrap();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
let node = tree.get(node);
let (l, r) = rayon::join(
|| node.left.map(|node| sum(tree, node)).unwrap_or_default(),
|| node.right.map(|node| sum(tree, node)).unwrap_or_default(),
);
node.leaf + l + r
}
let now = std::time::Instant::now();
let sum = pool.scope(move |s| {
let root = tree.root().unwrap();
sum(&tree, root)
});
let elapsed = now.elapsed().as_micros();
info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
}
#[test]
@ -1407,7 +1665,7 @@ mod tests {
});
let elapsed = now.elapsed().as_micros();
eprintln!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3);
info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3);
}
#[test]
@ -1418,18 +1676,7 @@ mod tests {
}
let counter = Arc::new(AtomicUsize::new(0));
{
let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks {
at_entry: Some(Arc::new(|_worker| {
// eprintln!("new worker thread: {}", worker.index);
})),
at_exit: Some(Arc::new({
let counter = counter.clone();
move |_worker: &WorkerThread| {
// eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get());
counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed);
}
})),
});
let pool = ThreadPool::new();
run_in_scope(pool, |s| {
for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
@ -1443,6 +1690,33 @@ mod tests {
// eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
}
#[test]
#[tracing_test::traced_test]
fn mine_join() {
let pool = ThreadPool::new();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize, scope: &Scope<'_>) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
node.leaf + l + r
}
let sum = run_in_scope(pool, move |s| {
let root = tree.root().unwrap();
sum(&tree, root, s)
});
}
#[test]
#[tracing_test::traced_test]
fn sync() {
@ -1452,11 +1726,19 @@ mod tests {
}
let elapsed = now.elapsed().as_micros();
eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3);
info!("(sync) total time: {}ms", elapsed as f32 / 1e3);
}
#[inline]
fn spinning(i: usize) {
#[cfg(feature = "spin-slow")]
spinning_slow(i);
#[cfg(not(feature = "spin-slow"))]
spinning_fast(i);
}
#[inline]
fn spinning_slow(i: usize) {
let rng = rng::XorShift64Star::new(i as u64);
(0..i).reduce(|a, b| {
black_box({
@ -1465,4 +1747,16 @@ mod tests {
})
});
}
#[inline]
fn spinning_fast(i: usize) {
let rng = rng::XorShift64Star::new(i as u64);
//(0..rng.next_usize(i)).reduce(|a, b| {
(0..20).reduce(|a, b| {
black_box({
let a = rng.next_usize(a.max(1));
a ^ b
})
});
}
}