idk this sucks
This commit is contained in:
parent
a691b614bc
commit
736e4e1a60
|
@ -4,8 +4,12 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[features]
|
||||
internal_heartbeat = []
|
||||
heartbeat = []
|
||||
spin-slow = []
|
||||
cpu-pinning = []
|
||||
work-stealing = []
|
||||
prefer-local = []
|
||||
never-local = []
|
||||
|
||||
|
||||
[dependencies]
|
||||
|
@ -16,6 +20,7 @@ bevy_tasks = "0.15.1"
|
|||
parking_lot = "0.12.3"
|
||||
thread_local = "1.1.8"
|
||||
crossbeam = "0.8.4"
|
||||
st3 = "0.4"
|
||||
|
||||
async-task = "4.7.1"
|
||||
|
||||
|
|
554
src/lib.rs
554
src/lib.rs
|
@ -1,11 +1,10 @@
|
|||
use std::{
|
||||
cell::{OnceCell, UnsafeCell},
|
||||
collections::VecDeque,
|
||||
cell::{Cell, UnsafeCell},
|
||||
future::Future,
|
||||
mem::MaybeUninit,
|
||||
num::NonZero,
|
||||
pin::{pin, Pin},
|
||||
ptr::NonNull,
|
||||
ptr::{self, NonNull},
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
|
||||
Arc,
|
||||
|
@ -17,7 +16,11 @@ use std::{
|
|||
|
||||
use async_task::{Runnable, Task};
|
||||
use bitflags::bitflags;
|
||||
use crossbeam::{queue::SegQueue, utils::CachePadded};
|
||||
use crossbeam::{
|
||||
atomic::AtomicCell,
|
||||
deque::{Injector, Stealer, Worker},
|
||||
utils::CachePadded,
|
||||
};
|
||||
use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch};
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use scope::Scope;
|
||||
|
@ -337,7 +340,7 @@ pub mod latch {
|
|||
pub struct ThreadPoolState {
|
||||
num_threads: AtomicUsize,
|
||||
lock: Mutex<()>,
|
||||
heartbeat_state: CachePadded<ThreadState>,
|
||||
heartbeat_state: CachePadded<ThreadControl>,
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
|
@ -348,15 +351,21 @@ bitflags! {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ThreadState {
|
||||
should_shove: AtomicBool,
|
||||
shoved_task: Slot<TaskRef>,
|
||||
pub struct ThreadControl {
|
||||
status: Mutex<ThreadStatus>,
|
||||
status_changed: Condvar,
|
||||
should_terminate: AtomicLatch,
|
||||
}
|
||||
|
||||
impl ThreadState {
|
||||
pub struct ThreadState {
|
||||
should_shove: AtomicBool,
|
||||
control: ThreadControl,
|
||||
stealer: Stealer<TaskRef>,
|
||||
worker: AtomicCell<Option<Worker<TaskRef>>>,
|
||||
shoved_task: CachePadded<Slot<TaskRef>>,
|
||||
}
|
||||
|
||||
impl ThreadControl {
|
||||
/// returns true if thread was sleeping
|
||||
#[inline]
|
||||
fn wake(&self) -> bool {
|
||||
|
@ -451,40 +460,48 @@ impl ThreadPoolCallbacks {
|
|||
pub struct ThreadPool {
|
||||
threads: [CachePadded<ThreadState>; MAX_THREADS],
|
||||
pool_state: CachePadded<ThreadPoolState>,
|
||||
global_queue: SegQueue<TaskRef>,
|
||||
global_queue: Injector<TaskRef>,
|
||||
callbacks: CachePadded<ThreadPoolCallbacks>,
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState {
|
||||
pub fn new() -> Self {
|
||||
Self::new_with_callbacks(ThreadPoolCallbacks::new_empty())
|
||||
}
|
||||
|
||||
pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
|
||||
let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| {
|
||||
let worker = Worker::<TaskRef>::new_fifo();
|
||||
let stealer = worker.stealer();
|
||||
|
||||
let thread = CachePadded::new(ThreadState {
|
||||
should_shove: AtomicBool::new(false),
|
||||
shoved_task: Slot::new(),
|
||||
shoved_task: Slot::new().into(),
|
||||
control: ThreadControl {
|
||||
status: Mutex::new(ThreadStatus::empty()),
|
||||
status_changed: Condvar::new(),
|
||||
should_terminate: AtomicLatch::new(),
|
||||
},
|
||||
stealer,
|
||||
worker: AtomicCell::new(Some(worker)),
|
||||
});
|
||||
uninit.write(thread);
|
||||
unsafe { uninit.assume_init() }
|
||||
});
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
|
||||
pool_state: CachePadded::new(ThreadPoolState {
|
||||
num_threads: AtomicUsize::new(0),
|
||||
lock: Mutex::new(()),
|
||||
heartbeat_state: Self::INITIAL_THREAD_STATE,
|
||||
}),
|
||||
global_queue: SegQueue::new(),
|
||||
callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
|
||||
Self {
|
||||
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
|
||||
threads,
|
||||
pool_state: CachePadded::new(ThreadPoolState {
|
||||
num_threads: AtomicUsize::new(0),
|
||||
lock: Mutex::new(()),
|
||||
heartbeat_state: Self::INITIAL_THREAD_STATE,
|
||||
heartbeat_state: ThreadControl {
|
||||
status: Mutex::new(ThreadStatus::empty()),
|
||||
status_changed: Condvar::new(),
|
||||
should_terminate: AtomicLatch::new(),
|
||||
}
|
||||
.into(),
|
||||
}),
|
||||
global_queue: SegQueue::new(),
|
||||
global_queue: Injector::new(),
|
||||
callbacks: CachePadded::new(callbacks),
|
||||
}
|
||||
}
|
||||
|
@ -495,7 +512,7 @@ impl ThreadPool {
|
|||
}
|
||||
|
||||
pub fn wake_thread(&self, index: usize) -> Option<bool> {
|
||||
Some(self.threads.get(index as usize)?.wake())
|
||||
Some(self.threads.get(index as usize)?.control.wake())
|
||||
}
|
||||
|
||||
pub fn wake_any(&self, count: usize) -> usize {
|
||||
|
@ -503,7 +520,7 @@ impl ThreadPool {
|
|||
let num_woken = self
|
||||
.threads
|
||||
.iter()
|
||||
.filter_map(|thread| thread.wake().then_some(()))
|
||||
.filter_map(|thread| thread.control.wake().then_some(()))
|
||||
.take(count)
|
||||
.count();
|
||||
num_woken
|
||||
|
@ -517,6 +534,27 @@ impl ThreadPool {
|
|||
core::ptr::from_ref(self) as usize
|
||||
}
|
||||
|
||||
fn push_local_or_inject_balanced(&self, task: TaskRef) {
|
||||
let global_len = self.global_queue.len();
|
||||
WorkerThread::with(|worker| match worker {
|
||||
Some(worker) if worker.pool.id() == self.id() => {
|
||||
let worker_len = worker.worker.len();
|
||||
if worker_len == 0 {
|
||||
worker.push_task(task);
|
||||
} else if global_len == 0 {
|
||||
self.inject(task);
|
||||
} else {
|
||||
if worker_len >= global_len {
|
||||
worker.push_task(task);
|
||||
} else {
|
||||
self.inject(task);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => self.inject(task),
|
||||
})
|
||||
}
|
||||
|
||||
fn push_local_or_inject(&self, task: TaskRef) {
|
||||
WorkerThread::with(|worker| match worker {
|
||||
Some(worker) if worker.pool.id() == self.id() => worker.push_task(task),
|
||||
|
@ -536,6 +574,15 @@ impl ThreadPool {
|
|||
self.wake_any(n);
|
||||
}
|
||||
|
||||
fn inject_maybe_local(&self, task: TaskRef) {
|
||||
#[cfg(all(not(feature = "never-local"), feature = "prefer-local"))]
|
||||
self.push_local_or_inject(task);
|
||||
#[cfg(all(not(feature = "prefer-local"), feature = "never-local"))]
|
||||
self.inject(task);
|
||||
#[cfg(not(any(feature = "prefer-local", feature = "never-local")))]
|
||||
self.push_local_or_inject_balanced(task);
|
||||
}
|
||||
|
||||
fn inject(&self, task: TaskRef) {
|
||||
self.global_queue.push(task);
|
||||
|
||||
|
@ -581,10 +628,10 @@ impl ThreadPool {
|
|||
}
|
||||
|
||||
for thread in new_threads {
|
||||
thread.wait_for_running();
|
||||
thread.control.wait_for_running();
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "internal_heartbeat"))]
|
||||
#[cfg(feature = "heartbeat")]
|
||||
if current_size == 0 {
|
||||
std::thread::spawn(move || {
|
||||
heartbeat_loop(self);
|
||||
|
@ -601,13 +648,13 @@ impl ThreadPool {
|
|||
let terminating_threads = &self.threads[new_size..current_size];
|
||||
|
||||
for thread in terminating_threads {
|
||||
thread.notify_should_terminate();
|
||||
thread.control.notify_should_terminate();
|
||||
}
|
||||
for thread in terminating_threads {
|
||||
thread.wait_for_termination();
|
||||
thread.control.wait_for_termination();
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "internal_heartbeat"))]
|
||||
#[cfg(feature = "heartbeat")]
|
||||
if new_size == 0 {
|
||||
self.pool_state.heartbeat_state.notify_should_terminate();
|
||||
self.pool_state.heartbeat_state.wait_for_termination();
|
||||
|
@ -712,7 +759,7 @@ impl ThreadPool {
|
|||
}));
|
||||
|
||||
let taskref = task.into_ref().as_task_ref();
|
||||
self.inject(taskref);
|
||||
self.push_local_or_inject(taskref);
|
||||
|
||||
worker.run_until(&latch);
|
||||
result.unwrap()
|
||||
|
@ -727,7 +774,7 @@ impl ThreadPool {
|
|||
let task = HeapTask::new(f);
|
||||
|
||||
let taskref = unsafe { task.into_static_task_ref() };
|
||||
self.push_local_or_inject(taskref);
|
||||
self.inject_maybe_local(taskref);
|
||||
}
|
||||
|
||||
fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T>
|
||||
|
@ -745,7 +792,7 @@ impl ThreadPool {
|
|||
})
|
||||
};
|
||||
|
||||
self.push_local_or_inject(taskref);
|
||||
self.inject_maybe_local(taskref);
|
||||
};
|
||||
|
||||
let (runnable, task) = async_task::spawn(future, schedule);
|
||||
|
@ -789,6 +836,30 @@ impl ThreadPool {
|
|||
}
|
||||
|
||||
fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
G: FnOnce() -> U + Send,
|
||||
T: Send,
|
||||
U: Send,
|
||||
{
|
||||
self.join_threaded(f, g)
|
||||
}
|
||||
|
||||
fn join_seq<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
G: FnOnce() -> U + Send,
|
||||
T: Send,
|
||||
U: Send,
|
||||
{
|
||||
let a = f();
|
||||
let b = g();
|
||||
|
||||
(a, b)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn join_threaded<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
G: FnOnce() -> U + Send,
|
||||
|
@ -808,7 +879,6 @@ impl ThreadPool {
|
|||
|
||||
let ref_b = task_b.as_ref().as_task_ref();
|
||||
let b_id = ref_b.id();
|
||||
// TODO: maybe try to push this off to another thread immediately first?
|
||||
worker.push_task(ref_b);
|
||||
|
||||
let result_a = f();
|
||||
|
@ -817,14 +887,18 @@ impl ThreadPool {
|
|||
match worker.pop_task() {
|
||||
Some(task) => {
|
||||
if task.id() == b_id {
|
||||
worker.try_promote();
|
||||
// we're not calling execute() here, so manually try
|
||||
// shoving a task.
|
||||
//worker.try_promote();
|
||||
worker.shove_task();
|
||||
unsafe {
|
||||
task_b.run_as_ref();
|
||||
}
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
worker.execute(task);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
worker.run_until(&latch_b);
|
||||
}
|
||||
|
@ -837,12 +911,12 @@ impl ThreadPool {
|
|||
|
||||
fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T
|
||||
where
|
||||
Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send,
|
||||
Fn: FnOnce(&Scope<'scope>) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
self.in_worker(|owner, _| {
|
||||
let scope = pin!(unsafe { Scope::<'scope>::new(owner) });
|
||||
let result = f(scope.as_ref());
|
||||
let scope = unsafe { Scope::<'scope>::new(owner) };
|
||||
let result = f(&scope);
|
||||
scope.complete(owner);
|
||||
result
|
||||
})
|
||||
|
@ -850,7 +924,8 @@ impl ThreadPool {
|
|||
}
|
||||
|
||||
pub struct WorkerThread {
|
||||
queue: TaskQueue<TaskRef>,
|
||||
// queue: TaskQueue<TaskRef>,
|
||||
worker: Worker<TaskRef>,
|
||||
pool: &'static ThreadPool,
|
||||
index: usize,
|
||||
rng: rng::XorShift64Star,
|
||||
|
@ -860,7 +935,7 @@ pub struct WorkerThread {
|
|||
const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) };
|
||||
|
||||
std::thread_local! {
|
||||
static WORKER_THREAD_STATE: CachePadded<OnceCell<WorkerThread>> = const {CachePadded::new(OnceCell::new())};
|
||||
static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())};
|
||||
}
|
||||
|
||||
impl WorkerThread {
|
||||
|
@ -880,25 +955,47 @@ impl WorkerThread {
|
|||
fn is_worker_thread() -> bool {
|
||||
Self::with(|worker| worker.is_some())
|
||||
}
|
||||
|
||||
fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T {
|
||||
WORKER_THREAD_STATE.with(|thread| f(thread.get()))
|
||||
WORKER_THREAD_STATE.with(|thread| {
|
||||
f(NonNull::<WorkerThread>::new(thread.get().cast_mut())
|
||||
.map(|ptr| unsafe { ptr.as_ref() }))
|
||||
})
|
||||
}
|
||||
#[inline]
|
||||
fn pop_task(&self) -> Option<TaskRef> {
|
||||
self.queue.pop_front()
|
||||
self.worker.pop()
|
||||
//self.queue.pop_front(task);
|
||||
}
|
||||
#[inline]
|
||||
fn push_task(&self, task: TaskRef) {
|
||||
self.queue.push_front(task);
|
||||
self.worker.push(task);
|
||||
//self.queue.push_front(task);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn drain(&self) -> impl Iterator<Item = TaskRef> {
|
||||
self.queue.drain()
|
||||
// self.queue.drain()
|
||||
core::iter::empty()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn steal_tasks(&self) -> Option<TaskRef> {
|
||||
// careful not to call threads() here because that omits any threads
|
||||
// that were killed, which might still have tasks.
|
||||
let threads = &self.pool.threads;
|
||||
let (start, end) = threads.split_at(self.rng.next_usize(threads.len()));
|
||||
|
||||
end.iter()
|
||||
.chain(start)
|
||||
.find_map(|thread: &CachePadded<ThreadState>| {
|
||||
thread.stealer.steal_batch_and_pop(&self.worker).success()
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn claim_shoved_task(&self) -> Option<TaskRef> {
|
||||
// take own shoved task first
|
||||
if let Some(task) = self.info().shoved_task.try_take() {
|
||||
return Some(task);
|
||||
}
|
||||
|
@ -916,12 +1013,16 @@ impl WorkerThread {
|
|||
|
||||
#[cold]
|
||||
fn shove_task(&self) {
|
||||
if let Some(task) = self.queue.pop_back() {
|
||||
if !self.info().shoved_task.is_occupied() {
|
||||
if let Some(task) = self.info().stealer.steal().success() {
|
||||
match self.info().shoved_task.try_put(task) {
|
||||
// shoved task is occupied, reinsert into queue
|
||||
Some(task) => self.queue.push_back(task),
|
||||
// this really shouldn't happen
|
||||
Some(_task) => unreachable!(),
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// wake thread to execute task
|
||||
self.pool.wake_any(1);
|
||||
}
|
||||
|
@ -934,24 +1035,15 @@ impl WorkerThread {
|
|||
|
||||
#[inline]
|
||||
fn try_promote(&self) {
|
||||
#[cfg(feature = "internal_heartbeat")]
|
||||
let now = std::time::Instant::now();
|
||||
// SAFETY: workerthread is thread-local non-sync
|
||||
|
||||
#[cfg(feature = "internal_heartbeat")]
|
||||
let should_shove =
|
||||
unsafe { *self.last_heartbeat.get() }.duration_since(now) > HEARTBEAT_INTERVAL;
|
||||
#[cfg(not(feature = "internal_heartbeat"))]
|
||||
#[cfg(feature = "heartbeat")]
|
||||
let should_shove = self.info().should_shove.load(Ordering::Acquire);
|
||||
#[cfg(not(feature = "heartbeat"))]
|
||||
let should_shove = true;
|
||||
|
||||
if should_shove {
|
||||
// SAFETY: workerthread is thread-local non-sync
|
||||
#[cfg(feature = "internal_heartbeat")]
|
||||
unsafe {
|
||||
*&mut *self.last_heartbeat.get() = now;
|
||||
}
|
||||
#[cfg(not(feature = "internal_heartbeat"))]
|
||||
#[cfg(feature = "heartbeat")]
|
||||
self.info().should_shove.store(false, Ordering::Release);
|
||||
|
||||
self.shove_task();
|
||||
}
|
||||
}
|
||||
|
@ -959,9 +1051,22 @@ impl WorkerThread {
|
|||
#[inline]
|
||||
fn find_any_task(&self) -> Option<TaskRef> {
|
||||
// TODO: attempt stealing work here, too.
|
||||
self.pop_task()
|
||||
let mut task = self
|
||||
.pop_task()
|
||||
.or_else(|| self.claim_shoved_task())
|
||||
.or_else(|| self.pool.global_queue.pop())
|
||||
.or_else(|| {
|
||||
self.pool
|
||||
.global_queue
|
||||
.steal_batch_and_pop(&self.worker)
|
||||
.success()
|
||||
});
|
||||
|
||||
#[cfg(feature = "work-stealing")]
|
||||
{
|
||||
task = task.or_else(|| self.steal_tasks());
|
||||
}
|
||||
|
||||
task
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -991,34 +1096,36 @@ impl WorkerThread {
|
|||
self.execute(task);
|
||||
}
|
||||
None => {
|
||||
debug!("waiting for tasks");
|
||||
self.info().wait_for_should_wake();
|
||||
//debug!("waiting for tasks");
|
||||
self.info().control.wait_for_should_wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn worker_loop(pool: &'static ThreadPool, index: usize) {
|
||||
let info = &pool.threads()[index as usize];
|
||||
|
||||
WORKER_THREAD_STATE.with(|worker| {
|
||||
let worker = worker.get_or_init(|| WorkerThread {
|
||||
queue: TaskQueue::new(),
|
||||
let worker = CachePadded::new(WorkerThread {
|
||||
// queue: TaskQueue::new(),
|
||||
worker: info.worker.take().unwrap(),
|
||||
pool,
|
||||
index,
|
||||
rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64),
|
||||
last_heartbeat: UnsafeCell::new(std::time::Instant::now()),
|
||||
});
|
||||
|
||||
WORKER_THREAD_STATE.with(|cell| {
|
||||
cell.set(&*worker);
|
||||
|
||||
if let Some(callback) = pool.callbacks.at_entry.as_ref() {
|
||||
callback(worker);
|
||||
callback(&worker);
|
||||
}
|
||||
|
||||
info.notify_running();
|
||||
info.control.notify_running();
|
||||
// info.notify_running();
|
||||
worker.run_until(&info.should_terminate);
|
||||
worker.run_until(&info.control.should_terminate);
|
||||
|
||||
if let Some(callback) = pool.callbacks.at_exit.as_ref() {
|
||||
callback(worker);
|
||||
callback(&worker);
|
||||
}
|
||||
|
||||
for task in worker.drain() {
|
||||
|
@ -1028,9 +1135,14 @@ impl WorkerThread {
|
|||
if let Some(task) = info.shoved_task.try_take() {
|
||||
pool.inject(task);
|
||||
}
|
||||
|
||||
cell.set(ptr::null());
|
||||
});
|
||||
|
||||
info.notify_termination();
|
||||
let WorkerThread { worker, .. } = CachePadded::into_inner(worker);
|
||||
info.worker.store(Some(worker));
|
||||
|
||||
info.control.notify_termination();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1061,56 +1173,68 @@ fn heartbeat_loop(pool: &'static ThreadPool) {
|
|||
state.notify_termination();
|
||||
}
|
||||
|
||||
pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>);
|
||||
use vec_queue::TaskQueue;
|
||||
|
||||
impl<T> TaskQueue<T> {
|
||||
mod vec_queue {
|
||||
use std::{cell::UnsafeCell, collections::VecDeque};
|
||||
|
||||
pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>);
|
||||
|
||||
impl<T> TaskQueue<T> {
|
||||
/// Creates a new [`TaskQueue<T>`].
|
||||
#[inline]
|
||||
const fn new() -> Self {
|
||||
pub const fn new() -> Self {
|
||||
Self(UnsafeCell::new(VecDeque::new()))
|
||||
}
|
||||
#[inline]
|
||||
fn get_mut(&self) -> &mut VecDeque<T> {
|
||||
pub fn get_mut(&self) -> &mut VecDeque<T> {
|
||||
unsafe { &mut *self.0.get() }
|
||||
}
|
||||
#[inline]
|
||||
fn pop_front(&self) -> Option<T> {
|
||||
pub fn pop_front(&self) -> Option<T> {
|
||||
self.get_mut().pop_front()
|
||||
}
|
||||
#[inline]
|
||||
fn pop_back(&self) -> Option<T> {
|
||||
pub fn pop_back(&self) -> Option<T> {
|
||||
self.get_mut().pop_back()
|
||||
}
|
||||
#[inline]
|
||||
fn push_back(&self, t: T) {
|
||||
pub fn push_back(&self, t: T) {
|
||||
self.get_mut().push_back(t);
|
||||
}
|
||||
#[inline]
|
||||
fn push_front(&self, t: T) {
|
||||
pub fn push_front(&self, t: T) {
|
||||
self.get_mut().push_front(t);
|
||||
}
|
||||
#[inline]
|
||||
fn take(&self) -> VecDeque<T> {
|
||||
pub fn take(&self) -> VecDeque<T> {
|
||||
let this = core::mem::replace(self.get_mut(), VecDeque::new());
|
||||
this
|
||||
}
|
||||
#[inline]
|
||||
fn drain(&self) -> impl Iterator<Item = T> {
|
||||
pub fn drain(&self) -> impl Iterator<Item = T> {
|
||||
self.take().into_iter()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SlotState: u8 {
|
||||
const LOCKED = 1 << 1;
|
||||
const OCCUPIED = 1 << 2;
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum SlotState {
|
||||
None,
|
||||
Locked,
|
||||
Occupied,
|
||||
}
|
||||
|
||||
impl From<u8> for SlotState {
|
||||
fn from(value: u8) -> Self {
|
||||
unsafe { core::mem::transmute(value) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SlotState> for u8 {
|
||||
fn from(value: SlotState) -> Self {
|
||||
value.bits()
|
||||
value as u8
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1125,10 +1249,7 @@ unsafe impl<T> Sync for Slot<T> where T: Send {}
|
|||
impl<T> Drop for Slot<T> {
|
||||
fn drop(&mut self) {
|
||||
if core::mem::needs_drop::<T>() {
|
||||
if SlotState::from_bits(*self.state.get_mut())
|
||||
.unwrap()
|
||||
.contains(SlotState::OCCUPIED)
|
||||
{
|
||||
if *self.state.get_mut() == SlotState::Occupied as u8 {
|
||||
unsafe {
|
||||
self.slot.get().drop_in_place();
|
||||
}
|
||||
|
@ -1141,15 +1262,56 @@ impl<T> Slot<T> {
|
|||
pub const fn new() -> Slot<T> {
|
||||
Self {
|
||||
slot: UnsafeCell::new(MaybeUninit::uninit()),
|
||||
state: AtomicU8::new(SlotState::empty().bits()),
|
||||
state: AtomicU8::new(SlotState::None as u8),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_occupied(&self) -> bool {
|
||||
self.state.load(Ordering::Acquire) == SlotState::Occupied.into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn insert(&self, t: T) -> Option<T> {
|
||||
let value = match self
|
||||
.state
|
||||
.swap(SlotState::Locked.into(), Ordering::AcqRel)
|
||||
.into()
|
||||
{
|
||||
SlotState::Locked => {
|
||||
// return early: was already locked.
|
||||
debug!("slot was already locked");
|
||||
return None;
|
||||
}
|
||||
SlotState::Occupied => {
|
||||
let slot = self.slot.get();
|
||||
// replace
|
||||
unsafe {
|
||||
let v = (*slot).assume_init_read();
|
||||
(*slot).write(t);
|
||||
Some(v)
|
||||
}
|
||||
}
|
||||
SlotState::None => {
|
||||
let slot = self.slot.get();
|
||||
// insert
|
||||
unsafe {
|
||||
(*slot).write(t);
|
||||
}
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// release lock
|
||||
self.state
|
||||
.store(SlotState::Occupied.into(), Ordering::Release);
|
||||
value
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn try_put(&self, t: T) -> Option<T> {
|
||||
match self.state.compare_exchange(
|
||||
SlotState::empty().into(),
|
||||
SlotState::LOCKED.into(),
|
||||
SlotState::None.into(),
|
||||
SlotState::Locked.into(),
|
||||
Ordering::Acquire,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
|
@ -1161,7 +1323,7 @@ impl<T> Slot<T> {
|
|||
|
||||
// release lock
|
||||
self.state
|
||||
.store(SlotState::OCCUPIED.into(), Ordering::Release);
|
||||
.store(SlotState::Occupied.into(), Ordering::Release);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -1170,8 +1332,8 @@ impl<T> Slot<T> {
|
|||
#[inline]
|
||||
pub fn try_take(&self) -> Option<T> {
|
||||
match self.state.compare_exchange(
|
||||
SlotState::OCCUPIED.into(),
|
||||
SlotState::LOCKED.into(),
|
||||
SlotState::Occupied.into(),
|
||||
SlotState::Locked.into(),
|
||||
Ordering::Acquire,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
|
@ -1181,8 +1343,7 @@ impl<T> Slot<T> {
|
|||
let t = unsafe { (*slot).assume_init_read() };
|
||||
|
||||
// release lock
|
||||
self.state
|
||||
.store(SlotState::empty().into(), Ordering::Release);
|
||||
self.state.store(SlotState::None.into(), Ordering::Release);
|
||||
Some(t)
|
||||
}
|
||||
Err(_) => None,
|
||||
|
@ -1227,14 +1388,15 @@ mod scope {
|
|||
use std::{
|
||||
future::Future,
|
||||
marker::{PhantomData, PhantomPinned},
|
||||
pin::pin,
|
||||
ptr::{self, NonNull},
|
||||
};
|
||||
|
||||
use async_task::{Runnable, Task};
|
||||
|
||||
use crate::{
|
||||
latch::{CountWakeLatch, Latch},
|
||||
task::{HeapTask, TaskRef},
|
||||
latch::{CountWakeLatch, Latch, Probe, ThreadWakeLatch},
|
||||
task::{HeapTask, StackTask, TaskRef},
|
||||
ThreadPool, WorkerThread,
|
||||
};
|
||||
|
||||
|
@ -1253,6 +1415,16 @@ mod scope {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn join<F, G, T, U>(&self, f: F, g: G) -> (T, U)
|
||||
where
|
||||
F: FnOnce(&Self) -> T + Send,
|
||||
G: FnOnce(&Self) -> U + Send,
|
||||
T: Send,
|
||||
U: Send,
|
||||
{
|
||||
self.pool.join(|| f(self), || g(self))
|
||||
}
|
||||
|
||||
pub fn spawn<Fn>(&self, f: Fn)
|
||||
where
|
||||
Fn: FnOnce(&Scope<'scope>) + Send + 'scope,
|
||||
|
@ -1267,7 +1439,7 @@ mod scope {
|
|||
});
|
||||
|
||||
let taskref = unsafe { task.into_task_ref() };
|
||||
self.pool.push_local_or_inject(taskref);
|
||||
self.pool.inject_maybe_local(taskref);
|
||||
}
|
||||
|
||||
pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T>
|
||||
|
@ -1289,7 +1461,7 @@ mod scope {
|
|||
};
|
||||
|
||||
unsafe {
|
||||
ptr.as_ref().pool.push_local_or_inject(taskref);
|
||||
ptr.as_ref().pool.inject_maybe_local(taskref);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1332,8 +1504,58 @@ mod scope {
|
|||
mod tests {
|
||||
use std::{cell::Cell, hint::black_box};
|
||||
|
||||
use tracing::info;
|
||||
|
||||
use super::*;
|
||||
|
||||
mod tree {
|
||||
|
||||
pub struct Tree<T> {
|
||||
nodes: Box<[Node<T>]>,
|
||||
root: Option<usize>,
|
||||
}
|
||||
pub struct Node<T> {
|
||||
pub leaf: T,
|
||||
pub left: Option<usize>,
|
||||
pub right: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T> Tree<T> {
|
||||
pub fn new(depth: usize, t: T) -> Tree<T>
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
let mut nodes = Vec::with_capacity((0..depth).sum());
|
||||
let root = Self::build_node(&mut nodes, depth, t);
|
||||
Self {
|
||||
nodes: nodes.into_boxed_slice(),
|
||||
root: Some(root),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn root(&self) -> Option<usize> {
|
||||
self.root
|
||||
}
|
||||
|
||||
pub fn get(&self, index: usize) -> &Node<T> {
|
||||
&self.nodes[index]
|
||||
}
|
||||
|
||||
pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
let node = Node {
|
||||
leaf: t,
|
||||
left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
|
||||
right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
|
||||
};
|
||||
nodes.push(node);
|
||||
nodes.len() - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const PRIMES: &'static [usize] = &[
|
||||
1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283,
|
||||
1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409,
|
||||
|
@ -1344,9 +1566,14 @@ mod tests {
|
|||
1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
|
||||
];
|
||||
|
||||
const REPEAT: usize = 0x100;
|
||||
#[cfg(feature = "spin-slow")]
|
||||
const REPEAT: usize = 0x800;
|
||||
#[cfg(not(feature = "spin-slow"))]
|
||||
const REPEAT: usize = 0x8000;
|
||||
|
||||
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T {
|
||||
const TREE_SIZE: usize = 10;
|
||||
|
||||
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T {
|
||||
let pool = Box::new(pool);
|
||||
let ptr = Box::into_raw(pool);
|
||||
|
||||
|
@ -1357,9 +1584,9 @@ mod tests {
|
|||
let now = std::time::Instant::now();
|
||||
let result = pool.scope(f);
|
||||
let elapsed = now.elapsed().as_micros();
|
||||
eprintln!("(mine) total time: {}ms", elapsed as f32 / 1e3);
|
||||
info!("(mine) total time: {}ms", elapsed as f32 / 1e3);
|
||||
pool.resize_to(0);
|
||||
assert!(pool.global_queue.pop().is_none());
|
||||
assert!(pool.global_queue.is_empty());
|
||||
result
|
||||
};
|
||||
|
||||
|
@ -1385,7 +1612,38 @@ mod tests {
|
|||
});
|
||||
let elapsed = now.elapsed().as_micros();
|
||||
|
||||
eprintln!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
|
||||
info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tracing_test::traced_test]
|
||||
fn rayon_join() {
|
||||
let pool = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(bevy_tasks::available_parallelism())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let tree = tree::Tree::new(TREE_SIZE, 1u32);
|
||||
|
||||
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = rayon::join(
|
||||
|| node.left.map(|node| sum(tree, node)).unwrap_or_default(),
|
||||
|| node.right.map(|node| sum(tree, node)).unwrap_or_default(),
|
||||
);
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
let sum = pool.scope(move |s| {
|
||||
let root = tree.root().unwrap();
|
||||
sum(&tree, root)
|
||||
});
|
||||
|
||||
let elapsed = now.elapsed().as_micros();
|
||||
|
||||
info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1407,7 +1665,7 @@ mod tests {
|
|||
});
|
||||
let elapsed = now.elapsed().as_micros();
|
||||
|
||||
eprintln!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3);
|
||||
info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1418,18 +1676,7 @@ mod tests {
|
|||
}
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
{
|
||||
let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks {
|
||||
at_entry: Some(Arc::new(|_worker| {
|
||||
// eprintln!("new worker thread: {}", worker.index);
|
||||
})),
|
||||
at_exit: Some(Arc::new({
|
||||
let counter = counter.clone();
|
||||
move |_worker: &WorkerThread| {
|
||||
// eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get());
|
||||
counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed);
|
||||
}
|
||||
})),
|
||||
});
|
||||
let pool = ThreadPool::new();
|
||||
|
||||
run_in_scope(pool, |s| {
|
||||
for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
|
||||
|
@ -1443,6 +1690,33 @@ mod tests {
|
|||
// eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tracing_test::traced_test]
|
||||
fn mine_join() {
|
||||
let pool = ThreadPool::new();
|
||||
|
||||
let tree = tree::Tree::new(TREE_SIZE, 1u32);
|
||||
|
||||
fn sum(tree: &tree::Tree<u32>, node: usize, scope: &Scope<'_>) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = scope.join(
|
||||
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
||||
|s| {
|
||||
node.right
|
||||
.map(|node| sum(tree, node, s))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
);
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
let sum = run_in_scope(pool, move |s| {
|
||||
let root = tree.root().unwrap();
|
||||
sum(&tree, root, s)
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tracing_test::traced_test]
|
||||
fn sync() {
|
||||
|
@ -1452,11 +1726,19 @@ mod tests {
|
|||
}
|
||||
let elapsed = now.elapsed().as_micros();
|
||||
|
||||
eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3);
|
||||
info!("(sync) total time: {}ms", elapsed as f32 / 1e3);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn spinning(i: usize) {
|
||||
#[cfg(feature = "spin-slow")]
|
||||
spinning_slow(i);
|
||||
#[cfg(not(feature = "spin-slow"))]
|
||||
spinning_fast(i);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn spinning_slow(i: usize) {
|
||||
let rng = rng::XorShift64Star::new(i as u64);
|
||||
(0..i).reduce(|a, b| {
|
||||
black_box({
|
||||
|
@ -1465,4 +1747,16 @@ mod tests {
|
|||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn spinning_fast(i: usize) {
|
||||
let rng = rng::XorShift64Star::new(i as u64);
|
||||
//(0..rng.next_usize(i)).reduce(|a, b| {
|
||||
(0..20).reduce(|a, b| {
|
||||
black_box({
|
||||
let a = rng.next_usize(a.max(1));
|
||||
a ^ b
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue