executor/src/lib.rs
2025-06-21 00:28:18 +02:00

1922 lines
52 KiB
Rust

#![feature(
vec_deque_pop_if,
unsafe_cell_access,
debug_closure_helpers,
cold_path,
fn_align,
box_vec_non_null,
box_as_ptr,
atomic_try_update,
let_chains
)]
use std::{
cell::{Cell, UnsafeCell},
future::Future,
mem::MaybeUninit,
num::NonZero,
pin::{pin, Pin},
ptr::{self, NonNull},
sync::{
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
Arc,
},
task::Context,
thread::available_parallelism,
time::Duration,
};
use async_task::{Runnable, Task};
use bitflags::bitflags;
use crossbeam::{
atomic::AtomicCell,
deque::{Injector, Stealer, Worker},
utils::CachePadded,
};
use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch};
use parking_lot::{Condvar, Mutex};
use scope::Scope;
use task::{HeapTask, StackTask, TaskRef};
use tracing::debug;
pub mod job;
pub mod util;
pub mod task {
use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin};
pub trait Task {
unsafe fn execute(this: *const ());
}
pub struct TaskRef {
ptr: *const (),
execute_fn: unsafe fn(*const ()),
}
impl TaskRef {
pub unsafe fn new<T>(task: *const T) -> TaskRef
where
T: Task,
{
Self {
ptr: task.cast(),
execute_fn: <T as Task>::execute,
}
}
pub unsafe fn new_raw(ptr: *const (), execute_fn: unsafe fn(*const ())) -> TaskRef {
Self { ptr, execute_fn }
}
#[inline]
pub fn id(&self) -> impl Eq {
(self.ptr, self.execute_fn)
}
#[inline]
pub fn execute(self) {
unsafe { (self.execute_fn)(self.ptr) }
}
#[inline]
pub unsafe fn execute_with_scope<T>(self, scope: &mut T) {
unsafe {
core::mem::transmute::<_, unsafe fn(*const (), &mut T)>(self.execute_fn)(
self.ptr, scope,
)
}
}
}
unsafe impl Send for TaskRef {}
unsafe impl Sync for TaskRef {}
pub struct StackTask<F: FnOnce() + Send> {
task: UnsafeCell<Option<F>>,
_phantom: PhantomPinned,
}
impl<F: FnOnce() + Send> StackTask<F> {
pub fn new(task: F) -> StackTask<F> {
Self {
task: UnsafeCell::new(Some(task)),
_phantom: PhantomPinned,
}
}
#[inline]
pub fn run(self) {
self.task.into_inner().unwrap()();
}
#[inline]
pub unsafe fn run_as_ref(&self) {
((&mut *self.task.get()).take().unwrap())();
}
#[inline]
pub fn as_task_ref(self: Pin<&Self>) -> TaskRef {
unsafe { TaskRef::new(&*self) }
}
}
impl<F: FnOnce() + Send> Task for StackTask<F> {
#[inline]
unsafe fn execute(this: *const ()) {
let this = &*this.cast::<Self>();
let task = (&mut *this.task.get()).take().unwrap();
task();
}
}
pub struct HeapTask<F: FnOnce() + Send> {
task: F,
_phantom: PhantomPinned,
}
impl<F: FnOnce() + Send> HeapTask<F> {
pub fn new(task: F) -> Box<HeapTask<F>> {
Box::new(Self {
task,
_phantom: PhantomPinned,
})
}
#[inline]
pub unsafe fn into_static_task_ref(self: Box<Self>) -> TaskRef
where
F: 'static,
{
self.into_task_ref()
}
#[inline]
pub unsafe fn into_task_ref(self: Box<Self>) -> TaskRef {
TaskRef::new(Box::into_raw(self))
}
}
impl<F: FnOnce() + Send> Task for HeapTask<F> {
#[inline]
unsafe fn execute(this: *const ()) {
let this = Box::from_raw(this.cast::<Self>().cast_mut());
(this.task)();
}
}
}
pub mod latch {
use core::marker::PhantomData;
use std::{
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
task::Wake,
};
use parking_lot::{Condvar, Mutex};
use crate::{ThreadPool, WorkerThread};
pub trait Latch {
unsafe fn set_raw(this: *const Self);
}
pub trait Probe {
fn probe(&self) -> bool;
}
#[derive(Debug)]
pub struct AtomicLatch(AtomicBool);
impl AtomicLatch {
#[inline]
pub const fn new() -> AtomicLatch {
Self(AtomicBool::new(false))
}
#[inline]
pub fn reset(&self) {
self.0.store(false, Ordering::Release);
}
}
impl Latch for AtomicLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
(*this).0.store(true, Ordering::Release);
}
}
impl Probe for AtomicLatch {
#[inline]
fn probe(&self) -> bool {
self.0.load(Ordering::Acquire)
}
}
pub struct ClosureLatch<S, P> {
set: S,
probe: P,
}
impl<S, P> ClosureLatch<S, P> {
pub fn new(set: S, probe: P) -> Self {
Self { set, probe }
}
pub fn new_boxed(set: S, probe: P) -> Box<Self> {
Box::new(Self { set, probe })
}
}
impl<S, P> Latch for ClosureLatch<S, P>
where
S: Fn(),
{
unsafe fn set_raw(this: *const Self) {
let this = &*this;
(this.set)();
}
}
impl<S, P> Probe for ClosureLatch<S, P>
where
P: Fn() -> bool,
{
fn probe(&self) -> bool {
(self.probe)()
}
}
pub struct ThreadWakeLatch {
inner: AtomicLatch,
index: usize,
pool: &'static ThreadPool,
}
impl ThreadWakeLatch {
#[inline]
pub const fn new(thread: &WorkerThread) -> ThreadWakeLatch {
Self {
inner: AtomicLatch::new(),
pool: thread.pool,
index: thread.index,
}
}
#[inline]
pub fn reset(&self) {
self.inner.reset()
}
}
impl Latch for ThreadWakeLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
let (pool, index) = {
let this = &*this;
(this.pool, this.index)
};
Latch::set_raw(&(*this).inner);
pool.wake_thread(index);
}
}
impl Probe for ThreadWakeLatch {
#[inline]
fn probe(&self) -> bool {
self.inner.probe()
}
}
pub struct LatchRef<'a, L: Latch> {
inner: *const L,
_marker: PhantomData<&'a L>,
}
impl<'a, L: Latch> LatchRef<'a, L> {
#[inline]
pub const fn new(latch: &'a L) -> Self {
Self {
inner: latch,
_marker: PhantomData,
}
}
}
impl<'a, L: Latch> Latch for LatchRef<'a, L> {
#[inline]
unsafe fn set_raw(this: *const Self) {
let this = &*this;
Latch::set_raw(this.inner);
}
}
impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> {
#[inline]
fn probe(&self) -> bool {
unsafe {
let this = &*self.inner;
Probe::probe(this)
}
}
}
pub struct NopLatch;
impl Latch for NopLatch {
#[inline]
unsafe fn set_raw(_this: *const Self) {
// do nothing
}
}
pub struct MutexLatch {
mutex: Mutex<bool>,
signal: Condvar,
}
impl MutexLatch {
#[inline]
pub const fn new() -> MutexLatch {
Self {
mutex: Mutex::new(false),
signal: Condvar::new(),
}
}
#[inline]
pub fn wait(&self) {
let mut guard = self.mutex.lock();
while !*guard {
self.signal.wait(&mut guard);
}
}
#[inline]
pub fn wait_and_reset(&self) {
let mut guard = self.mutex.lock();
while !*guard {
self.signal.wait(&mut guard);
}
*guard = false;
}
}
impl Latch for MutexLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
let mut guard = (*this).mutex.lock();
*guard = true;
(*this).signal.notify_all();
}
}
pub struct CountWakeLatch {
counter: AtomicUsize,
inner: ThreadWakeLatch,
}
impl CountWakeLatch {
#[inline]
pub const fn new(count: usize, thread: &WorkerThread) -> CountWakeLatch {
Self {
counter: AtomicUsize::new(count),
inner: ThreadWakeLatch::new(thread),
}
}
#[inline]
pub fn increment(&self) {
self.counter.fetch_add(1, Ordering::Relaxed);
}
}
impl Latch for CountWakeLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
Latch::set_raw(&(*this).inner);
}
}
}
impl Probe for CountWakeLatch {
#[inline]
fn probe(&self) -> bool {
self.inner.probe()
}
}
pub struct LatchWaker<L>(L);
impl<L> LatchWaker<L> {
#[inline]
pub fn new(latch: L) -> Arc<Self> {
Arc::new(Self(latch))
}
#[inline]
pub fn latch(&self) -> &L {
&self.0
}
}
impl<L> Wake for LatchWaker<L>
where
L: Latch,
{
#[inline]
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
#[inline]
fn wake_by_ref(self: &Arc<Self>) {
unsafe {
Latch::set_raw(&self.0);
}
}
}
}
pub mod melange;
pub mod praetor;
pub struct ThreadPoolState {
num_threads: AtomicUsize,
lock: Mutex<()>,
heartbeat_state: CachePadded<ThreadControl>,
}
bitflags! {
#[derive(Clone)]
pub struct ThreadStatus: u8 {
const RUNNING = 1 << 0;
const SLEEPING = 1 << 1;
const SHOULD_WAKE = 1 << 2;
}
}
pub struct ThreadControl {
status: Mutex<ThreadStatus>,
status_changed: Condvar,
should_terminate: AtomicLatch,
}
pub struct ThreadState {
should_shove: AtomicBool,
control: ThreadControl,
stealer: Stealer<TaskRef>,
worker: AtomicCell<Option<Worker<TaskRef>>>,
shoved_task: CachePadded<Slot<TaskRef>>,
}
impl ThreadControl {
pub const fn new() -> Self {
Self {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
}
}
/// returns true if thread was sleeping
#[inline]
pub fn wake(&self) -> bool {
let mut guard = self.status.lock();
guard.insert(ThreadStatus::SHOULD_WAKE);
self.status_changed.notify_all();
guard.contains(ThreadStatus::SLEEPING)
}
#[inline]
pub fn wait_for_running(&self) {
let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::RUNNING) {
self.status_changed.wait(&mut guard);
}
}
#[inline]
pub fn wait_for_should_wake(&self) {
let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::SHOULD_WAKE) {
guard.insert(ThreadStatus::SLEEPING);
self.status_changed.wait(&mut guard);
}
guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING);
}
#[inline]
pub fn wait_for_should_wake_timeout(&self, timeout: Duration) {
let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::SHOULD_WAKE) {
guard.insert(ThreadStatus::SLEEPING);
if self
.status_changed
.wait_for(&mut guard, timeout)
.timed_out()
{
break;
}
}
guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING);
}
#[inline]
pub fn wait_for_termination(&self) {
let mut guard = self.status.lock();
while guard.contains(ThreadStatus::RUNNING) {
self.status_changed.wait(&mut guard);
}
}
#[inline]
pub fn notify_running(&self) {
let mut guard = self.status.lock();
guard.insert(ThreadStatus::RUNNING);
self.status_changed.notify_all();
}
#[inline]
pub fn notify_termination(&self) {
let mut guard = self.status.lock();
*guard = ThreadStatus::empty();
self.status_changed.notify_all();
}
#[inline]
pub fn notify_should_terminate(&self) {
unsafe {
Latch::set_raw(&self.should_terminate);
}
self.wake();
}
}
const MAX_THREADS: usize = 32;
type ThreadCallback = dyn Fn(&WorkerThread) + Send + Sync + 'static;
pub struct ThreadPoolCallbacks {
at_entry: Option<Arc<ThreadCallback>>,
at_exit: Option<Arc<ThreadCallback>>,
}
impl ThreadPoolCallbacks {
pub const fn new_empty() -> ThreadPoolCallbacks {
Self {
at_entry: None,
at_exit: None,
}
}
}
pub struct ThreadPool {
threads: [CachePadded<ThreadState>; MAX_THREADS],
pool_state: CachePadded<ThreadPoolState>,
global_queue: Injector<TaskRef>,
callbacks: CachePadded<ThreadPoolCallbacks>,
}
impl ThreadPool {
pub fn new() -> Self {
Self::new_with_callbacks(ThreadPoolCallbacks::new_empty())
}
pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| {
let worker = Worker::<TaskRef>::new_fifo();
let stealer = worker.stealer();
let thread = CachePadded::new(ThreadState {
should_shove: AtomicBool::new(false),
shoved_task: Slot::new().into(),
control: ThreadControl {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
},
stealer,
worker: AtomicCell::new(Some(worker)),
});
uninit.write(thread);
unsafe { uninit.assume_init() }
});
Self {
threads,
pool_state: CachePadded::new(ThreadPoolState {
num_threads: AtomicUsize::new(0),
lock: Mutex::new(()),
heartbeat_state: ThreadControl {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
}
.into(),
}),
global_queue: Injector::new(),
callbacks: CachePadded::new(callbacks),
}
}
#[inline]
fn threads(&self) -> &[CachePadded<ThreadState>] {
&self.threads[..self.pool_state.num_threads.load(Ordering::Relaxed) as usize]
}
pub fn wake_thread(&self, index: usize) -> Option<bool> {
Some(self.threads.get(index as usize)?.control.wake())
}
pub fn wake_any(&self, count: usize) -> usize {
if count > 0 {
let num_woken = self
.threads
.iter()
.filter_map(|thread| thread.control.wake().then_some(()))
.take(count)
.count();
num_woken
} else {
0
}
}
#[inline]
pub fn id(&self) -> impl Eq {
core::ptr::from_ref(self) as usize
}
#[allow(dead_code)]
fn push_local_or_inject_balanced(&self, task: TaskRef) {
let global_len = self.global_queue.len();
WorkerThread::with(|worker| match worker {
Some(worker) if worker.pool.id() == self.id() => {
let worker_len = worker.worker.len();
if worker_len == 0 {
worker.push_task(task);
} else if global_len == 0 {
self.inject(task);
} else {
if worker_len >= global_len {
worker.push_task(task);
} else {
self.inject(task);
}
}
}
_ => self.inject(task),
})
}
fn push_local_or_inject(&self, task: TaskRef) {
WorkerThread::with(|worker| match worker {
Some(worker) if worker.pool.id() == self.id() => worker.push_task(task),
_ => self.inject(task),
})
}
#[allow(dead_code)]
fn inject_many<I>(&self, tasks: I)
where
I: Iterator<Item = TaskRef>,
{
let mut n = 0;
for task in tasks {
n += 1;
self.global_queue.push(task);
}
self.wake_any(n);
}
#[allow(unused_variables)]
fn inject_maybe_local(&self, task: TaskRef) {
#[cfg(all(not(feature = "never-local"), feature = "prefer-local"))]
self.push_local_or_inject(task);
#[cfg(all(not(feature = "prefer-local"), feature = "never-local"))]
self.inject(task);
#[cfg(not(any(feature = "prefer-local", feature = "never-local")))]
self.push_local_or_inject_balanced(task);
}
fn inject(&self, task: TaskRef) {
self.global_queue.push(task);
self.wake_any(1);
}
#[allow(dead_code)]
fn resize<F: Fn(usize) -> usize>(&'static self, size: F) -> usize {
if WorkerThread::is_worker_thread() {
// acquire required here?
debug!("tried to resize from within threadpool!");
return self.pool_state.num_threads.load(Ordering::Acquire);
}
#[cfg(feature = "cpu-pinning")]
let cpus = core_affinity::get_core_ids().unwrap();
let _guard = self.pool_state.lock.lock();
let current_size = self.pool_state.num_threads.load(Ordering::Acquire);
let new_size = size(current_size).clamp(0, MAX_THREADS);
debug!(current_size, new_size, "resizing threadpool");
if new_size == current_size {
return current_size;
}
self.pool_state
.num_threads
.store(new_size, Ordering::Release);
match new_size.cmp(&current_size) {
std::cmp::Ordering::Greater => {
let new_threads = &self.threads[current_size..new_size];
for (i, _) in new_threads.iter().enumerate() {
#[cfg(feature = "cpu-pinning")]
let core = cpus[i];
std::thread::spawn(move || {
#[cfg(feature = "cpu-pinning")]
core_affinity::set_for_current(core);
WorkerThread::worker_loop(&self, current_size + i);
});
}
for thread in new_threads {
thread.control.wait_for_running();
}
#[cfg(feature = "heartbeat")]
if current_size == 0 {
std::thread::spawn(move || {
heartbeat_loop(self);
});
self.pool_state.heartbeat_state.wait_for_running();
}
}
std::cmp::Ordering::Less => {
debug!(
"waiting for threads {:?} to terminate.",
new_size..current_size
);
let terminating_threads = &self.threads[new_size..current_size];
for thread in terminating_threads {
thread.control.notify_should_terminate();
}
for thread in terminating_threads {
thread.control.wait_for_termination();
}
#[cfg(feature = "heartbeat")]
if new_size == 0 {
self.pool_state.heartbeat_state.notify_should_terminate();
self.pool_state.heartbeat_state.wait_for_termination();
}
}
std::cmp::Ordering::Equal => unreachable!(),
}
new_size
}
#[allow(dead_code)]
fn ensure_one_worker(&'static self) -> usize {
self.resize(|current| current.max(1))
}
#[allow(dead_code)]
fn resize_to_available(&'static self) {
self.resize_to(available_parallelism().map(NonZero::get).unwrap_or(1));
}
#[allow(dead_code)]
fn resize_to(&'static self, new_size: usize) -> usize {
self.resize(|_| new_size)
}
#[allow(dead_code)]
fn grow_by(&'static self, num_threads: usize) -> usize {
self.resize(|current| current.saturating_add(num_threads))
}
#[allow(dead_code)]
fn shrink_by(&'static self, num_threads: usize) -> usize {
self.resize(|current| current.saturating_sub(num_threads))
}
#[allow(dead_code)]
fn shrink_to(&'static self, num_threads: usize) -> usize {
self.resize(|_| num_threads)
}
fn in_worker<F, T>(&'static self, f: F) -> T
where
F: FnOnce(&WorkerThread, bool) -> T + Send,
T: Send,
{
WorkerThread::with(|worker| match worker {
Some(worker) => {
if worker.pool.id() == self.id() {
self.in_worker_cross(worker, f)
} else {
f(worker, false)
}
}
None => self.in_worker_cold(f),
})
}
#[cold]
fn in_worker_cold<F, T>(&'static self, f: F) -> T
where
F: FnOnce(&WorkerThread, bool) -> T + Send,
T: Send,
{
std::thread_local! {static LATCH: MutexLatch = const {MutexLatch::new()}};
LATCH.with(|latch| {
let mut result = None;
let task = StackTask::new(|| {
WorkerThread::with(|worker| {
let worker = worker.unwrap();
result = Some(f(worker, true));
unsafe {
// SAFETY: static thread-local
Latch::set_raw(latch);
}
})
});
let pinned = pin!(task);
let taskref = pinned.as_ref().as_task_ref();
self.inject(taskref);
latch.wait_and_reset();
result.unwrap()
})
}
/// run f in `self`, but block current thread until work is complete.
fn in_worker_cross<F, T>(&'static self, worker: &WorkerThread, f: F) -> T
where
F: FnOnce(&WorkerThread, bool) -> T + Send,
T: Send,
{
let latch = ThreadWakeLatch::new(worker);
let mut result = None;
let task = pin!(StackTask::new(|| {
WorkerThread::with(|worker| {
let worker = worker.unwrap();
result = Some(f(worker, true));
unsafe {
// SAFETY: static thread-local
Latch::set_raw(&latch);
}
})
}));
let taskref = task.into_ref().as_task_ref();
self.push_local_or_inject(taskref);
worker.run_until(&latch);
result.unwrap()
}
}
impl ThreadPool {
pub fn spawn<Fn>(&'static self, f: Fn)
where
Fn: FnOnce() + Send + 'static,
{
let task = HeapTask::new(f);
let taskref = unsafe { task.into_static_task_ref() };
self.inject_maybe_local(taskref);
}
pub fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T>
where
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let schedule = move |runnable: Runnable| {
let taskref = unsafe {
TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| {
let this = NonNull::new_unchecked(this.cast_mut());
let runnable = Runnable::<()>::from_raw(this);
runnable.run();
})
};
self.inject_maybe_local(taskref);
};
let (runnable, task) = async_task::spawn(future, schedule);
runnable.schedule();
task
}
pub fn spawn_async<Fn, Fut, T>(&'static self, f: Fn) -> Task<T>
where
Fn: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.spawn_future(async move { f().await })
}
pub fn block_on<Fut, T>(&'static self, mut future: Fut)
where
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let mut future = unsafe { Pin::new_unchecked(&mut future) };
self.in_worker(|worker, _| {
let wake = LatchWaker::new(ThreadWakeLatch::new(worker));
let ctx_waker = Arc::clone(&wake).into();
let mut ctx = Context::from_waker(&ctx_waker);
loop {
match future.as_mut().poll(&mut ctx) {
std::task::Poll::Ready(t) => {
return t;
}
std::task::Poll::Pending => {
worker.run_until(wake.latch());
wake.latch().reset();
}
}
}
});
}
pub fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
self.join_threaded(f, g)
}
#[allow(dead_code)]
fn join_seq<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
let a = f();
let b = g();
(a, b)
}
#[inline]
fn join_threaded<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
self.in_worker(|worker, _| {
let mut result_b = None;
let latch_b = ThreadWakeLatch::new(worker);
let task_b = pin!(StackTask::new(|| {
result_b = Some(g());
unsafe {
Latch::set_raw(&latch_b);
}
}));
let ref_b = task_b.as_ref().as_task_ref();
let b_id = ref_b.id();
worker.push_task(ref_b);
let result_a = f();
while !latch_b.probe() {
match worker.pop_task() {
Some(task) => {
if task.id() == b_id {
// we're not calling execute() here, so manually try
// shoving a task.
//worker.try_promote();
worker.shove_task();
unsafe {
task_b.run_as_ref();
}
break;
} else {
worker.execute(task);
}
}
None => {
worker.run_until(&latch_b);
}
}
}
(result_a, result_b.unwrap())
})
}
pub fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T
where
Fn: FnOnce(&Scope<'scope>) -> T + Send,
T: Send,
{
self.in_worker(|owner, _| {
let scope = unsafe { Scope::<'scope>::new(owner) };
let result = f(&scope);
scope.complete(owner);
result
})
}
}
pub struct WorkerThread {
// queue: TaskQueue<TaskRef>,
worker: Worker<TaskRef>,
pool: &'static ThreadPool,
index: usize,
rng: rng::XorShift64Star,
}
const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) };
std::thread_local! {
static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())};
}
impl WorkerThread {
#[inline]
fn info(&self) -> &ThreadState {
&self.pool.threads[self.index as usize]
}
#[inline]
fn pool(&self) -> &'static ThreadPool {
self.pool
}
#[inline]
#[allow(dead_code)]
fn index(&self) -> usize {
self.index
}
#[inline]
fn is_worker_thread() -> bool {
Self::with(|worker| worker.is_some())
}
fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T {
WORKER_THREAD_STATE.with(|thread| {
f(NonNull::<WorkerThread>::new(thread.get().cast_mut())
.map(|ptr| unsafe { ptr.as_ref() }))
})
}
#[inline]
fn pop_task(&self) -> Option<TaskRef> {
self.worker.pop()
//self.queue.pop_front(task);
}
#[inline]
fn push_task(&self, task: TaskRef) {
self.worker.push(task);
//self.queue.push_front(task);
}
#[inline]
fn drain(&self) -> impl Iterator<Item = TaskRef> {
// self.queue.drain()
core::iter::empty()
}
#[inline]
fn steal_tasks(&self) -> Option<TaskRef> {
// careful not to call threads() here because that omits any threads
// that were killed, which might still have tasks.
let threads = &self.pool.threads;
let (start, end) = threads.split_at(self.rng.next_usize(threads.len()));
end.iter()
.chain(start)
.find_map(|thread: &CachePadded<ThreadState>| {
thread.stealer.steal_batch_and_pop(&self.worker).success()
})
}
#[inline]
fn claim_shoved_task(&self) -> Option<TaskRef> {
// take own shoved task first
if let Some(task) = self.info().shoved_task.try_take() {
return Some(task);
}
let threads = self.pool.threads();
if threads.is_empty() {
return None;
}
let (start, end) = threads.split_at(self.rng.next_usize(threads.len()));
end.iter()
.chain(start)
.find_map(|thread| thread.shoved_task.try_take())
}
#[cold]
fn shove_task(&self) {
if !self.info().shoved_task.is_occupied() {
if let Some(task) = self.info().stealer.steal().success() {
match self.info().shoved_task.try_put(task) {
// shoved task is occupied, reinsert into queue
// this really shouldn't happen
Some(_task) => unreachable!(),
None => {}
}
}
} else {
// wake thread to execute task
self.pool.wake_any(1);
}
}
fn execute(&self, task: TaskRef) {
self.try_promote();
task.execute();
}
#[inline]
fn try_promote(&self) {
#[cfg(feature = "heartbeat")]
let should_shove = self.info().should_shove.load(Ordering::Acquire);
#[cfg(not(feature = "heartbeat"))]
let should_shove = true;
if should_shove {
#[cfg(feature = "heartbeat")]
self.info().should_shove.store(false, Ordering::Release);
self.shove_task();
}
}
#[inline]
fn find_any_task(&self) -> Option<TaskRef> {
// TODO: attempt stealing work here, too.
#[allow(unused_mut)]
let mut task = self
.pop_task()
.or_else(|| self.claim_shoved_task())
.or_else(|| {
self.pool
.global_queue
.steal_batch_and_pop(&self.worker)
.success()
});
#[cfg(feature = "work-stealing")]
{
task = task.or_else(|| self.steal_tasks());
}
task
}
#[inline]
fn run_until<L>(&self, latch: &L)
where
L: Probe,
{
if !latch.probe() {
self.run_until_cold(latch);
}
}
#[cold]
fn run_until_cold<L>(&self, latch: &L)
where
L: Probe,
{
while !latch.probe() {
self.run_until_inner();
}
}
#[inline]
fn run_until_inner(&self) {
match self.find_any_task() {
Some(task) => {
self.execute(task);
}
None => {
//debug!("waiting for tasks");
self.info().control.wait_for_should_wake();
}
}
}
fn worker_loop(pool: &'static ThreadPool, index: usize) {
let info = &pool.threads()[index as usize];
let worker = CachePadded::new(WorkerThread {
// queue: TaskQueue::new(),
worker: info.worker.take().unwrap(),
pool,
index,
rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64),
});
WORKER_THREAD_STATE.with(|cell| {
cell.set(&*worker);
if let Some(callback) = pool.callbacks.at_entry.as_ref() {
callback(&worker);
}
info.control.notify_running();
// info.notify_running();
worker.run_until(&info.control.should_terminate);
if let Some(callback) = pool.callbacks.at_exit.as_ref() {
callback(&worker);
}
for task in worker.drain() {
pool.inject(task);
}
if let Some(task) = info.shoved_task.try_take() {
pool.inject(task);
}
cell.set(ptr::null());
});
let WorkerThread { worker, .. } = CachePadded::into_inner(worker);
info.worker.store(Some(worker));
info.control.notify_termination();
}
}
fn heartbeat_loop(pool: &'static ThreadPool) {
let state = &pool.pool_state.heartbeat_state;
state.notify_running();
let mut i = 0;
while !state.should_terminate.probe() {
let threads = pool.threads();
if threads.is_empty() {
break;
}
if i >= threads.len() {
i = 0;
continue;
}
threads[i].should_shove.store(true, Ordering::Relaxed);
i += 1;
let interval = HEARTBEAT_INTERVAL / threads.len() as u32;
state.wait_for_should_wake_timeout(interval);
}
state.notify_termination();
}
mod vec_queue {
use std::{cell::UnsafeCell, collections::VecDeque};
pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>);
impl<T> TaskQueue<T> {
/// Creates a new [`TaskQueue<T>`].
#[inline]
#[allow(dead_code)]
pub const fn new() -> Self {
Self(UnsafeCell::new(VecDeque::new()))
}
#[inline]
#[allow(dead_code)]
pub fn get_mut(&self) -> &mut VecDeque<T> {
unsafe { &mut *self.0.get() }
}
#[inline]
#[allow(dead_code)]
pub fn pop_front(&self) -> Option<T> {
self.get_mut().pop_front()
}
#[inline]
#[allow(dead_code)]
pub fn pop_back(&self) -> Option<T> {
self.get_mut().pop_back()
}
#[inline]
#[allow(dead_code)]
pub fn push_back(&self, t: T) {
self.get_mut().push_back(t);
}
#[inline]
#[allow(dead_code)]
pub fn push_front(&self, t: T) {
self.get_mut().push_front(t);
}
#[inline]
#[allow(dead_code)]
pub fn take(&self) -> VecDeque<T> {
let this = core::mem::replace(self.get_mut(), VecDeque::new());
this
}
#[inline]
#[allow(dead_code)]
pub fn drain(&self) -> impl Iterator<Item = T> {
self.take().into_iter()
}
}
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SlotState {
None,
Locked,
Occupied,
}
impl From<u8> for SlotState {
fn from(value: u8) -> Self {
unsafe { core::mem::transmute(value) }
}
}
impl From<SlotState> for u8 {
fn from(value: SlotState) -> Self {
value as u8
}
}
pub struct Slot<T> {
slot: UnsafeCell<MaybeUninit<T>>,
state: AtomicU8,
}
unsafe impl<T> Send for Slot<T> where T: Send {}
unsafe impl<T> Sync for Slot<T> where T: Send {}
impl<T> Drop for Slot<T> {
fn drop(&mut self) {
if core::mem::needs_drop::<T>() {
if *self.state.get_mut() == SlotState::Occupied as u8 {
unsafe {
self.slot.get().drop_in_place();
}
}
}
}
}
impl<T> Slot<T> {
pub const fn new() -> Slot<T> {
Self {
slot: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(SlotState::None as u8),
}
}
pub fn is_occupied(&self) -> bool {
self.state.load(Ordering::Acquire) == SlotState::Occupied.into()
}
#[inline]
pub fn insert(&self, t: T) -> Option<T> {
let value = match self
.state
.swap(SlotState::Locked.into(), Ordering::AcqRel)
.into()
{
SlotState::Locked => {
// return early: was already locked.
debug!("slot was already locked");
return None;
}
SlotState::Occupied => {
let slot = self.slot.get();
// replace
unsafe {
let v = (*slot).assume_init_read();
(*slot).write(t);
Some(v)
}
}
SlotState::None => {
let slot = self.slot.get();
// insert
unsafe {
(*slot).write(t);
}
None
}
};
// release lock
self.state
.store(SlotState::Occupied.into(), Ordering::Release);
value
}
#[inline]
pub fn try_put(&self, t: T) -> Option<T> {
match self.state.compare_exchange(
SlotState::None.into(),
SlotState::Locked.into(),
Ordering::Acquire,
Ordering::Relaxed,
) {
Err(_) => Some(t),
Ok(_) => {
let slot = self.slot.get();
// SAFETY: we hold LOCKED on the spinlock
unsafe { (*slot).write(t) };
// release lock
self.state
.store(SlotState::Occupied.into(), Ordering::Release);
None
}
}
}
#[inline]
pub fn try_take(&self) -> Option<T> {
match self.state.compare_exchange(
SlotState::Occupied.into(),
SlotState::Locked.into(),
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => {
let slot = self.slot.get();
// SAFETY: we hold LOCKED on the spinlock
let t = unsafe { (*slot).assume_init_read() };
// release lock
self.state.store(SlotState::None.into(), Ordering::Release);
Some(t)
}
Err(_) => None,
}
}
}
mod rng {
use core::cell::Cell;
pub struct XorShift64Star {
state: Cell<u64>,
}
impl XorShift64Star {
/// Initializes the prng with a seed. Provided seed must be nonzero.
pub fn new(seed: u64) -> Self {
XorShift64Star {
state: Cell::new(seed),
}
}
/// Returns a pseudorandom number.
pub fn next(&self) -> u64 {
let mut x = self.state.get();
debug_assert_ne!(x, 0);
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.state.set(x);
x.wrapping_mul(0x2545_f491_4f6c_dd1d)
}
/// Return a pseudorandom number from `0..n`.
pub fn next_usize(&self, n: usize) -> usize {
(self.next() % n as u64) as usize
}
}
}
pub mod scope {
use std::{
future::Future,
marker::PhantomData,
ptr::{self, NonNull},
};
use async_task::{Runnable, Task};
use crate::{
latch::{CountWakeLatch, Latch},
task::{HeapTask, TaskRef},
ThreadPool, WorkerThread,
};
pub struct Scope<'scope> {
pool: &'static ThreadPool,
tasks_completed_latch: CountWakeLatch,
_marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
}
impl<'scope> Scope<'scope> {
pub unsafe fn new(owner: &WorkerThread) -> Scope<'scope> {
Scope {
pool: owner.pool(),
tasks_completed_latch: CountWakeLatch::new(1, owner),
_marker: PhantomData,
}
}
pub fn join<F, G, T, U>(&self, f: F, g: G) -> (T, U)
where
F: FnOnce(&Self) -> T + Send,
G: FnOnce(&Self) -> U + Send,
T: Send,
U: Send,
{
self.pool.join(|| f(self), || g(self))
}
pub fn spawn<Fn>(&self, f: Fn)
where
Fn: FnOnce(&Scope<'scope>) + Send + 'scope,
{
self.tasks_completed_latch.increment();
let ptr = SendPtr::from_ref(self);
let task = HeapTask::new(move || unsafe {
let this = ptr.as_ref();
f(this);
Latch::set_raw(&this.tasks_completed_latch);
});
let taskref = unsafe { task.into_task_ref() };
self.pool.inject_maybe_local(taskref);
}
pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T>
where
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.tasks_completed_latch.increment();
let ptr = SendPtr::from_ref(self);
let schedule = move |runnable: Runnable| {
let taskref = unsafe {
TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| {
let this = NonNull::new_unchecked(this.cast_mut());
let runnable = Runnable::<()>::from_raw(this);
runnable.run();
})
};
unsafe {
ptr.as_ref().pool.inject_maybe_local(taskref);
}
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
pub fn spawn_async<Fn, Fut, T>(&self, f: Fn) -> Task<T>
where
Fn: FnOnce() -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.spawn_future(async move { f().await })
}
pub fn complete(&self, owner: &WorkerThread) {
unsafe {
Latch::set_raw(&self.tasks_completed_latch);
}
owner.run_until(&self.tasks_completed_latch);
}
}
struct SendPtr<T>(*const T);
impl<T> SendPtr<T> {
fn from_ref(t: &T) -> Self {
Self(ptr::from_ref(t).cast())
}
unsafe fn as_ref(&self) -> &T {
&*self.0
}
}
unsafe impl<T> Send for SendPtr<T> {}
}
// #[cfg(test)]
// mod tests {
// use std::{cell::Cell, hint::black_box, time::Instant};
// use tracing::info;
// use super::*;
// mod tree {
// pub struct Tree<T> {
// nodes: Box<[Node<T>]>,
// root: Option<usize>,
// }
// pub struct Node<T> {
// pub leaf: T,
// pub left: Option<usize>,
// pub right: Option<usize>,
// }
// impl<T> Tree<T> {
// pub fn new(depth: usize, t: T) -> Tree<T>
// where
// T: Copy,
// {
// let mut nodes = Vec::with_capacity((0..depth).sum());
// let root = Self::build_node(&mut nodes, depth, t);
// Self {
// nodes: nodes.into_boxed_slice(),
// root: Some(root),
// }
// }
// pub fn root(&self) -> Option<usize> {
// self.root
// }
// pub fn get(&self, index: usize) -> &Node<T> {
// &self.nodes[index]
// }
// pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
// where
// T: Copy,
// {
// let node = Node {
// leaf: t,
// left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
// right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
// };
// nodes.push(node);
// nodes.len() - 1
// }
// }
// }
// const PRIMES: &'static [usize] = &[
// 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283,
// 1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409,
// 1423, 1427, 1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493,
// 1499, 1511, 1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, 1607,
// 1609, 1613, 1619, 1621, 1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721,
// 1723, 1733, 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847,
// 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
// ];
// #[cfg(feature = "spin-slow")]
// const REPEAT: usize = 0x800;
// #[cfg(not(feature = "spin-slow"))]
// const REPEAT: usize = 0x8000;
// const TREE_SIZE: usize = 10;
// fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T {
// let pool = Box::new(pool);
// let ptr = Box::into_raw(pool);
// let result = {
// let pool: &'static ThreadPool = unsafe { &*ptr };
// // pool.ensure_one_worker();
// pool.resize_to_available();
// let now = std::time::Instant::now();
// let result = pool.scope(f);
// let elapsed = now.elapsed().as_micros();
// info!("(mine) total time: {}ms", elapsed as f32 / 1e3);
// pool.resize_to(0);
// assert!(pool.global_queue.is_empty());
// result
// };
// let _pool = unsafe { Box::from_raw(ptr) };
// result
// }
// #[test]
// #[tracing_test::traced_test]
// fn rayon() {
// let pool = rayon::ThreadPoolBuilder::new()
// .num_threads(bevy_tasks::available_parallelism())
// .build()
// .unwrap();
// let now = std::time::Instant::now();
// pool.scope(|s| {
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
// s.spawn(move |_| {
// black_box(spinning(p));
// });
// }
// });
// let elapsed = now.elapsed().as_micros();
// info!("(rayon) total time: {}ms", elapsed as f32 / 1e3);
// }
// #[test]
// #[tracing_test::traced_test]
// fn rayon_join() {
// let pool = rayon::ThreadPoolBuilder::new()
// .num_threads(bevy_tasks::available_parallelism())
// .build()
// .unwrap();
// let tree = tree::Tree::new(TREE_SIZE, 1u32);
// fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
// let node = tree.get(node);
// let (l, r) = rayon::join(
// || node.left.map(|node| sum(tree, node)).unwrap_or_default(),
// || node.right.map(|node| sum(tree, node)).unwrap_or_default(),
// );
// node.leaf + l + r
// }
// let now = std::time::Instant::now();
// let sum = pool.scope(move |s| {
// let root = tree.root().unwrap();
// sum(&tree, root)
// });
// let elapsed = now.elapsed().as_micros();
// info!("(rayon) {sum} total time: {}ms", elapsed as f32 / 1e3);
// }
// #[test]
// #[tracing_test::traced_test]
// fn bevy_tasks() {
// let pool = bevy_tasks::ComputeTaskPool::get_or_init(|| {
// bevy_tasks::TaskPoolBuilder::new()
// .num_threads(bevy_tasks::available_parallelism())
// .build()
// });
// let now = std::time::Instant::now();
// pool.scope(|s| {
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
// s.spawn(async move {
// black_box(spinning(p));
// });
// }
// });
// let elapsed = now.elapsed().as_micros();
// info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3);
// }
// #[test]
// #[tracing_test::traced_test]
// fn mine() {
// std::thread_local! {
// static WAIT_COUNT: Cell<usize> = const {Cell::new(0)};
// }
// let counter = Arc::new(AtomicUsize::new(0));
// {
// let pool = ThreadPool::new();
// run_in_scope(pool, |s| {
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
// s.spawn(move |_| {
// black_box(spinning(p));
// });
// }
// });
// };
// // eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
// }
// #[test]
// #[tracing_test::traced_test]
// fn mine_join() {
// let pool = ThreadPool::new();
// let tree = tree::Tree::new(TREE_SIZE, 1u32);
// fn sum(tree: &tree::Tree<u32>, node: usize, scope: &Scope<'_>) -> u32 {
// let node = tree.get(node);
// let (l, r) = scope.join(
// |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
// |s| {
// node.right
// .map(|node| sum(tree, node, s))
// .unwrap_or_default()
// },
// );
// node.leaf + l + r
// }
// let sum = run_in_scope(pool, move |s| {
// let root = tree.root().unwrap();
// sum(&tree, root, s)
// });
// }
// #[test]
// #[tracing_test::traced_test]
// fn melange_join() {
// let pool = melange::ThreadPool::new(bevy_tasks::available_parallelism());
// let mut scope = pool.new_worker();
// let tree = tree::Tree::new(TREE_SIZE, 1u32);
// fn sum(tree: &tree::Tree<u32>, node: usize, scope: &mut melange::WorkerThread) -> u32 {
// let node = tree.get(node);
// let (l, r) = scope.join(
// |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
// |s| {
// node.right
// .map(|node| sum(tree, node, s))
// .unwrap_or_default()
// },
// );
// node.leaf + l + r
// }
// let now = Instant::now();
// let res = sum(&tree, tree.root().unwrap(), &mut scope);
// eprintln!(
// "res: {res} took {}ms",
// now.elapsed().as_micros() as f32 / 1e3
// );
// assert_ne!(res, 0);
// }
// #[test]
// #[tracing_test::traced_test]
// fn sync() {
// let now = std::time::Instant::now();
// for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
// black_box(spinning(p));
// }
// let elapsed = now.elapsed().as_micros();
// info!("(sync) total time: {}ms", elapsed as f32 / 1e3);
// }
// #[inline]
// fn spinning(i: usize) {
// #[cfg(feature = "spin-slow")]
// spinning_slow(i);
// #[cfg(not(feature = "spin-slow"))]
// spinning_fast(i);
// }
// #[inline]
// fn spinning_slow(i: usize) {
// let rng = rng::XorShift64Star::new(i as u64);
// (0..i).reduce(|a, b| {
// black_box({
// let a = rng.next_usize(a.max(1));
// ((b as f32).exp() * (a as f32).sin().cbrt()).to_bits() as usize
// })
// });
// }
// #[inline]
// fn spinning_fast(i: usize) {
// let rng = rng::XorShift64Star::new(i as u64);
// //(0..rng.next_usize(i)).reduce(|a, b| {
// (0..20).reduce(|a, b| {
// black_box({
// let a = rng.next_usize(a.max(1));
// a ^ b
// })
// });
// }
// }