idk this sucks

This commit is contained in:
Janis 2025-01-31 16:30:22 +01:00
parent a691b614bc
commit 736e4e1a60
2 changed files with 466 additions and 167 deletions

View file

@ -4,8 +4,12 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[features] [features]
internal_heartbeat = [] heartbeat = []
spin-slow = []
cpu-pinning = [] cpu-pinning = []
work-stealing = []
prefer-local = []
never-local = []
[dependencies] [dependencies]
@ -16,6 +20,7 @@ bevy_tasks = "0.15.1"
parking_lot = "0.12.3" parking_lot = "0.12.3"
thread_local = "1.1.8" thread_local = "1.1.8"
crossbeam = "0.8.4" crossbeam = "0.8.4"
st3 = "0.4"
async-task = "4.7.1" async-task = "4.7.1"

View file

@ -1,11 +1,10 @@
use std::{ use std::{
cell::{OnceCell, UnsafeCell}, cell::{Cell, UnsafeCell},
collections::VecDeque,
future::Future, future::Future,
mem::MaybeUninit, mem::MaybeUninit,
num::NonZero, num::NonZero,
pin::{pin, Pin}, pin::{pin, Pin},
ptr::NonNull, ptr::{self, NonNull},
sync::{ sync::{
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
Arc, Arc,
@ -17,7 +16,11 @@ use std::{
use async_task::{Runnable, Task}; use async_task::{Runnable, Task};
use bitflags::bitflags; use bitflags::bitflags;
use crossbeam::{queue::SegQueue, utils::CachePadded}; use crossbeam::{
atomic::AtomicCell,
deque::{Injector, Stealer, Worker},
utils::CachePadded,
};
use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use scope::Scope; use scope::Scope;
@ -337,7 +340,7 @@ pub mod latch {
pub struct ThreadPoolState { pub struct ThreadPoolState {
num_threads: AtomicUsize, num_threads: AtomicUsize,
lock: Mutex<()>, lock: Mutex<()>,
heartbeat_state: CachePadded<ThreadState>, heartbeat_state: CachePadded<ThreadControl>,
} }
bitflags! { bitflags! {
@ -348,15 +351,21 @@ bitflags! {
} }
} }
pub struct ThreadState { pub struct ThreadControl {
should_shove: AtomicBool,
shoved_task: Slot<TaskRef>,
status: Mutex<ThreadStatus>, status: Mutex<ThreadStatus>,
status_changed: Condvar, status_changed: Condvar,
should_terminate: AtomicLatch, should_terminate: AtomicLatch,
} }
impl ThreadState { pub struct ThreadState {
should_shove: AtomicBool,
control: ThreadControl,
stealer: Stealer<TaskRef>,
worker: AtomicCell<Option<Worker<TaskRef>>>,
shoved_task: CachePadded<Slot<TaskRef>>,
}
impl ThreadControl {
/// returns true if thread was sleeping /// returns true if thread was sleeping
#[inline] #[inline]
fn wake(&self) -> bool { fn wake(&self) -> bool {
@ -451,40 +460,48 @@ impl ThreadPoolCallbacks {
pub struct ThreadPool { pub struct ThreadPool {
threads: [CachePadded<ThreadState>; MAX_THREADS], threads: [CachePadded<ThreadState>; MAX_THREADS],
pool_state: CachePadded<ThreadPoolState>, pool_state: CachePadded<ThreadPoolState>,
global_queue: SegQueue<TaskRef>, global_queue: Injector<TaskRef>,
callbacks: CachePadded<ThreadPoolCallbacks>, callbacks: CachePadded<ThreadPoolCallbacks>,
} }
impl ThreadPool { impl ThreadPool {
const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState { pub fn new() -> Self {
should_shove: AtomicBool::new(false), Self::new_with_callbacks(ThreadPoolCallbacks::new_empty())
shoved_task: Slot::new(),
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
});
pub const fn new() -> Self {
Self {
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
pool_state: CachePadded::new(ThreadPoolState {
num_threads: AtomicUsize::new(0),
lock: Mutex::new(()),
heartbeat_state: Self::INITIAL_THREAD_STATE,
}),
global_queue: SegQueue::new(),
callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()),
}
} }
pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| {
let worker = Worker::<TaskRef>::new_fifo();
let stealer = worker.stealer();
let thread = CachePadded::new(ThreadState {
should_shove: AtomicBool::new(false),
shoved_task: Slot::new().into(),
control: ThreadControl {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
},
stealer,
worker: AtomicCell::new(Some(worker)),
});
uninit.write(thread);
unsafe { uninit.assume_init() }
});
Self { Self {
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, threads,
pool_state: CachePadded::new(ThreadPoolState { pool_state: CachePadded::new(ThreadPoolState {
num_threads: AtomicUsize::new(0), num_threads: AtomicUsize::new(0),
lock: Mutex::new(()), lock: Mutex::new(()),
heartbeat_state: Self::INITIAL_THREAD_STATE, heartbeat_state: ThreadControl {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
}
.into(),
}), }),
global_queue: SegQueue::new(), global_queue: Injector::new(),
callbacks: CachePadded::new(callbacks), callbacks: CachePadded::new(callbacks),
} }
} }
@ -495,7 +512,7 @@ impl ThreadPool {
} }
pub fn wake_thread(&self, index: usize) -> Option<bool> { pub fn wake_thread(&self, index: usize) -> Option<bool> {
Some(self.threads.get(index as usize)?.wake()) Some(self.threads.get(index as usize)?.control.wake())
} }
pub fn wake_any(&self, count: usize) -> usize { pub fn wake_any(&self, count: usize) -> usize {
@ -503,7 +520,7 @@ impl ThreadPool {
let num_woken = self let num_woken = self
.threads .threads
.iter() .iter()
.filter_map(|thread| thread.wake().then_some(())) .filter_map(|thread| thread.control.wake().then_some(()))
.take(count) .take(count)
.count(); .count();
num_woken num_woken
@ -517,6 +534,27 @@ impl ThreadPool {
core::ptr::from_ref(self) as usize core::ptr::from_ref(self) as usize
} }
fn push_local_or_inject_balanced(&self, task: TaskRef) {
let global_len = self.global_queue.len();
WorkerThread::with(|worker| match worker {
Some(worker) if worker.pool.id() == self.id() => {
let worker_len = worker.worker.len();
if worker_len == 0 {
worker.push_task(task);
} else if global_len == 0 {
self.inject(task);
} else {
if worker_len >= global_len {
worker.push_task(task);
} else {
self.inject(task);
}
}
}
_ => self.inject(task),
})
}
fn push_local_or_inject(&self, task: TaskRef) { fn push_local_or_inject(&self, task: TaskRef) {
WorkerThread::with(|worker| match worker { WorkerThread::with(|worker| match worker {
Some(worker) if worker.pool.id() == self.id() => worker.push_task(task), Some(worker) if worker.pool.id() == self.id() => worker.push_task(task),
@ -536,6 +574,15 @@ impl ThreadPool {
self.wake_any(n); self.wake_any(n);
} }
fn inject_maybe_local(&self, task: TaskRef) {
#[cfg(all(not(feature = "never-local"), feature = "prefer-local"))]
self.push_local_or_inject(task);
#[cfg(all(not(feature = "prefer-local"), feature = "never-local"))]
self.inject(task);
#[cfg(not(any(feature = "prefer-local", feature = "never-local")))]
self.push_local_or_inject_balanced(task);
}
fn inject(&self, task: TaskRef) { fn inject(&self, task: TaskRef) {
self.global_queue.push(task); self.global_queue.push(task);
@ -581,10 +628,10 @@ impl ThreadPool {
} }
for thread in new_threads { for thread in new_threads {
thread.wait_for_running(); thread.control.wait_for_running();
} }
#[cfg(not(feature = "internal_heartbeat"))] #[cfg(feature = "heartbeat")]
if current_size == 0 { if current_size == 0 {
std::thread::spawn(move || { std::thread::spawn(move || {
heartbeat_loop(self); heartbeat_loop(self);
@ -601,13 +648,13 @@ impl ThreadPool {
let terminating_threads = &self.threads[new_size..current_size]; let terminating_threads = &self.threads[new_size..current_size];
for thread in terminating_threads { for thread in terminating_threads {
thread.notify_should_terminate(); thread.control.notify_should_terminate();
} }
for thread in terminating_threads { for thread in terminating_threads {
thread.wait_for_termination(); thread.control.wait_for_termination();
} }
#[cfg(not(feature = "internal_heartbeat"))] #[cfg(feature = "heartbeat")]
if new_size == 0 { if new_size == 0 {
self.pool_state.heartbeat_state.notify_should_terminate(); self.pool_state.heartbeat_state.notify_should_terminate();
self.pool_state.heartbeat_state.wait_for_termination(); self.pool_state.heartbeat_state.wait_for_termination();
@ -712,7 +759,7 @@ impl ThreadPool {
})); }));
let taskref = task.into_ref().as_task_ref(); let taskref = task.into_ref().as_task_ref();
self.inject(taskref); self.push_local_or_inject(taskref);
worker.run_until(&latch); worker.run_until(&latch);
result.unwrap() result.unwrap()
@ -727,7 +774,7 @@ impl ThreadPool {
let task = HeapTask::new(f); let task = HeapTask::new(f);
let taskref = unsafe { task.into_static_task_ref() }; let taskref = unsafe { task.into_static_task_ref() };
self.push_local_or_inject(taskref); self.inject_maybe_local(taskref);
} }
fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T> fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T>
@ -745,7 +792,7 @@ impl ThreadPool {
}) })
}; };
self.push_local_or_inject(taskref); self.inject_maybe_local(taskref);
}; };
let (runnable, task) = async_task::spawn(future, schedule); let (runnable, task) = async_task::spawn(future, schedule);
@ -789,6 +836,30 @@ impl ThreadPool {
} }
fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
self.join_threaded(f, g)
}
fn join_seq<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
let a = f();
let b = g();
(a, b)
}
#[inline]
fn join_threaded<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where where
F: FnOnce() -> T + Send, F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send, G: FnOnce() -> U + Send,
@ -808,7 +879,6 @@ impl ThreadPool {
let ref_b = task_b.as_ref().as_task_ref(); let ref_b = task_b.as_ref().as_task_ref();
let b_id = ref_b.id(); let b_id = ref_b.id();
// TODO: maybe try to push this off to another thread immediately first?
worker.push_task(ref_b); worker.push_task(ref_b);
let result_a = f(); let result_a = f();
@ -817,13 +887,17 @@ impl ThreadPool {
match worker.pop_task() { match worker.pop_task() {
Some(task) => { Some(task) => {
if task.id() == b_id { if task.id() == b_id {
worker.try_promote(); // we're not calling execute() here, so manually try
// shoving a task.
//worker.try_promote();
worker.shove_task();
unsafe { unsafe {
task_b.run_as_ref(); task_b.run_as_ref();
} }
break; break;
} else {
worker.execute(task);
} }
worker.execute(task);
} }
None => { None => {
worker.run_until(&latch_b); worker.run_until(&latch_b);
@ -837,12 +911,12 @@ impl ThreadPool {
fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T
where where
Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send, Fn: FnOnce(&Scope<'scope>) -> T + Send,
T: Send, T: Send,
{ {
self.in_worker(|owner, _| { self.in_worker(|owner, _| {
let scope = pin!(unsafe { Scope::<'scope>::new(owner) }); let scope = unsafe { Scope::<'scope>::new(owner) };
let result = f(scope.as_ref()); let result = f(&scope);
scope.complete(owner); scope.complete(owner);
result result
}) })
@ -850,7 +924,8 @@ impl ThreadPool {
} }
pub struct WorkerThread { pub struct WorkerThread {
queue: TaskQueue<TaskRef>, // queue: TaskQueue<TaskRef>,
worker: Worker<TaskRef>,
pool: &'static ThreadPool, pool: &'static ThreadPool,
index: usize, index: usize,
rng: rng::XorShift64Star, rng: rng::XorShift64Star,
@ -860,7 +935,7 @@ pub struct WorkerThread {
const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) }; const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) };
std::thread_local! { std::thread_local! {
static WORKER_THREAD_STATE: CachePadded<OnceCell<WorkerThread>> = const {CachePadded::new(OnceCell::new())}; static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())};
} }
impl WorkerThread { impl WorkerThread {
@ -880,25 +955,47 @@ impl WorkerThread {
fn is_worker_thread() -> bool { fn is_worker_thread() -> bool {
Self::with(|worker| worker.is_some()) Self::with(|worker| worker.is_some())
} }
fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T { fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T {
WORKER_THREAD_STATE.with(|thread| f(thread.get())) WORKER_THREAD_STATE.with(|thread| {
f(NonNull::<WorkerThread>::new(thread.get().cast_mut())
.map(|ptr| unsafe { ptr.as_ref() }))
})
} }
#[inline] #[inline]
fn pop_task(&self) -> Option<TaskRef> { fn pop_task(&self) -> Option<TaskRef> {
self.queue.pop_front() self.worker.pop()
//self.queue.pop_front(task);
} }
#[inline] #[inline]
fn push_task(&self, task: TaskRef) { fn push_task(&self, task: TaskRef) {
self.queue.push_front(task); self.worker.push(task);
//self.queue.push_front(task);
} }
#[inline] #[inline]
fn drain(&self) -> impl Iterator<Item = TaskRef> { fn drain(&self) -> impl Iterator<Item = TaskRef> {
self.queue.drain() // self.queue.drain()
core::iter::empty()
}
#[inline]
fn steal_tasks(&self) -> Option<TaskRef> {
// careful not to call threads() here because that omits any threads
// that were killed, which might still have tasks.
let threads = &self.pool.threads;
let (start, end) = threads.split_at(self.rng.next_usize(threads.len()));
end.iter()
.chain(start)
.find_map(|thread: &CachePadded<ThreadState>| {
thread.stealer.steal_batch_and_pop(&self.worker).success()
})
} }
#[inline] #[inline]
fn claim_shoved_task(&self) -> Option<TaskRef> { fn claim_shoved_task(&self) -> Option<TaskRef> {
// take own shoved task first
if let Some(task) = self.info().shoved_task.try_take() { if let Some(task) = self.info().shoved_task.try_take() {
return Some(task); return Some(task);
} }
@ -916,12 +1013,16 @@ impl WorkerThread {
#[cold] #[cold]
fn shove_task(&self) { fn shove_task(&self) {
if let Some(task) = self.queue.pop_back() { if !self.info().shoved_task.is_occupied() {
match self.info().shoved_task.try_put(task) { if let Some(task) = self.info().stealer.steal().success() {
// shoved task is occupied, reinsert into queue match self.info().shoved_task.try_put(task) {
Some(task) => self.queue.push_back(task), // shoved task is occupied, reinsert into queue
None => {} // this really shouldn't happen
Some(_task) => unreachable!(),
None => {}
}
} }
} else {
// wake thread to execute task // wake thread to execute task
self.pool.wake_any(1); self.pool.wake_any(1);
} }
@ -934,24 +1035,15 @@ impl WorkerThread {
#[inline] #[inline]
fn try_promote(&self) { fn try_promote(&self) {
#[cfg(feature = "internal_heartbeat")] #[cfg(feature = "heartbeat")]
let now = std::time::Instant::now();
// SAFETY: workerthread is thread-local non-sync
#[cfg(feature = "internal_heartbeat")]
let should_shove =
unsafe { *self.last_heartbeat.get() }.duration_since(now) > HEARTBEAT_INTERVAL;
#[cfg(not(feature = "internal_heartbeat"))]
let should_shove = self.info().should_shove.load(Ordering::Acquire); let should_shove = self.info().should_shove.load(Ordering::Acquire);
#[cfg(not(feature = "heartbeat"))]
let should_shove = true;
if should_shove { if should_shove {
// SAFETY: workerthread is thread-local non-sync #[cfg(feature = "heartbeat")]
#[cfg(feature = "internal_heartbeat")]
unsafe {
*&mut *self.last_heartbeat.get() = now;
}
#[cfg(not(feature = "internal_heartbeat"))]
self.info().should_shove.store(false, Ordering::Release); self.info().should_shove.store(false, Ordering::Release);
self.shove_task(); self.shove_task();
} }
} }
@ -959,9 +1051,22 @@ impl WorkerThread {
#[inline] #[inline]
fn find_any_task(&self) -> Option<TaskRef> { fn find_any_task(&self) -> Option<TaskRef> {
// TODO: attempt stealing work here, too. // TODO: attempt stealing work here, too.
self.pop_task() let mut task = self
.pop_task()
.or_else(|| self.claim_shoved_task()) .or_else(|| self.claim_shoved_task())
.or_else(|| self.pool.global_queue.pop()) .or_else(|| {
self.pool
.global_queue
.steal_batch_and_pop(&self.worker)
.success()
});
#[cfg(feature = "work-stealing")]
{
task = task.or_else(|| self.steal_tasks());
}
task
} }
#[inline] #[inline]
@ -991,34 +1096,36 @@ impl WorkerThread {
self.execute(task); self.execute(task);
} }
None => { None => {
debug!("waiting for tasks"); //debug!("waiting for tasks");
self.info().wait_for_should_wake(); self.info().control.wait_for_should_wake();
} }
} }
} }
fn worker_loop(pool: &'static ThreadPool, index: usize) { fn worker_loop(pool: &'static ThreadPool, index: usize) {
let info = &pool.threads()[index as usize]; let info = &pool.threads()[index as usize];
let worker = CachePadded::new(WorkerThread {
// queue: TaskQueue::new(),
worker: info.worker.take().unwrap(),
pool,
index,
rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64),
last_heartbeat: UnsafeCell::new(std::time::Instant::now()),
});
WORKER_THREAD_STATE.with(|worker| { WORKER_THREAD_STATE.with(|cell| {
let worker = worker.get_or_init(|| WorkerThread { cell.set(&*worker);
queue: TaskQueue::new(),
pool,
index,
rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64),
last_heartbeat: UnsafeCell::new(std::time::Instant::now()),
});
if let Some(callback) = pool.callbacks.at_entry.as_ref() { if let Some(callback) = pool.callbacks.at_entry.as_ref() {
callback(worker); callback(&worker);
} }
info.notify_running(); info.control.notify_running();
// info.notify_running(); // info.notify_running();
worker.run_until(&info.should_terminate); worker.run_until(&info.control.should_terminate);
if let Some(callback) = pool.callbacks.at_exit.as_ref() { if let Some(callback) = pool.callbacks.at_exit.as_ref() {
callback(worker); callback(&worker);
} }
for task in worker.drain() { for task in worker.drain() {
@ -1028,9 +1135,14 @@ impl WorkerThread {
if let Some(task) = info.shoved_task.try_take() { if let Some(task) = info.shoved_task.try_take() {
pool.inject(task); pool.inject(task);
} }
cell.set(ptr::null());
}); });
info.notify_termination(); let WorkerThread { worker, .. } = CachePadded::into_inner(worker);
info.worker.store(Some(worker));
info.control.notify_termination();
} }
} }
@ -1061,56 +1173,68 @@ fn heartbeat_loop(pool: &'static ThreadPool) {
state.notify_termination(); state.notify_termination();
} }
pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>); use vec_queue::TaskQueue;
impl<T> TaskQueue<T> { mod vec_queue {
/// Creates a new [`TaskQueue<T>`]. use std::{cell::UnsafeCell, collections::VecDeque};
#[inline]
const fn new() -> Self { pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>);
Self(UnsafeCell::new(VecDeque::new()))
} impl<T> TaskQueue<T> {
#[inline] /// Creates a new [`TaskQueue<T>`].
fn get_mut(&self) -> &mut VecDeque<T> { #[inline]
unsafe { &mut *self.0.get() } pub const fn new() -> Self {
} Self(UnsafeCell::new(VecDeque::new()))
#[inline] }
fn pop_front(&self) -> Option<T> { #[inline]
self.get_mut().pop_front() pub fn get_mut(&self) -> &mut VecDeque<T> {
} unsafe { &mut *self.0.get() }
#[inline] }
fn pop_back(&self) -> Option<T> { #[inline]
self.get_mut().pop_back() pub fn pop_front(&self) -> Option<T> {
} self.get_mut().pop_front()
#[inline] }
fn push_back(&self, t: T) { #[inline]
self.get_mut().push_back(t); pub fn pop_back(&self) -> Option<T> {
} self.get_mut().pop_back()
#[inline] }
fn push_front(&self, t: T) { #[inline]
self.get_mut().push_front(t); pub fn push_back(&self, t: T) {
} self.get_mut().push_back(t);
#[inline] }
fn take(&self) -> VecDeque<T> { #[inline]
let this = core::mem::replace(self.get_mut(), VecDeque::new()); pub fn push_front(&self, t: T) {
this self.get_mut().push_front(t);
} }
#[inline] #[inline]
fn drain(&self) -> impl Iterator<Item = T> { pub fn take(&self) -> VecDeque<T> {
self.take().into_iter() let this = core::mem::replace(self.get_mut(), VecDeque::new());
this
}
#[inline]
pub fn drain(&self) -> impl Iterator<Item = T> {
self.take().into_iter()
}
} }
} }
bitflags! { #[repr(u8)]
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SlotState: u8 { enum SlotState {
const LOCKED = 1 << 1; None,
const OCCUPIED = 1 << 2; Locked,
Occupied,
}
impl From<u8> for SlotState {
fn from(value: u8) -> Self {
unsafe { core::mem::transmute(value) }
} }
} }
impl From<SlotState> for u8 { impl From<SlotState> for u8 {
fn from(value: SlotState) -> Self { fn from(value: SlotState) -> Self {
value.bits() value as u8
} }
} }
@ -1125,10 +1249,7 @@ unsafe impl<T> Sync for Slot<T> where T: Send {}
impl<T> Drop for Slot<T> { impl<T> Drop for Slot<T> {
fn drop(&mut self) { fn drop(&mut self) {
if core::mem::needs_drop::<T>() { if core::mem::needs_drop::<T>() {
if SlotState::from_bits(*self.state.get_mut()) if *self.state.get_mut() == SlotState::Occupied as u8 {
.unwrap()
.contains(SlotState::OCCUPIED)
{
unsafe { unsafe {
self.slot.get().drop_in_place(); self.slot.get().drop_in_place();
} }
@ -1141,15 +1262,56 @@ impl<T> Slot<T> {
pub const fn new() -> Slot<T> { pub const fn new() -> Slot<T> {
Self { Self {
slot: UnsafeCell::new(MaybeUninit::uninit()), slot: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(SlotState::empty().bits()), state: AtomicU8::new(SlotState::None as u8),
} }
} }
pub fn is_occupied(&self) -> bool {
self.state.load(Ordering::Acquire) == SlotState::Occupied.into()
}
#[inline]
pub fn insert(&self, t: T) -> Option<T> {
let value = match self
.state
.swap(SlotState::Locked.into(), Ordering::AcqRel)
.into()
{
SlotState::Locked => {
// return early: was already locked.
debug!("slot was already locked");
return None;
}
SlotState::Occupied => {
let slot = self.slot.get();
// replace
unsafe {
let v = (*slot).assume_init_read();
(*slot).write(t);
Some(v)
}
}
SlotState::None => {
let slot = self.slot.get();
// insert
unsafe {
(*slot).write(t);
}
None
}
};
// release lock
self.state
.store(SlotState::Occupied.into(), Ordering::Release);
value
}
#[inline] #[inline]
pub fn try_put(&self, t: T) -> Option<T> { pub fn try_put(&self, t: T) -> Option<T> {
match self.state.compare_exchange( match self.state.compare_exchange(
SlotState::empty().into(), SlotState::None.into(),
SlotState::LOCKED.into(), SlotState::Locked.into(),
Ordering::Acquire, Ordering::Acquire,
Ordering::Relaxed, Ordering::Relaxed,
) { ) {
@ -1161,7 +1323,7 @@ impl<T> Slot<T> {
// release lock // release lock
self.state self.state
.store(SlotState::OCCUPIED.into(), Ordering::Release); .store(SlotState::Occupied.into(), Ordering::Release);
None None
} }
} }
@ -1170,8 +1332,8 @@ impl<T> Slot<T> {
#[inline] #[inline]
pub fn try_take(&self) -> Option<T> { pub fn try_take(&self) -> Option<T> {
match self.state.compare_exchange( match self.state.compare_exchange(
SlotState::OCCUPIED.into(), SlotState::Occupied.into(),
SlotState::LOCKED.into(), SlotState::Locked.into(),
Ordering::Acquire, Ordering::Acquire,
Ordering::Relaxed, Ordering::Relaxed,
) { ) {
@ -1181,8 +1343,7 @@ impl<T> Slot<T> {
let t = unsafe { (*slot).assume_init_read() }; let t = unsafe { (*slot).assume_init_read() };
// release lock // release lock
self.state self.state.store(SlotState::None.into(), Ordering::Release);
.store(SlotState::empty().into(), Ordering::Release);
Some(t) Some(t)
} }
Err(_) => None, Err(_) => None,
@ -1227,14 +1388,15 @@ mod scope {
use std::{ use std::{
future::Future, future::Future,
marker::{PhantomData, PhantomPinned}, marker::{PhantomData, PhantomPinned},
pin::pin,
ptr::{self, NonNull}, ptr::{self, NonNull},
}; };
use async_task::{Runnable, Task}; use async_task::{Runnable, Task};
use crate::{ use crate::{
latch::{CountWakeLatch, Latch}, latch::{CountWakeLatch, Latch, Probe, ThreadWakeLatch},
task::{HeapTask, TaskRef}, task::{HeapTask, StackTask, TaskRef},
ThreadPool, WorkerThread, ThreadPool, WorkerThread,
}; };
@ -1253,6 +1415,16 @@ mod scope {
} }
} }
pub fn join<F, G, T, U>(&self, f: F, g: G) -> (T, U)
where
F: FnOnce(&Self) -> T + Send,
G: FnOnce(&Self) -> U + Send,
T: Send,
U: Send,
{
self.pool.join(|| f(self), || g(self))
}
pub fn spawn<Fn>(&self, f: Fn) pub fn spawn<Fn>(&self, f: Fn)
where where
Fn: FnOnce(&Scope<'scope>) + Send + 'scope, Fn: FnOnce(&Scope<'scope>) + Send + 'scope,
@ -1267,7 +1439,7 @@ mod scope {
}); });
let taskref = unsafe { task.into_task_ref() }; let taskref = unsafe { task.into_task_ref() };
self.pool.push_local_or_inject(taskref); self.pool.inject_maybe_local(taskref);
} }
pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T> pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T>
@ -1289,7 +1461,7 @@ mod scope {
}; };
unsafe { unsafe {
ptr.as_ref().pool.push_local_or_inject(taskref); ptr.as_ref().pool.inject_maybe_local(taskref);
} }
}; };
@ -1332,8 +1504,58 @@ mod scope {
mod tests { mod tests {
use std::{cell::Cell, hint::black_box}; use std::{cell::Cell, hint::black_box};
use tracing::info;
use super::*; use super::*;
mod tree {
pub struct Tree<T> {
nodes: Box<[Node<T>]>,
root: Option<usize>,
}
pub struct Node<T> {
pub leaf: T,
pub left: Option<usize>,
pub right: Option<usize>,
}
impl<T> Tree<T> {
pub fn new(depth: usize, t: T) -> Tree<T>
where
T: Copy,
{
let mut nodes = Vec::with_capacity((0..depth).sum());
let root = Self::build_node(&mut nodes, depth, t);
Self {
nodes: nodes.into_boxed_slice(),
root: Some(root),
}
}
pub fn root(&self) -> Option<usize> {
self.root
}
pub fn get(&self, index: usize) -> &Node<T> {
&self.nodes[index]
}
pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
where
T: Copy,
{
let node = Node {
leaf: t,
left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
};
nodes.push(node);
nodes.len() - 1
}
}
}
const PRIMES: &'static [usize] = &[ const PRIMES: &'static [usize] = &[
1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283,
1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, 1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409,
@ -1344,9 +1566,14 @@ mod tests {
1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
]; ];
const REPEAT: usize = 0x100; #[cfg(feature = "spin-slow")]
const REPEAT: usize = 0x800;
#[cfg(not(feature = "spin-slow"))]
const REPEAT: usize = 0x8000;
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { const TREE_SIZE: usize = 10;
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T {
let pool = Box::new(pool); let pool = Box::new(pool);
let ptr = Box::into_raw(pool); let ptr = Box::into_raw(pool);
@ -1357,9 +1584,9 @@ mod tests {
let now = std::time::Instant::now(); let now = std::time::Instant::now();
let result = pool.scope(f); let result = pool.scope(f);
let elapsed = now.elapsed().as_micros(); let elapsed = now.elapsed().as_micros();
eprintln!("(mine) total time: {}ms", elapsed as f32 / 1e3); info!("(mine) total time: {}ms", elapsed as f32 / 1e3);
pool.resize_to(0); pool.resize_to(0);
assert!(pool.global_queue.pop().is_none()); assert!(pool.global_queue.is_empty());
result result
}; };
@ -1385,7 +1612,38 @@ mod tests {
}); });
let elapsed = now.elapsed().as_micros(); let elapsed = now.elapsed().as_micros();
eprintln!("(rayon) total time: {}ms", elapsed as f32 / 1e3); info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
}
#[test]
#[tracing_test::traced_test]
fn rayon_join() {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(bevy_tasks::available_parallelism())
.build()
.unwrap();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
let node = tree.get(node);
let (l, r) = rayon::join(
|| node.left.map(|node| sum(tree, node)).unwrap_or_default(),
|| node.right.map(|node| sum(tree, node)).unwrap_or_default(),
);
node.leaf + l + r
}
let now = std::time::Instant::now();
let sum = pool.scope(move |s| {
let root = tree.root().unwrap();
sum(&tree, root)
});
let elapsed = now.elapsed().as_micros();
info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
} }
#[test] #[test]
@ -1407,7 +1665,7 @@ mod tests {
}); });
let elapsed = now.elapsed().as_micros(); let elapsed = now.elapsed().as_micros();
eprintln!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3);
} }
#[test] #[test]
@ -1418,18 +1676,7 @@ mod tests {
} }
let counter = Arc::new(AtomicUsize::new(0)); let counter = Arc::new(AtomicUsize::new(0));
{ {
let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks { let pool = ThreadPool::new();
at_entry: Some(Arc::new(|_worker| {
// eprintln!("new worker thread: {}", worker.index);
})),
at_exit: Some(Arc::new({
let counter = counter.clone();
move |_worker: &WorkerThread| {
// eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get());
counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed);
}
})),
});
run_in_scope(pool, |s| { run_in_scope(pool, |s| {
for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
@ -1443,6 +1690,33 @@ mod tests {
// eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); // eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
} }
#[test]
#[tracing_test::traced_test]
fn mine_join() {
let pool = ThreadPool::new();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize, scope: &Scope<'_>) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
node.leaf + l + r
}
let sum = run_in_scope(pool, move |s| {
let root = tree.root().unwrap();
sum(&tree, root, s)
});
}
#[test] #[test]
#[tracing_test::traced_test] #[tracing_test::traced_test]
fn sync() { fn sync() {
@ -1452,11 +1726,19 @@ mod tests {
} }
let elapsed = now.elapsed().as_micros(); let elapsed = now.elapsed().as_micros();
eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3); info!("(sync) total time: {}ms", elapsed as f32 / 1e3);
} }
#[inline] #[inline]
fn spinning(i: usize) { fn spinning(i: usize) {
#[cfg(feature = "spin-slow")]
spinning_slow(i);
#[cfg(not(feature = "spin-slow"))]
spinning_fast(i);
}
#[inline]
fn spinning_slow(i: usize) {
let rng = rng::XorShift64Star::new(i as u64); let rng = rng::XorShift64Star::new(i as u64);
(0..i).reduce(|a, b| { (0..i).reduce(|a, b| {
black_box({ black_box({
@ -1465,4 +1747,16 @@ mod tests {
}) })
}); });
} }
#[inline]
fn spinning_fast(i: usize) {
let rng = rng::XorShift64Star::new(i as u64);
//(0..rng.next_usize(i)).reduce(|a, b| {
(0..20).reduce(|a, b| {
black_box({
let a = rng.next_usize(a.max(1));
a ^ b
})
});
}
} }