1922 lines
52 KiB
Rust
1922 lines
52 KiB
Rust
#![feature(
|
|
vec_deque_pop_if,
|
|
unsafe_cell_access,
|
|
debug_closure_helpers,
|
|
cold_path,
|
|
fn_align,
|
|
box_vec_non_null,
|
|
box_as_ptr,
|
|
atomic_try_update,
|
|
let_chains
|
|
)]
|
|
|
|
use std::{
|
|
cell::{Cell, UnsafeCell},
|
|
future::Future,
|
|
mem::MaybeUninit,
|
|
num::NonZero,
|
|
pin::{pin, Pin},
|
|
ptr::{self, NonNull},
|
|
sync::{
|
|
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
|
|
Arc,
|
|
},
|
|
task::Context,
|
|
thread::available_parallelism,
|
|
time::Duration,
|
|
};
|
|
|
|
use async_task::{Runnable, Task};
|
|
use bitflags::bitflags;
|
|
use crossbeam::{
|
|
atomic::AtomicCell,
|
|
deque::{Injector, Stealer, Worker},
|
|
utils::CachePadded,
|
|
};
|
|
use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch};
|
|
use parking_lot::{Condvar, Mutex};
|
|
use scope::Scope;
|
|
use task::{HeapTask, StackTask, TaskRef};
|
|
use tracing::debug;
|
|
|
|
pub mod job;
|
|
pub mod util;
|
|
|
|
pub mod task {
|
|
use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin};
|
|
|
|
pub trait Task {
|
|
unsafe fn execute(this: *const ());
|
|
}
|
|
|
|
pub struct TaskRef {
|
|
ptr: *const (),
|
|
execute_fn: unsafe fn(*const ()),
|
|
}
|
|
|
|
impl TaskRef {
|
|
pub unsafe fn new<T>(task: *const T) -> TaskRef
|
|
where
|
|
T: Task,
|
|
{
|
|
Self {
|
|
ptr: task.cast(),
|
|
execute_fn: <T as Task>::execute,
|
|
}
|
|
}
|
|
|
|
pub unsafe fn new_raw(ptr: *const (), execute_fn: unsafe fn(*const ())) -> TaskRef {
|
|
Self { ptr, execute_fn }
|
|
}
|
|
|
|
#[inline]
|
|
pub fn id(&self) -> impl Eq {
|
|
(self.ptr, self.execute_fn)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn execute(self) {
|
|
unsafe { (self.execute_fn)(self.ptr) }
|
|
}
|
|
#[inline]
|
|
pub unsafe fn execute_with_scope<T>(self, scope: &mut T) {
|
|
unsafe {
|
|
core::mem::transmute::<_, unsafe fn(*const (), &mut T)>(self.execute_fn)(
|
|
self.ptr, scope,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
unsafe impl Send for TaskRef {}
|
|
unsafe impl Sync for TaskRef {}
|
|
|
|
pub struct StackTask<F: FnOnce() + Send> {
|
|
task: UnsafeCell<Option<F>>,
|
|
_phantom: PhantomPinned,
|
|
}
|
|
|
|
impl<F: FnOnce() + Send> StackTask<F> {
|
|
pub fn new(task: F) -> StackTask<F> {
|
|
Self {
|
|
task: UnsafeCell::new(Some(task)),
|
|
_phantom: PhantomPinned,
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn run(self) {
|
|
self.task.into_inner().unwrap()();
|
|
}
|
|
|
|
#[inline]
|
|
pub unsafe fn run_as_ref(&self) {
|
|
((&mut *self.task.get()).take().unwrap())();
|
|
}
|
|
|
|
#[inline]
|
|
pub fn as_task_ref(self: Pin<&Self>) -> TaskRef {
|
|
unsafe { TaskRef::new(&*self) }
|
|
}
|
|
}
|
|
|
|
impl<F: FnOnce() + Send> Task for StackTask<F> {
|
|
#[inline]
|
|
unsafe fn execute(this: *const ()) {
|
|
let this = &*this.cast::<Self>();
|
|
let task = (&mut *this.task.get()).take().unwrap();
|
|
task();
|
|
}
|
|
}
|
|
|
|
pub struct HeapTask<F: FnOnce() + Send> {
|
|
task: F,
|
|
_phantom: PhantomPinned,
|
|
}
|
|
|
|
impl<F: FnOnce() + Send> HeapTask<F> {
|
|
pub fn new(task: F) -> Box<HeapTask<F>> {
|
|
Box::new(Self {
|
|
task,
|
|
_phantom: PhantomPinned,
|
|
})
|
|
}
|
|
|
|
#[inline]
|
|
pub unsafe fn into_static_task_ref(self: Box<Self>) -> TaskRef
|
|
where
|
|
F: 'static,
|
|
{
|
|
self.into_task_ref()
|
|
}
|
|
|
|
#[inline]
|
|
pub unsafe fn into_task_ref(self: Box<Self>) -> TaskRef {
|
|
TaskRef::new(Box::into_raw(self))
|
|
}
|
|
}
|
|
impl<F: FnOnce() + Send> Task for HeapTask<F> {
|
|
#[inline]
|
|
unsafe fn execute(this: *const ()) {
|
|
let this = Box::from_raw(this.cast::<Self>().cast_mut());
|
|
(this.task)();
|
|
}
|
|
}
|
|
}
|
|
|
|
pub mod latch {
|
|
use core::marker::PhantomData;
|
|
use std::{
|
|
sync::{
|
|
atomic::{AtomicBool, AtomicUsize, Ordering},
|
|
Arc,
|
|
},
|
|
task::Wake,
|
|
};
|
|
|
|
use parking_lot::{Condvar, Mutex};
|
|
|
|
use crate::{ThreadPool, WorkerThread};
|
|
|
|
pub trait Latch {
|
|
unsafe fn set_raw(this: *const Self);
|
|
}
|
|
|
|
pub trait Probe {
|
|
fn probe(&self) -> bool;
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct AtomicLatch(AtomicBool);
|
|
|
|
impl AtomicLatch {
|
|
#[inline]
|
|
pub const fn new() -> AtomicLatch {
|
|
Self(AtomicBool::new(false))
|
|
}
|
|
#[inline]
|
|
pub fn reset(&self) {
|
|
self.0.store(false, Ordering::Release);
|
|
}
|
|
}
|
|
|
|
impl Latch for AtomicLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
(*this).0.store(true, Ordering::Release);
|
|
}
|
|
}
|
|
|
|
impl Probe for AtomicLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
self.0.load(Ordering::Acquire)
|
|
}
|
|
}
|
|
|
|
pub struct ClosureLatch<S, P> {
|
|
set: S,
|
|
probe: P,
|
|
}
|
|
|
|
impl<S, P> ClosureLatch<S, P> {
|
|
pub fn new(set: S, probe: P) -> Self {
|
|
Self { set, probe }
|
|
}
|
|
pub fn new_boxed(set: S, probe: P) -> Box<Self> {
|
|
Box::new(Self { set, probe })
|
|
}
|
|
}
|
|
|
|
impl<S, P> Latch for ClosureLatch<S, P>
|
|
where
|
|
S: Fn(),
|
|
{
|
|
unsafe fn set_raw(this: *const Self) {
|
|
let this = &*this;
|
|
(this.set)();
|
|
}
|
|
}
|
|
impl<S, P> Probe for ClosureLatch<S, P>
|
|
where
|
|
P: Fn() -> bool,
|
|
{
|
|
fn probe(&self) -> bool {
|
|
(self.probe)()
|
|
}
|
|
}
|
|
|
|
pub struct ThreadWakeLatch {
|
|
inner: AtomicLatch,
|
|
index: usize,
|
|
pool: &'static ThreadPool,
|
|
}
|
|
|
|
impl ThreadWakeLatch {
|
|
#[inline]
|
|
pub const fn new(thread: &WorkerThread) -> ThreadWakeLatch {
|
|
Self {
|
|
inner: AtomicLatch::new(),
|
|
pool: thread.pool,
|
|
index: thread.index,
|
|
}
|
|
}
|
|
#[inline]
|
|
pub fn reset(&self) {
|
|
self.inner.reset()
|
|
}
|
|
}
|
|
|
|
impl Latch for ThreadWakeLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
let (pool, index) = {
|
|
let this = &*this;
|
|
(this.pool, this.index)
|
|
};
|
|
Latch::set_raw(&(*this).inner);
|
|
pool.wake_thread(index);
|
|
}
|
|
}
|
|
|
|
impl Probe for ThreadWakeLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
self.inner.probe()
|
|
}
|
|
}
|
|
|
|
pub struct LatchRef<'a, L: Latch> {
|
|
inner: *const L,
|
|
_marker: PhantomData<&'a L>,
|
|
}
|
|
|
|
impl<'a, L: Latch> LatchRef<'a, L> {
|
|
#[inline]
|
|
pub const fn new(latch: &'a L) -> Self {
|
|
Self {
|
|
inner: latch,
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, L: Latch> Latch for LatchRef<'a, L> {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
let this = &*this;
|
|
Latch::set_raw(this.inner);
|
|
}
|
|
}
|
|
|
|
impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
unsafe {
|
|
let this = &*self.inner;
|
|
Probe::probe(this)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct NopLatch;
|
|
|
|
impl Latch for NopLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(_this: *const Self) {
|
|
// do nothing
|
|
}
|
|
}
|
|
|
|
pub struct MutexLatch {
|
|
mutex: Mutex<bool>,
|
|
signal: Condvar,
|
|
}
|
|
|
|
impl MutexLatch {
|
|
#[inline]
|
|
pub const fn new() -> MutexLatch {
|
|
Self {
|
|
mutex: Mutex::new(false),
|
|
signal: Condvar::new(),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn wait(&self) {
|
|
let mut guard = self.mutex.lock();
|
|
while !*guard {
|
|
self.signal.wait(&mut guard);
|
|
}
|
|
}
|
|
#[inline]
|
|
pub fn wait_and_reset(&self) {
|
|
let mut guard = self.mutex.lock();
|
|
while !*guard {
|
|
self.signal.wait(&mut guard);
|
|
}
|
|
*guard = false;
|
|
}
|
|
}
|
|
|
|
impl Latch for MutexLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
let mut guard = (*this).mutex.lock();
|
|
*guard = true;
|
|
(*this).signal.notify_all();
|
|
}
|
|
}
|
|
|
|
pub struct CountWakeLatch {
|
|
counter: AtomicUsize,
|
|
inner: ThreadWakeLatch,
|
|
}
|
|
|
|
impl CountWakeLatch {
|
|
#[inline]
|
|
pub const fn new(count: usize, thread: &WorkerThread) -> CountWakeLatch {
|
|
Self {
|
|
counter: AtomicUsize::new(count),
|
|
inner: ThreadWakeLatch::new(thread),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn increment(&self) {
|
|
self.counter.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
|
|
impl Latch for CountWakeLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
|
|
Latch::set_raw(&(*this).inner);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Probe for CountWakeLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
self.inner.probe()
|
|
}
|
|
}
|
|
|
|
pub struct LatchWaker<L>(L);
|
|
|
|
impl<L> LatchWaker<L> {
|
|
#[inline]
|
|
pub fn new(latch: L) -> Arc<Self> {
|
|
Arc::new(Self(latch))
|
|
}
|
|
#[inline]
|
|
pub fn latch(&self) -> &L {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl<L> Wake for LatchWaker<L>
|
|
where
|
|
L: Latch,
|
|
{
|
|
#[inline]
|
|
fn wake(self: Arc<Self>) {
|
|
self.wake_by_ref();
|
|
}
|
|
#[inline]
|
|
fn wake_by_ref(self: &Arc<Self>) {
|
|
unsafe {
|
|
Latch::set_raw(&self.0);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub mod melange;
|
|
pub mod praetor;
|
|
|
|
pub struct ThreadPoolState {
|
|
num_threads: AtomicUsize,
|
|
lock: Mutex<()>,
|
|
heartbeat_state: CachePadded<ThreadControl>,
|
|
}
|
|
|
|
bitflags! {
|
|
#[derive(Clone)]
|
|
pub struct ThreadStatus: u8 {
|
|
const RUNNING = 1 << 0;
|
|
const SLEEPING = 1 << 1;
|
|
const SHOULD_WAKE = 1 << 2;
|
|
}
|
|
}
|
|
|
|
pub struct ThreadControl {
|
|
status: Mutex<ThreadStatus>,
|
|
status_changed: Condvar,
|
|
should_terminate: AtomicLatch,
|
|
}
|
|
|
|
pub struct ThreadState {
|
|
should_shove: AtomicBool,
|
|
control: ThreadControl,
|
|
stealer: Stealer<TaskRef>,
|
|
worker: AtomicCell<Option<Worker<TaskRef>>>,
|
|
shoved_task: CachePadded<Slot<TaskRef>>,
|
|
}
|
|
|
|
impl ThreadControl {
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
status: Mutex::new(ThreadStatus::empty()),
|
|
status_changed: Condvar::new(),
|
|
should_terminate: AtomicLatch::new(),
|
|
}
|
|
}
|
|
/// returns true if thread was sleeping
|
|
#[inline]
|
|
pub fn wake(&self) -> bool {
|
|
let mut guard = self.status.lock();
|
|
guard.insert(ThreadStatus::SHOULD_WAKE);
|
|
self.status_changed.notify_all();
|
|
guard.contains(ThreadStatus::SLEEPING)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn wait_for_running(&self) {
|
|
let mut guard = self.status.lock();
|
|
while !guard.contains(ThreadStatus::RUNNING) {
|
|
self.status_changed.wait(&mut guard);
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn wait_for_should_wake(&self) {
|
|
let mut guard = self.status.lock();
|
|
while !guard.contains(ThreadStatus::SHOULD_WAKE) {
|
|
guard.insert(ThreadStatus::SLEEPING);
|
|
self.status_changed.wait(&mut guard);
|
|
}
|
|
guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING);
|
|
}
|
|
|
|
#[inline]
|
|
pub fn wait_for_should_wake_timeout(&self, timeout: Duration) {
|
|
let mut guard = self.status.lock();
|
|
while !guard.contains(ThreadStatus::SHOULD_WAKE) {
|
|
guard.insert(ThreadStatus::SLEEPING);
|
|
if self
|
|
.status_changed
|
|
.wait_for(&mut guard, timeout)
|
|
.timed_out()
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING);
|
|
}
|
|
|
|
#[inline]
|
|
pub fn wait_for_termination(&self) {
|
|
let mut guard = self.status.lock();
|
|
while guard.contains(ThreadStatus::RUNNING) {
|
|
self.status_changed.wait(&mut guard);
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn notify_running(&self) {
|
|
let mut guard = self.status.lock();
|
|
guard.insert(ThreadStatus::RUNNING);
|
|
self.status_changed.notify_all();
|
|
}
|
|
|
|
#[inline]
|
|
pub fn notify_termination(&self) {
|
|
let mut guard = self.status.lock();
|
|
*guard = ThreadStatus::empty();
|
|
self.status_changed.notify_all();
|
|
}
|
|
|
|
#[inline]
|
|
pub fn notify_should_terminate(&self) {
|
|
unsafe {
|
|
Latch::set_raw(&self.should_terminate);
|
|
}
|
|
self.wake();
|
|
}
|
|
}
|
|
|
|
const MAX_THREADS: usize = 32;
|
|
|
|
type ThreadCallback = dyn Fn(&WorkerThread) + Send + Sync + 'static;
|
|
pub struct ThreadPoolCallbacks {
|
|
at_entry: Option<Arc<ThreadCallback>>,
|
|
at_exit: Option<Arc<ThreadCallback>>,
|
|
}
|
|
|
|
impl ThreadPoolCallbacks {
|
|
pub const fn new_empty() -> ThreadPoolCallbacks {
|
|
Self {
|
|
at_entry: None,
|
|
at_exit: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ThreadPool {
|
|
threads: [CachePadded<ThreadState>; MAX_THREADS],
|
|
pool_state: CachePadded<ThreadPoolState>,
|
|
global_queue: Injector<TaskRef>,
|
|
callbacks: CachePadded<ThreadPoolCallbacks>,
|
|
}
|
|
|
|
impl ThreadPool {
|
|
pub fn new() -> Self {
|
|
Self::new_with_callbacks(ThreadPoolCallbacks::new_empty())
|
|
}
|
|
|
|
pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
|
|
let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| {
|
|
let worker = Worker::<TaskRef>::new_fifo();
|
|
let stealer = worker.stealer();
|
|
|
|
let thread = CachePadded::new(ThreadState {
|
|
should_shove: AtomicBool::new(false),
|
|
shoved_task: Slot::new().into(),
|
|
control: ThreadControl {
|
|
status: Mutex::new(ThreadStatus::empty()),
|
|
status_changed: Condvar::new(),
|
|
should_terminate: AtomicLatch::new(),
|
|
},
|
|
stealer,
|
|
worker: AtomicCell::new(Some(worker)),
|
|
});
|
|
uninit.write(thread);
|
|
unsafe { uninit.assume_init() }
|
|
});
|
|
|
|
Self {
|
|
threads,
|
|
pool_state: CachePadded::new(ThreadPoolState {
|
|
num_threads: AtomicUsize::new(0),
|
|
lock: Mutex::new(()),
|
|
heartbeat_state: ThreadControl {
|
|
status: Mutex::new(ThreadStatus::empty()),
|
|
status_changed: Condvar::new(),
|
|
should_terminate: AtomicLatch::new(),
|
|
}
|
|
.into(),
|
|
}),
|
|
global_queue: Injector::new(),
|
|
callbacks: CachePadded::new(callbacks),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn threads(&self) -> &[CachePadded<ThreadState>] {
|
|
&self.threads[..self.pool_state.num_threads.load(Ordering::Relaxed) as usize]
|
|
}
|
|
|
|
pub fn wake_thread(&self, index: usize) -> Option<bool> {
|
|
Some(self.threads.get(index as usize)?.control.wake())
|
|
}
|
|
|
|
pub fn wake_any(&self, count: usize) -> usize {
|
|
if count > 0 {
|
|
let num_woken = self
|
|
.threads
|
|
.iter()
|
|
.filter_map(|thread| thread.control.wake().then_some(()))
|
|
.take(count)
|
|
.count();
|
|
num_woken
|
|
} else {
|
|
0
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn id(&self) -> impl Eq {
|
|
core::ptr::from_ref(self) as usize
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn push_local_or_inject_balanced(&self, task: TaskRef) {
|
|
let global_len = self.global_queue.len();
|
|
WorkerThread::with(|worker| match worker {
|
|
Some(worker) if worker.pool.id() == self.id() => {
|
|
let worker_len = worker.worker.len();
|
|
if worker_len == 0 {
|
|
worker.push_task(task);
|
|
} else if global_len == 0 {
|
|
self.inject(task);
|
|
} else {
|
|
if worker_len >= global_len {
|
|
worker.push_task(task);
|
|
} else {
|
|
self.inject(task);
|
|
}
|
|
}
|
|
}
|
|
_ => self.inject(task),
|
|
})
|
|
}
|
|
|
|
fn push_local_or_inject(&self, task: TaskRef) {
|
|
WorkerThread::with(|worker| match worker {
|
|
Some(worker) if worker.pool.id() == self.id() => worker.push_task(task),
|
|
_ => self.inject(task),
|
|
})
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn inject_many<I>(&self, tasks: I)
|
|
where
|
|
I: Iterator<Item = TaskRef>,
|
|
{
|
|
let mut n = 0;
|
|
for task in tasks {
|
|
n += 1;
|
|
self.global_queue.push(task);
|
|
}
|
|
self.wake_any(n);
|
|
}
|
|
|
|
#[allow(unused_variables)]
|
|
fn inject_maybe_local(&self, task: TaskRef) {
|
|
#[cfg(all(not(feature = "never-local"), feature = "prefer-local"))]
|
|
self.push_local_or_inject(task);
|
|
#[cfg(all(not(feature = "prefer-local"), feature = "never-local"))]
|
|
self.inject(task);
|
|
#[cfg(not(any(feature = "prefer-local", feature = "never-local")))]
|
|
self.push_local_or_inject_balanced(task);
|
|
}
|
|
|
|
fn inject(&self, task: TaskRef) {
|
|
self.global_queue.push(task);
|
|
|
|
self.wake_any(1);
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn resize<F: Fn(usize) -> usize>(&'static self, size: F) -> usize {
|
|
if WorkerThread::is_worker_thread() {
|
|
// acquire required here?
|
|
debug!("tried to resize from within threadpool!");
|
|
return self.pool_state.num_threads.load(Ordering::Acquire);
|
|
}
|
|
|
|
#[cfg(feature = "cpu-pinning")]
|
|
let cpus = core_affinity::get_core_ids().unwrap();
|
|
|
|
let _guard = self.pool_state.lock.lock();
|
|
|
|
let current_size = self.pool_state.num_threads.load(Ordering::Acquire);
|
|
let new_size = size(current_size).clamp(0, MAX_THREADS);
|
|
debug!(current_size, new_size, "resizing threadpool");
|
|
|
|
if new_size == current_size {
|
|
return current_size;
|
|
}
|
|
|
|
self.pool_state
|
|
.num_threads
|
|
.store(new_size, Ordering::Release);
|
|
|
|
match new_size.cmp(¤t_size) {
|
|
std::cmp::Ordering::Greater => {
|
|
let new_threads = &self.threads[current_size..new_size];
|
|
|
|
for (i, _) in new_threads.iter().enumerate() {
|
|
#[cfg(feature = "cpu-pinning")]
|
|
let core = cpus[i];
|
|
std::thread::spawn(move || {
|
|
#[cfg(feature = "cpu-pinning")]
|
|
core_affinity::set_for_current(core);
|
|
WorkerThread::worker_loop(&self, current_size + i);
|
|
});
|
|
}
|
|
|
|
for thread in new_threads {
|
|
thread.control.wait_for_running();
|
|
}
|
|
|
|
#[cfg(feature = "heartbeat")]
|
|
if current_size == 0 {
|
|
std::thread::spawn(move || {
|
|
heartbeat_loop(self);
|
|
});
|
|
|
|
self.pool_state.heartbeat_state.wait_for_running();
|
|
}
|
|
}
|
|
std::cmp::Ordering::Less => {
|
|
debug!(
|
|
"waiting for threads {:?} to terminate.",
|
|
new_size..current_size
|
|
);
|
|
let terminating_threads = &self.threads[new_size..current_size];
|
|
|
|
for thread in terminating_threads {
|
|
thread.control.notify_should_terminate();
|
|
}
|
|
for thread in terminating_threads {
|
|
thread.control.wait_for_termination();
|
|
}
|
|
|
|
#[cfg(feature = "heartbeat")]
|
|
if new_size == 0 {
|
|
self.pool_state.heartbeat_state.notify_should_terminate();
|
|
self.pool_state.heartbeat_state.wait_for_termination();
|
|
}
|
|
}
|
|
std::cmp::Ordering::Equal => unreachable!(),
|
|
}
|
|
|
|
new_size
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn ensure_one_worker(&'static self) -> usize {
|
|
self.resize(|current| current.max(1))
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn resize_to_available(&'static self) {
|
|
self.resize_to(available_parallelism().map(NonZero::get).unwrap_or(1));
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn resize_to(&'static self, new_size: usize) -> usize {
|
|
self.resize(|_| new_size)
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn grow_by(&'static self, num_threads: usize) -> usize {
|
|
self.resize(|current| current.saturating_add(num_threads))
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn shrink_by(&'static self, num_threads: usize) -> usize {
|
|
self.resize(|current| current.saturating_sub(num_threads))
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn shrink_to(&'static self, num_threads: usize) -> usize {
|
|
self.resize(|_| num_threads)
|
|
}
|
|
|
|
fn in_worker<F, T>(&'static self, f: F) -> T
|
|
where
|
|
F: FnOnce(&WorkerThread, bool) -> T + Send,
|
|
T: Send,
|
|
{
|
|
WorkerThread::with(|worker| match worker {
|
|
Some(worker) => {
|
|
if worker.pool.id() == self.id() {
|
|
self.in_worker_cross(worker, f)
|
|
} else {
|
|
f(worker, false)
|
|
}
|
|
}
|
|
None => self.in_worker_cold(f),
|
|
})
|
|
}
|
|
|
|
#[cold]
|
|
fn in_worker_cold<F, T>(&'static self, f: F) -> T
|
|
where
|
|
F: FnOnce(&WorkerThread, bool) -> T + Send,
|
|
T: Send,
|
|
{
|
|
std::thread_local! {static LATCH: MutexLatch = const {MutexLatch::new()}};
|
|
|
|
LATCH.with(|latch| {
|
|
let mut result = None;
|
|
let task = StackTask::new(|| {
|
|
WorkerThread::with(|worker| {
|
|
let worker = worker.unwrap();
|
|
|
|
result = Some(f(worker, true));
|
|
|
|
unsafe {
|
|
// SAFETY: static thread-local
|
|
Latch::set_raw(latch);
|
|
}
|
|
})
|
|
});
|
|
|
|
let pinned = pin!(task);
|
|
let taskref = pinned.as_ref().as_task_ref();
|
|
self.inject(taskref);
|
|
|
|
latch.wait_and_reset();
|
|
result.unwrap()
|
|
})
|
|
}
|
|
|
|
/// run f in `self`, but block current thread until work is complete.
|
|
fn in_worker_cross<F, T>(&'static self, worker: &WorkerThread, f: F) -> T
|
|
where
|
|
F: FnOnce(&WorkerThread, bool) -> T + Send,
|
|
T: Send,
|
|
{
|
|
let latch = ThreadWakeLatch::new(worker);
|
|
|
|
let mut result = None;
|
|
|
|
let task = pin!(StackTask::new(|| {
|
|
WorkerThread::with(|worker| {
|
|
let worker = worker.unwrap();
|
|
|
|
result = Some(f(worker, true));
|
|
|
|
unsafe {
|
|
// SAFETY: static thread-local
|
|
Latch::set_raw(&latch);
|
|
}
|
|
})
|
|
}));
|
|
|
|
let taskref = task.into_ref().as_task_ref();
|
|
self.push_local_or_inject(taskref);
|
|
|
|
worker.run_until(&latch);
|
|
result.unwrap()
|
|
}
|
|
}
|
|
|
|
impl ThreadPool {
|
|
pub fn spawn<Fn>(&'static self, f: Fn)
|
|
where
|
|
Fn: FnOnce() + Send + 'static,
|
|
{
|
|
let task = HeapTask::new(f);
|
|
|
|
let taskref = unsafe { task.into_static_task_ref() };
|
|
self.inject_maybe_local(taskref);
|
|
}
|
|
|
|
pub fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T>
|
|
where
|
|
Fut: Future<Output = T> + Send + 'static,
|
|
T: Send + 'static,
|
|
{
|
|
let schedule = move |runnable: Runnable| {
|
|
let taskref = unsafe {
|
|
TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| {
|
|
let this = NonNull::new_unchecked(this.cast_mut());
|
|
|
|
let runnable = Runnable::<()>::from_raw(this);
|
|
runnable.run();
|
|
})
|
|
};
|
|
|
|
self.inject_maybe_local(taskref);
|
|
};
|
|
|
|
let (runnable, task) = async_task::spawn(future, schedule);
|
|
|
|
runnable.schedule();
|
|
task
|
|
}
|
|
|
|
pub fn spawn_async<Fn, Fut, T>(&'static self, f: Fn) -> Task<T>
|
|
where
|
|
Fn: FnOnce() -> Fut + Send + 'static,
|
|
Fut: Future<Output = T> + Send + 'static,
|
|
T: Send + 'static,
|
|
{
|
|
self.spawn_future(async move { f().await })
|
|
}
|
|
|
|
pub fn block_on<Fut, T>(&'static self, mut future: Fut)
|
|
where
|
|
Fut: Future<Output = T> + Send + 'static,
|
|
T: Send + 'static,
|
|
{
|
|
let mut future = unsafe { Pin::new_unchecked(&mut future) };
|
|
self.in_worker(|worker, _| {
|
|
let wake = LatchWaker::new(ThreadWakeLatch::new(worker));
|
|
let ctx_waker = Arc::clone(&wake).into();
|
|
let mut ctx = Context::from_waker(&ctx_waker);
|
|
|
|
loop {
|
|
match future.as_mut().poll(&mut ctx) {
|
|
std::task::Poll::Ready(t) => {
|
|
return t;
|
|
}
|
|
std::task::Poll::Pending => {
|
|
worker.run_until(wake.latch());
|
|
wake.latch().reset();
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
pub fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
|
|
where
|
|
F: FnOnce() -> T + Send,
|
|
G: FnOnce() -> U + Send,
|
|
T: Send,
|
|
U: Send,
|
|
{
|
|
self.join_threaded(f, g)
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn join_seq<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
|
|
where
|
|
F: FnOnce() -> T + Send,
|
|
G: FnOnce() -> U + Send,
|
|
T: Send,
|
|
U: Send,
|
|
{
|
|
let a = f();
|
|
let b = g();
|
|
|
|
(a, b)
|
|
}
|
|
|
|
#[inline]
|
|
fn join_threaded<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
|
|
where
|
|
F: FnOnce() -> T + Send,
|
|
G: FnOnce() -> U + Send,
|
|
T: Send,
|
|
U: Send,
|
|
{
|
|
self.in_worker(|worker, _| {
|
|
let mut result_b = None;
|
|
|
|
let latch_b = ThreadWakeLatch::new(worker);
|
|
let task_b = pin!(StackTask::new(|| {
|
|
result_b = Some(g());
|
|
unsafe {
|
|
Latch::set_raw(&latch_b);
|
|
}
|
|
}));
|
|
|
|
let ref_b = task_b.as_ref().as_task_ref();
|
|
let b_id = ref_b.id();
|
|
worker.push_task(ref_b);
|
|
|
|
let result_a = f();
|
|
|
|
while !latch_b.probe() {
|
|
match worker.pop_task() {
|
|
Some(task) => {
|
|
if task.id() == b_id {
|
|
// we're not calling execute() here, so manually try
|
|
// shoving a task.
|
|
//worker.try_promote();
|
|
worker.shove_task();
|
|
unsafe {
|
|
task_b.run_as_ref();
|
|
}
|
|
break;
|
|
} else {
|
|
worker.execute(task);
|
|
}
|
|
}
|
|
None => {
|
|
worker.run_until(&latch_b);
|
|
}
|
|
}
|
|
}
|
|
|
|
(result_a, result_b.unwrap())
|
|
})
|
|
}
|
|
|
|
pub fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T
|
|
where
|
|
Fn: FnOnce(&Scope<'scope>) -> T + Send,
|
|
T: Send,
|
|
{
|
|
self.in_worker(|owner, _| {
|
|
let scope = unsafe { Scope::<'scope>::new(owner) };
|
|
let result = f(&scope);
|
|
scope.complete(owner);
|
|
result
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct WorkerThread {
|
|
// queue: TaskQueue<TaskRef>,
|
|
worker: Worker<TaskRef>,
|
|
pool: &'static ThreadPool,
|
|
index: usize,
|
|
rng: rng::XorShift64Star,
|
|
}
|
|
|
|
const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) };
|
|
|
|
std::thread_local! {
|
|
static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())};
|
|
}
|
|
|
|
impl WorkerThread {
|
|
#[inline]
|
|
fn info(&self) -> &ThreadState {
|
|
&self.pool.threads[self.index as usize]
|
|
}
|
|
#[inline]
|
|
fn pool(&self) -> &'static ThreadPool {
|
|
self.pool
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
fn index(&self) -> usize {
|
|
self.index
|
|
}
|
|
#[inline]
|
|
fn is_worker_thread() -> bool {
|
|
Self::with(|worker| worker.is_some())
|
|
}
|
|
|
|
fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T {
|
|
WORKER_THREAD_STATE.with(|thread| {
|
|
f(NonNull::<WorkerThread>::new(thread.get().cast_mut())
|
|
.map(|ptr| unsafe { ptr.as_ref() }))
|
|
})
|
|
}
|
|
#[inline]
|
|
fn pop_task(&self) -> Option<TaskRef> {
|
|
self.worker.pop()
|
|
//self.queue.pop_front(task);
|
|
}
|
|
#[inline]
|
|
fn push_task(&self, task: TaskRef) {
|
|
self.worker.push(task);
|
|
//self.queue.push_front(task);
|
|
}
|
|
|
|
#[inline]
|
|
fn drain(&self) -> impl Iterator<Item = TaskRef> {
|
|
// self.queue.drain()
|
|
core::iter::empty()
|
|
}
|
|
|
|
#[inline]
|
|
fn steal_tasks(&self) -> Option<TaskRef> {
|
|
// careful not to call threads() here because that omits any threads
|
|
// that were killed, which might still have tasks.
|
|
let threads = &self.pool.threads;
|
|
let (start, end) = threads.split_at(self.rng.next_usize(threads.len()));
|
|
|
|
end.iter()
|
|
.chain(start)
|
|
.find_map(|thread: &CachePadded<ThreadState>| {
|
|
thread.stealer.steal_batch_and_pop(&self.worker).success()
|
|
})
|
|
}
|
|
|
|
#[inline]
|
|
fn claim_shoved_task(&self) -> Option<TaskRef> {
|
|
// take own shoved task first
|
|
if let Some(task) = self.info().shoved_task.try_take() {
|
|
return Some(task);
|
|
}
|
|
|
|
let threads = self.pool.threads();
|
|
if threads.is_empty() {
|
|
return None;
|
|
}
|
|
let (start, end) = threads.split_at(self.rng.next_usize(threads.len()));
|
|
|
|
end.iter()
|
|
.chain(start)
|
|
.find_map(|thread| thread.shoved_task.try_take())
|
|
}
|
|
|
|
#[cold]
|
|
fn shove_task(&self) {
|
|
if !self.info().shoved_task.is_occupied() {
|
|
if let Some(task) = self.info().stealer.steal().success() {
|
|
match self.info().shoved_task.try_put(task) {
|
|
// shoved task is occupied, reinsert into queue
|
|
// this really shouldn't happen
|
|
Some(_task) => unreachable!(),
|
|
None => {}
|
|
}
|
|
}
|
|
} else {
|
|
// wake thread to execute task
|
|
self.pool.wake_any(1);
|
|
}
|
|
}
|
|
|
|
fn execute(&self, task: TaskRef) {
|
|
self.try_promote();
|
|
task.execute();
|
|
}
|
|
|
|
#[inline]
|
|
fn try_promote(&self) {
|
|
#[cfg(feature = "heartbeat")]
|
|
let should_shove = self.info().should_shove.load(Ordering::Acquire);
|
|
#[cfg(not(feature = "heartbeat"))]
|
|
let should_shove = true;
|
|
|
|
if should_shove {
|
|
#[cfg(feature = "heartbeat")]
|
|
self.info().should_shove.store(false, Ordering::Release);
|
|
|
|
self.shove_task();
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn find_any_task(&self) -> Option<TaskRef> {
|
|
// TODO: attempt stealing work here, too.
|
|
#[allow(unused_mut)]
|
|
let mut task = self
|
|
.pop_task()
|
|
.or_else(|| self.claim_shoved_task())
|
|
.or_else(|| {
|
|
self.pool
|
|
.global_queue
|
|
.steal_batch_and_pop(&self.worker)
|
|
.success()
|
|
});
|
|
|
|
#[cfg(feature = "work-stealing")]
|
|
{
|
|
task = task.or_else(|| self.steal_tasks());
|
|
}
|
|
|
|
task
|
|
}
|
|
|
|
#[inline]
|
|
fn run_until<L>(&self, latch: &L)
|
|
where
|
|
L: Probe,
|
|
{
|
|
if !latch.probe() {
|
|
self.run_until_cold(latch);
|
|
}
|
|
}
|
|
|
|
#[cold]
|
|
fn run_until_cold<L>(&self, latch: &L)
|
|
where
|
|
L: Probe,
|
|
{
|
|
while !latch.probe() {
|
|
self.run_until_inner();
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn run_until_inner(&self) {
|
|
match self.find_any_task() {
|
|
Some(task) => {
|
|
self.execute(task);
|
|
}
|
|
None => {
|
|
//debug!("waiting for tasks");
|
|
self.info().control.wait_for_should_wake();
|
|
}
|
|
}
|
|
}
|
|
|
|
fn worker_loop(pool: &'static ThreadPool, index: usize) {
|
|
let info = &pool.threads()[index as usize];
|
|
let worker = CachePadded::new(WorkerThread {
|
|
// queue: TaskQueue::new(),
|
|
worker: info.worker.take().unwrap(),
|
|
pool,
|
|
index,
|
|
rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64),
|
|
});
|
|
|
|
WORKER_THREAD_STATE.with(|cell| {
|
|
cell.set(&*worker);
|
|
|
|
if let Some(callback) = pool.callbacks.at_entry.as_ref() {
|
|
callback(&worker);
|
|
}
|
|
|
|
info.control.notify_running();
|
|
// info.notify_running();
|
|
worker.run_until(&info.control.should_terminate);
|
|
|
|
if let Some(callback) = pool.callbacks.at_exit.as_ref() {
|
|
callback(&worker);
|
|
}
|
|
|
|
for task in worker.drain() {
|
|
pool.inject(task);
|
|
}
|
|
|
|
if let Some(task) = info.shoved_task.try_take() {
|
|
pool.inject(task);
|
|
}
|
|
|
|
cell.set(ptr::null());
|
|
});
|
|
|
|
let WorkerThread { worker, .. } = CachePadded::into_inner(worker);
|
|
info.worker.store(Some(worker));
|
|
|
|
info.control.notify_termination();
|
|
}
|
|
}
|
|
|
|
fn heartbeat_loop(pool: &'static ThreadPool) {
|
|
let state = &pool.pool_state.heartbeat_state;
|
|
|
|
state.notify_running();
|
|
let mut i = 0;
|
|
while !state.should_terminate.probe() {
|
|
let threads = pool.threads();
|
|
if threads.is_empty() {
|
|
break;
|
|
}
|
|
|
|
if i >= threads.len() {
|
|
i = 0;
|
|
continue;
|
|
}
|
|
|
|
threads[i].should_shove.store(true, Ordering::Relaxed);
|
|
i += 1;
|
|
|
|
let interval = HEARTBEAT_INTERVAL / threads.len() as u32;
|
|
|
|
state.wait_for_should_wake_timeout(interval);
|
|
}
|
|
|
|
state.notify_termination();
|
|
}
|
|
|
|
mod vec_queue {
|
|
use std::{cell::UnsafeCell, collections::VecDeque};
|
|
|
|
pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>);
|
|
|
|
impl<T> TaskQueue<T> {
|
|
/// Creates a new [`TaskQueue<T>`].
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub const fn new() -> Self {
|
|
Self(UnsafeCell::new(VecDeque::new()))
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub fn get_mut(&self) -> &mut VecDeque<T> {
|
|
unsafe { &mut *self.0.get() }
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub fn pop_front(&self) -> Option<T> {
|
|
self.get_mut().pop_front()
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub fn pop_back(&self) -> Option<T> {
|
|
self.get_mut().pop_back()
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub fn push_back(&self, t: T) {
|
|
self.get_mut().push_back(t);
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub fn push_front(&self, t: T) {
|
|
self.get_mut().push_front(t);
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub fn take(&self) -> VecDeque<T> {
|
|
let this = core::mem::replace(self.get_mut(), VecDeque::new());
|
|
this
|
|
}
|
|
#[inline]
|
|
#[allow(dead_code)]
|
|
pub fn drain(&self) -> impl Iterator<Item = T> {
|
|
self.take().into_iter()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[repr(u8)]
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
enum SlotState {
|
|
None,
|
|
Locked,
|
|
Occupied,
|
|
}
|
|
|
|
impl From<u8> for SlotState {
|
|
fn from(value: u8) -> Self {
|
|
unsafe { core::mem::transmute(value) }
|
|
}
|
|
}
|
|
|
|
impl From<SlotState> for u8 {
|
|
fn from(value: SlotState) -> Self {
|
|
value as u8
|
|
}
|
|
}
|
|
|
|
pub struct Slot<T> {
|
|
slot: UnsafeCell<MaybeUninit<T>>,
|
|
state: AtomicU8,
|
|
}
|
|
|
|
unsafe impl<T> Send for Slot<T> where T: Send {}
|
|
unsafe impl<T> Sync for Slot<T> where T: Send {}
|
|
|
|
impl<T> Drop for Slot<T> {
|
|
fn drop(&mut self) {
|
|
if core::mem::needs_drop::<T>() {
|
|
if *self.state.get_mut() == SlotState::Occupied as u8 {
|
|
unsafe {
|
|
self.slot.get().drop_in_place();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> Slot<T> {
|
|
pub const fn new() -> Slot<T> {
|
|
Self {
|
|
slot: UnsafeCell::new(MaybeUninit::uninit()),
|
|
state: AtomicU8::new(SlotState::None as u8),
|
|
}
|
|
}
|
|
|
|
pub fn is_occupied(&self) -> bool {
|
|
self.state.load(Ordering::Acquire) == SlotState::Occupied.into()
|
|
}
|
|
|
|
#[inline]
|
|
pub fn insert(&self, t: T) -> Option<T> {
|
|
let value = match self
|
|
.state
|
|
.swap(SlotState::Locked.into(), Ordering::AcqRel)
|
|
.into()
|
|
{
|
|
SlotState::Locked => {
|
|
// return early: was already locked.
|
|
debug!("slot was already locked");
|
|
return None;
|
|
}
|
|
SlotState::Occupied => {
|
|
let slot = self.slot.get();
|
|
// replace
|
|
unsafe {
|
|
let v = (*slot).assume_init_read();
|
|
(*slot).write(t);
|
|
Some(v)
|
|
}
|
|
}
|
|
SlotState::None => {
|
|
let slot = self.slot.get();
|
|
// insert
|
|
unsafe {
|
|
(*slot).write(t);
|
|
}
|
|
None
|
|
}
|
|
};
|
|
|
|
// release lock
|
|
self.state
|
|
.store(SlotState::Occupied.into(), Ordering::Release);
|
|
value
|
|
}
|
|
|
|
#[inline]
|
|
pub fn try_put(&self, t: T) -> Option<T> {
|
|
match self.state.compare_exchange(
|
|
SlotState::None.into(),
|
|
SlotState::Locked.into(),
|
|
Ordering::Acquire,
|
|
Ordering::Relaxed,
|
|
) {
|
|
Err(_) => Some(t),
|
|
Ok(_) => {
|
|
let slot = self.slot.get();
|
|
// SAFETY: we hold LOCKED on the spinlock
|
|
unsafe { (*slot).write(t) };
|
|
|
|
// release lock
|
|
self.state
|
|
.store(SlotState::Occupied.into(), Ordering::Release);
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn try_take(&self) -> Option<T> {
|
|
match self.state.compare_exchange(
|
|
SlotState::Occupied.into(),
|
|
SlotState::Locked.into(),
|
|
Ordering::Acquire,
|
|
Ordering::Relaxed,
|
|
) {
|
|
Ok(_) => {
|
|
let slot = self.slot.get();
|
|
// SAFETY: we hold LOCKED on the spinlock
|
|
let t = unsafe { (*slot).assume_init_read() };
|
|
|
|
// release lock
|
|
self.state.store(SlotState::None.into(), Ordering::Release);
|
|
Some(t)
|
|
}
|
|
Err(_) => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
mod rng {
|
|
use core::cell::Cell;
|
|
|
|
pub struct XorShift64Star {
|
|
state: Cell<u64>,
|
|
}
|
|
|
|
impl XorShift64Star {
|
|
/// Initializes the prng with a seed. Provided seed must be nonzero.
|
|
pub fn new(seed: u64) -> Self {
|
|
XorShift64Star {
|
|
state: Cell::new(seed),
|
|
}
|
|
}
|
|
|
|
/// Returns a pseudorandom number.
|
|
pub fn next(&self) -> u64 {
|
|
let mut x = self.state.get();
|
|
debug_assert_ne!(x, 0);
|
|
x ^= x >> 12;
|
|
x ^= x << 25;
|
|
x ^= x >> 27;
|
|
self.state.set(x);
|
|
x.wrapping_mul(0x2545_f491_4f6c_dd1d)
|
|
}
|
|
|
|
/// Return a pseudorandom number from `0..n`.
|
|
pub fn next_usize(&self, n: usize) -> usize {
|
|
(self.next() % n as u64) as usize
|
|
}
|
|
}
|
|
}
|
|
|
|
pub mod scope {
|
|
use std::{
|
|
future::Future,
|
|
marker::PhantomData,
|
|
ptr::{self, NonNull},
|
|
};
|
|
|
|
use async_task::{Runnable, Task};
|
|
|
|
use crate::{
|
|
latch::{CountWakeLatch, Latch},
|
|
task::{HeapTask, TaskRef},
|
|
ThreadPool, WorkerThread,
|
|
};
|
|
|
|
pub struct Scope<'scope> {
|
|
pool: &'static ThreadPool,
|
|
tasks_completed_latch: CountWakeLatch,
|
|
_marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
|
|
}
|
|
|
|
impl<'scope> Scope<'scope> {
|
|
pub unsafe fn new(owner: &WorkerThread) -> Scope<'scope> {
|
|
Scope {
|
|
pool: owner.pool(),
|
|
tasks_completed_latch: CountWakeLatch::new(1, owner),
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
|
|
pub fn join<F, G, T, U>(&self, f: F, g: G) -> (T, U)
|
|
where
|
|
F: FnOnce(&Self) -> T + Send,
|
|
G: FnOnce(&Self) -> U + Send,
|
|
T: Send,
|
|
U: Send,
|
|
{
|
|
self.pool.join(|| f(self), || g(self))
|
|
}
|
|
|
|
pub fn spawn<Fn>(&self, f: Fn)
|
|
where
|
|
Fn: FnOnce(&Scope<'scope>) + Send + 'scope,
|
|
{
|
|
self.tasks_completed_latch.increment();
|
|
|
|
let ptr = SendPtr::from_ref(self);
|
|
let task = HeapTask::new(move || unsafe {
|
|
let this = ptr.as_ref();
|
|
f(this);
|
|
Latch::set_raw(&this.tasks_completed_latch);
|
|
});
|
|
|
|
let taskref = unsafe { task.into_task_ref() };
|
|
self.pool.inject_maybe_local(taskref);
|
|
}
|
|
|
|
pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T>
|
|
where
|
|
Fut: Future<Output = T> + Send + 'scope,
|
|
T: Send + 'scope,
|
|
{
|
|
self.tasks_completed_latch.increment();
|
|
|
|
let ptr = SendPtr::from_ref(self);
|
|
let schedule = move |runnable: Runnable| {
|
|
let taskref = unsafe {
|
|
TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| {
|
|
let this = NonNull::new_unchecked(this.cast_mut());
|
|
|
|
let runnable = Runnable::<()>::from_raw(this);
|
|
runnable.run();
|
|
})
|
|
};
|
|
|
|
unsafe {
|
|
ptr.as_ref().pool.inject_maybe_local(taskref);
|
|
}
|
|
};
|
|
|
|
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
|
|
|
runnable.schedule();
|
|
task
|
|
}
|
|
|
|
pub fn spawn_async<Fn, Fut, T>(&self, f: Fn) -> Task<T>
|
|
where
|
|
Fn: FnOnce() -> Fut + Send + 'scope,
|
|
Fut: Future<Output = T> + Send + 'scope,
|
|
T: Send + 'scope,
|
|
{
|
|
self.spawn_future(async move { f().await })
|
|
}
|
|
|
|
pub fn complete(&self, owner: &WorkerThread) {
|
|
unsafe {
|
|
Latch::set_raw(&self.tasks_completed_latch);
|
|
}
|
|
owner.run_until(&self.tasks_completed_latch);
|
|
}
|
|
}
|
|
|
|
struct SendPtr<T>(*const T);
|
|
impl<T> SendPtr<T> {
|
|
fn from_ref(t: &T) -> Self {
|
|
Self(ptr::from_ref(t).cast())
|
|
}
|
|
unsafe fn as_ref(&self) -> &T {
|
|
&*self.0
|
|
}
|
|
}
|
|
unsafe impl<T> Send for SendPtr<T> {}
|
|
}
|
|
|
|
// #[cfg(test)]
|
|
// mod tests {
|
|
// use std::{cell::Cell, hint::black_box, time::Instant};
|
|
|
|
// use tracing::info;
|
|
|
|
// use super::*;
|
|
|
|
// mod tree {
|
|
|
|
// pub struct Tree<T> {
|
|
// nodes: Box<[Node<T>]>,
|
|
// root: Option<usize>,
|
|
// }
|
|
// pub struct Node<T> {
|
|
// pub leaf: T,
|
|
// pub left: Option<usize>,
|
|
// pub right: Option<usize>,
|
|
// }
|
|
|
|
// impl<T> Tree<T> {
|
|
// pub fn new(depth: usize, t: T) -> Tree<T>
|
|
// where
|
|
// T: Copy,
|
|
// {
|
|
// let mut nodes = Vec::with_capacity((0..depth).sum());
|
|
// let root = Self::build_node(&mut nodes, depth, t);
|
|
// Self {
|
|
// nodes: nodes.into_boxed_slice(),
|
|
// root: Some(root),
|
|
// }
|
|
// }
|
|
|
|
// pub fn root(&self) -> Option<usize> {
|
|
// self.root
|
|
// }
|
|
|
|
// pub fn get(&self, index: usize) -> &Node<T> {
|
|
// &self.nodes[index]
|
|
// }
|
|
|
|
// pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
|
|
// where
|
|
// T: Copy,
|
|
// {
|
|
// let node = Node {
|
|
// leaf: t,
|
|
// left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
|
|
// right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
|
|
// };
|
|
// nodes.push(node);
|
|
// nodes.len() - 1
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// const PRIMES: &'static [usize] = &[
|
|
// 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283,
|
|
// 1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409,
|
|
// 1423, 1427, 1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493,
|
|
// 1499, 1511, 1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, 1607,
|
|
// 1609, 1613, 1619, 1621, 1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721,
|
|
// 1723, 1733, 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847,
|
|
// 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
|
|
// ];
|
|
|
|
// #[cfg(feature = "spin-slow")]
|
|
// const REPEAT: usize = 0x800;
|
|
// #[cfg(not(feature = "spin-slow"))]
|
|
// const REPEAT: usize = 0x8000;
|
|
|
|
// const TREE_SIZE: usize = 10;
|
|
|
|
// fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T {
|
|
// let pool = Box::new(pool);
|
|
// let ptr = Box::into_raw(pool);
|
|
|
|
// let result = {
|
|
// let pool: &'static ThreadPool = unsafe { &*ptr };
|
|
// // pool.ensure_one_worker();
|
|
// pool.resize_to_available();
|
|
// let now = std::time::Instant::now();
|
|
// let result = pool.scope(f);
|
|
// let elapsed = now.elapsed().as_micros();
|
|
// info!("(mine) total time: {}ms", elapsed as f32 / 1e3);
|
|
// pool.resize_to(0);
|
|
// assert!(pool.global_queue.is_empty());
|
|
// result
|
|
// };
|
|
|
|
// let _pool = unsafe { Box::from_raw(ptr) };
|
|
// result
|
|
// }
|
|
|
|
// #[test]
|
|
// #[tracing_test::traced_test]
|
|
// fn rayon() {
|
|
// let pool = rayon::ThreadPoolBuilder::new()
|
|
// .num_threads(bevy_tasks::available_parallelism())
|
|
// .build()
|
|
// .unwrap();
|
|
|
|
// let now = std::time::Instant::now();
|
|
// pool.scope(|s| {
|
|
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
|
|
// s.spawn(move |_| {
|
|
// black_box(spinning(p));
|
|
// });
|
|
// }
|
|
// });
|
|
// let elapsed = now.elapsed().as_micros();
|
|
|
|
// info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
|
|
// }
|
|
|
|
// #[test]
|
|
// #[tracing_test::traced_test]
|
|
// fn rayon_join() {
|
|
// let pool = rayon::ThreadPoolBuilder::new()
|
|
// .num_threads(bevy_tasks::available_parallelism())
|
|
// .build()
|
|
// .unwrap();
|
|
|
|
// let tree = tree::Tree::new(TREE_SIZE, 1u32);
|
|
|
|
// fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
|
|
// let node = tree.get(node);
|
|
// let (l, r) = rayon::join(
|
|
// || node.left.map(|node| sum(tree, node)).unwrap_or_default(),
|
|
// || node.right.map(|node| sum(tree, node)).unwrap_or_default(),
|
|
// );
|
|
|
|
// node.leaf + l + r
|
|
// }
|
|
|
|
// let now = std::time::Instant::now();
|
|
// let sum = pool.scope(move |s| {
|
|
// let root = tree.root().unwrap();
|
|
// sum(&tree, root)
|
|
// });
|
|
|
|
// let elapsed = now.elapsed().as_micros();
|
|
|
|
// info!("(rayon) {sum} total time: {}ms", elapsed as f32 / 1e3);
|
|
// }
|
|
|
|
// #[test]
|
|
// #[tracing_test::traced_test]
|
|
// fn bevy_tasks() {
|
|
// let pool = bevy_tasks::ComputeTaskPool::get_or_init(|| {
|
|
// bevy_tasks::TaskPoolBuilder::new()
|
|
// .num_threads(bevy_tasks::available_parallelism())
|
|
// .build()
|
|
// });
|
|
|
|
// let now = std::time::Instant::now();
|
|
// pool.scope(|s| {
|
|
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
|
|
// s.spawn(async move {
|
|
// black_box(spinning(p));
|
|
// });
|
|
// }
|
|
// });
|
|
// let elapsed = now.elapsed().as_micros();
|
|
|
|
// info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3);
|
|
// }
|
|
|
|
// #[test]
|
|
// #[tracing_test::traced_test]
|
|
// fn mine() {
|
|
// std::thread_local! {
|
|
// static WAIT_COUNT: Cell<usize> = const {Cell::new(0)};
|
|
// }
|
|
// let counter = Arc::new(AtomicUsize::new(0));
|
|
// {
|
|
// let pool = ThreadPool::new();
|
|
|
|
// run_in_scope(pool, |s| {
|
|
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
|
|
// s.spawn(move |_| {
|
|
// black_box(spinning(p));
|
|
// });
|
|
// }
|
|
// });
|
|
// };
|
|
|
|
// // eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
|
|
// }
|
|
|
|
// #[test]
|
|
// #[tracing_test::traced_test]
|
|
// fn mine_join() {
|
|
// let pool = ThreadPool::new();
|
|
|
|
// let tree = tree::Tree::new(TREE_SIZE, 1u32);
|
|
|
|
// fn sum(tree: &tree::Tree<u32>, node: usize, scope: &Scope<'_>) -> u32 {
|
|
// let node = tree.get(node);
|
|
// let (l, r) = scope.join(
|
|
// |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
|
// |s| {
|
|
// node.right
|
|
// .map(|node| sum(tree, node, s))
|
|
// .unwrap_or_default()
|
|
// },
|
|
// );
|
|
|
|
// node.leaf + l + r
|
|
// }
|
|
|
|
// let sum = run_in_scope(pool, move |s| {
|
|
// let root = tree.root().unwrap();
|
|
// sum(&tree, root, s)
|
|
// });
|
|
// }
|
|
|
|
// #[test]
|
|
// #[tracing_test::traced_test]
|
|
// fn melange_join() {
|
|
// let pool = melange::ThreadPool::new(bevy_tasks::available_parallelism());
|
|
|
|
// let mut scope = pool.new_worker();
|
|
|
|
// let tree = tree::Tree::new(TREE_SIZE, 1u32);
|
|
|
|
// fn sum(tree: &tree::Tree<u32>, node: usize, scope: &mut melange::WorkerThread) -> u32 {
|
|
// let node = tree.get(node);
|
|
// let (l, r) = scope.join(
|
|
// |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
|
// |s| {
|
|
// node.right
|
|
// .map(|node| sum(tree, node, s))
|
|
// .unwrap_or_default()
|
|
// },
|
|
// );
|
|
|
|
// node.leaf + l + r
|
|
// }
|
|
// let now = Instant::now();
|
|
// let res = sum(&tree, tree.root().unwrap(), &mut scope);
|
|
// eprintln!(
|
|
// "res: {res} took {}ms",
|
|
// now.elapsed().as_micros() as f32 / 1e3
|
|
// );
|
|
// assert_ne!(res, 0);
|
|
// }
|
|
|
|
// #[test]
|
|
// #[tracing_test::traced_test]
|
|
// fn sync() {
|
|
// let now = std::time::Instant::now();
|
|
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
|
|
// black_box(spinning(p));
|
|
// }
|
|
// let elapsed = now.elapsed().as_micros();
|
|
|
|
// info!("(sync) total time: {}ms", elapsed as f32 / 1e3);
|
|
// }
|
|
|
|
// #[inline]
|
|
// fn spinning(i: usize) {
|
|
// #[cfg(feature = "spin-slow")]
|
|
// spinning_slow(i);
|
|
// #[cfg(not(feature = "spin-slow"))]
|
|
// spinning_fast(i);
|
|
// }
|
|
|
|
// #[inline]
|
|
// fn spinning_slow(i: usize) {
|
|
// let rng = rng::XorShift64Star::new(i as u64);
|
|
// (0..i).reduce(|a, b| {
|
|
// black_box({
|
|
// let a = rng.next_usize(a.max(1));
|
|
// ((b as f32).exp() * (a as f32).sin().cbrt()).to_bits() as usize
|
|
// })
|
|
// });
|
|
// }
|
|
|
|
// #[inline]
|
|
// fn spinning_fast(i: usize) {
|
|
// let rng = rng::XorShift64Star::new(i as u64);
|
|
// //(0..rng.next_usize(i)).reduce(|a, b| {
|
|
// (0..20).reduce(|a, b| {
|
|
// black_box({
|
|
// let a = rng.next_usize(a.max(1));
|
|
// a ^ b
|
|
// })
|
|
// });
|
|
// }
|
|
// }
|