commit e75094d2a5d9a0fcd74f26533025bf2cefc78fd3 Author: Janis Date: Thu Jan 30 22:52:17 2025 +0100 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..daf874e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "executor" +version = "0.1.0" +edition = "2021" + +[features] +internal_heartbeat = [] + + +[dependencies] + +futures = "0.3" +rayon = "1.10" +parking_lot = "0.12.3" +thread_local = "1.1.8" +crossbeam = "0.8.4" + +async-task = "4.7.1" + +tracing = "0.1.40" +tracing-subscriber = {version ="0.3.18", features = ["env-filter"]} + +anyhow = "1.0.89" +thiserror = "2.0" +bitflags = "2.6" +# derive_more = "1.0.0" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a0ab289 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,1084 @@ +use std::{ + cell::{OnceCell, UnsafeCell}, + collections::VecDeque, + future::Future, + mem::MaybeUninit, + num::NonZero, + pin::{pin, Pin}, + ptr::NonNull, + sync::{ + atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, + Arc, + }, + task::Context, + thread::available_parallelism, + time::Duration, +}; + +use async_task::{Runnable, Task}; +use bitflags::bitflags; +use crossbeam::{queue::SegQueue, utils::CachePadded}; +use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; +use parking_lot::{Condvar, Mutex}; +use task::{HeapTask, StackTask, TaskRef}; + +pub mod task { + use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; + + pub trait Task { + unsafe fn execute(this: *const ()); + } + + pub struct TaskRef { + ptr: *const (), + execute_fn: unsafe fn(*const ()), + } + + impl TaskRef { + pub unsafe fn new(task: *const T) -> TaskRef + where + T: Task, + { + Self { + ptr: task.cast(), + execute_fn: ::execute, + } + } + + pub unsafe fn new_raw(ptr: *const (), execute_fn: unsafe fn(*const ())) -> TaskRef { + Self { ptr, execute_fn } + } + + pub fn id(&self) -> impl Eq { + (self.ptr, self.execute_fn) + } + + /// caller must ensure that this particular task is [`Send`] + pub fn execute(self) { + unsafe { (self.execute_fn)(self.ptr) } + } + } + + unsafe impl Send for TaskRef {} + unsafe impl Sync for TaskRef {} + + pub struct StackTask { + task: UnsafeCell>, + _phantom: PhantomPinned, + } + + impl StackTask { + pub fn new(task: F) -> StackTask { + Self { + task: UnsafeCell::new(Some(task)), + _phantom: PhantomPinned, + } + } + + pub fn run(self) { + self.task.into_inner().unwrap()(); + } + + pub fn as_task_ref(self: Pin<&Self>) -> TaskRef { + unsafe { TaskRef::new(&*self) } + } + } + + impl Task for StackTask { + unsafe fn execute(this: *const ()) { + let this = &*this.cast::(); + let task = (&mut *this.task.get()).take().unwrap(); + task(); + } + } + + pub struct HeapTask { + task: F, + _phantom: PhantomPinned, + } + + impl HeapTask { + pub fn new(task: F) -> Box> { + Box::new(Self { + task, + _phantom: PhantomPinned, + }) + } + + pub unsafe fn into_static_task_ref(self: Box) -> TaskRef + where + F: 'static, + { + self.into_task_ref() + } + + pub unsafe fn into_task_ref(self: Box) -> TaskRef { + TaskRef::new(Box::into_raw(self)) + } + } + impl Task for HeapTask { + unsafe fn execute(this: *const ()) { + let this = Box::from_raw(this.cast::().cast_mut()); + (this.task)(); + } + } +} + +pub mod latch { + use std::{ + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, + }, + task::Wake, + }; + + use parking_lot::{Condvar, Mutex}; + + use crate::{ThreadPool, WorkerThread}; + + pub trait Latch { + unsafe fn set_raw(this: *const Self); + } + + pub trait Probe { + fn probe(&self) -> bool; + } + + #[derive(Debug)] + pub struct AtomicLatch(AtomicBool); + + impl AtomicLatch { + pub const fn new() -> AtomicLatch { + Self(AtomicBool::new(false)) + } + pub fn reset(&self) { + self.0.store(false, Ordering::Release); + } + } + + impl Latch for AtomicLatch { + unsafe fn set_raw(this: *const Self) { + (*this).0.store(true, Ordering::Release); + } + } + + impl Probe for AtomicLatch { + fn probe(&self) -> bool { + self.0.load(Ordering::Acquire) + } + } + + pub struct ThreadWakeLatch { + inner: AtomicLatch, + index: usize, + pool: &'static ThreadPool, + } + + impl ThreadWakeLatch { + pub const fn new(thread: &WorkerThread) -> ThreadWakeLatch { + Self { + inner: AtomicLatch::new(), + pool: thread.pool, + index: thread.index, + } + } + pub fn reset(&self) { + self.inner.reset() + } + } + + impl Latch for ThreadWakeLatch { + unsafe fn set_raw(this: *const Self) { + let (pool, index) = { + let this = &*this; + (this.pool, this.index) + }; + Latch::set_raw(&(*this).inner); + pool.wake_thread(index); + } + } + + impl Probe for ThreadWakeLatch { + fn probe(&self) -> bool { + self.inner.probe() + } + } + + pub struct MutexLatch { + mutex: Mutex, + signal: Condvar, + } + + impl MutexLatch { + pub const fn new() -> MutexLatch { + Self { + mutex: Mutex::new(false), + signal: Condvar::new(), + } + } + + pub fn wait(&self) { + let mut guard = self.mutex.lock(); + while !*guard { + self.signal.wait(&mut guard); + } + } + pub fn wait_and_reset(&self) { + let mut guard = self.mutex.lock(); + while !*guard { + self.signal.wait(&mut guard); + } + *guard = false; + } + } + + impl Latch for MutexLatch { + unsafe fn set_raw(this: *const Self) { + let mut guard = (*this).mutex.lock(); + *guard = true; + (*this).signal.notify_all(); + } + } + + pub struct CountWakeLatch { + counter: AtomicUsize, + inner: ThreadWakeLatch, + } + + impl CountWakeLatch { + pub const fn new(count: usize, thread: &WorkerThread) -> CountWakeLatch { + Self { + counter: AtomicUsize::new(count), + inner: ThreadWakeLatch::new(thread), + } + } + + pub fn increment(&self) { + self.counter.fetch_add(1, Ordering::Relaxed); + } + } + + impl Latch for CountWakeLatch { + unsafe fn set_raw(this: *const Self) { + if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 { + Latch::set_raw(&(*this).inner); + } + } + } + + impl Probe for CountWakeLatch { + fn probe(&self) -> bool { + self.inner.probe() + } + } + + pub struct LatchWaker(L); + + impl LatchWaker { + pub fn new(latch: L) -> Arc { + Arc::new(Self(latch)) + } + pub fn latch(&self) -> &L { + &self.0 + } + } + + impl Wake for LatchWaker + where + L: Latch, + { + fn wake(self: Arc) { + self.wake_by_ref(); + } + fn wake_by_ref(self: &Arc) { + unsafe { + Latch::set_raw(&self.0); + } + } + } +} + +pub struct ThreadPoolState { + num_threads: AtomicUsize, + lock: Mutex<()>, + heartbeat_state: CachePadded, +} + +bitflags! { + pub struct ThreadStatus: u8 { + const RUNNING = 1 << 0; + const SLEEPING = 1 << 1; + const SHOULD_WAKE = 1 << 2; + } +} + +pub struct ThreadState { + should_shove: AtomicBool, + shoved_task: Slot, + status: Mutex, + status_changed: Condvar, + should_terminate: AtomicLatch, +} + +impl ThreadState { + /// returns true if thread was sleeping + fn wake(&self) -> bool { + let mut guard = self.status.lock(); + guard.insert(ThreadStatus::SHOULD_WAKE); + self.status_changed.notify_all(); + guard.contains(ThreadStatus::SLEEPING) + } + + fn wait_for_running(&self) { + let mut guard = self.status.lock(); + while !guard.contains(ThreadStatus::RUNNING) { + self.status_changed.wait(&mut guard); + } + } + + fn wait_for_should_wake(&self) { + let mut guard = self.status.lock(); + while !guard.contains(ThreadStatus::SHOULD_WAKE) { + guard.insert(ThreadStatus::SLEEPING); + self.status_changed.wait(&mut guard); + } + guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); + } + + fn wait_for_should_wake_timeout(&self, timeout: Duration) { + let mut guard = self.status.lock(); + while !guard.contains(ThreadStatus::SHOULD_WAKE) { + guard.insert(ThreadStatus::SLEEPING); + if self + .status_changed + .wait_for(&mut guard, timeout) + .timed_out() + { + break; + } + } + guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); + } + + fn wait_for_termination(&self) { + let mut guard = self.status.lock(); + while guard.contains(ThreadStatus::RUNNING) { + self.status_changed.wait(&mut guard); + } + } + + fn notify_running(&self) { + let mut guard = self.status.lock(); + guard.insert(ThreadStatus::RUNNING); + self.status_changed.notify_all(); + } + + fn notify_termination(&self) { + let mut guard = self.status.lock(); + *guard = ThreadStatus::empty(); + self.status_changed.notify_all(); + } + + fn set_should_terminate(&self) { + unsafe { + Latch::set_raw(&self.should_terminate); + } + } +} + +const MAX_THREADS: usize = 32; + +pub struct ThreadPool { + threads: [CachePadded; MAX_THREADS], + pool_state: CachePadded, + global_queue: SegQueue, +} + +impl ThreadPool { + pub const fn new() -> Self { + const INITIAL_THREAD_STATE: CachePadded = CachePadded::new(ThreadState { + should_shove: AtomicBool::new(false), + shoved_task: Slot::new(), + status: Mutex::new(ThreadStatus::empty()), + status_changed: Condvar::new(), + should_terminate: AtomicLatch::new(), + }); + + Self { + threads: const { [INITIAL_THREAD_STATE; MAX_THREADS] }, + pool_state: CachePadded::new(ThreadPoolState { + num_threads: AtomicUsize::new(0), + lock: Mutex::new(()), + heartbeat_state: INITIAL_THREAD_STATE, + }), + global_queue: SegQueue::new(), + } + } + + fn threads(&self) -> &[CachePadded] { + &self.threads[..self.pool_state.num_threads.load(Ordering::Relaxed) as usize] + } + + pub fn wake_thread(&self, index: usize) -> Option { + Some(self.threads.get(index as usize)?.wake()) + } + + pub fn wake_any(&self, count: usize) -> usize { + if count > 0 { + let num_woken = self + .threads + .iter() + .filter_map(|thread| thread.wake().then_some(())) + .take(count) + .count(); + num_woken + } else { + 0 + } + } + + pub fn id(&self) -> impl Eq { + core::ptr::from_ref(self) as usize + } + + fn push_local_or_inject(&self, task: TaskRef) { + WorkerThread::with(|worker| match worker { + Some(worker) if worker.pool.id() == self.id() => worker.push_task(task), + _ => self.inject(task), + }) + } + + fn inject_many(&self, tasks: I) + where + I: Iterator, + { + let mut n = 0; + for task in tasks { + n += 1; + self.global_queue.push(task); + } + self.wake_any(n); + } + + fn inject(&self, task: TaskRef) { + self.global_queue.push(task); + + self.wake_any(1); + } + + fn resize usize>(&'static self, size: F) -> usize { + if WorkerThread::is_worker_thread() { + // acquire required here? + return self.pool_state.num_threads.load(Ordering::Acquire); + } + + let _guard = self.pool_state.lock.lock(); + + let current_size = self.pool_state.num_threads.load(Ordering::Acquire); + let new_size = size(current_size).max(MAX_THREADS); + + if new_size == current_size { + return current_size; + } + + self.pool_state + .num_threads + .store(new_size, Ordering::Release); + + match new_size.cmp(¤t_size) { + std::cmp::Ordering::Greater => { + let new_threads = &self.threads[current_size..new_size]; + + for (i, thread) in new_threads.iter().enumerate() { + std::thread::spawn(move || { + WorkerThread::worker_loop(&self, current_size + i); + }); + } + + for thread in new_threads { + thread.wait_for_running(); + } + + #[cfg(not(feature = "internal_heartbeat"))] + if current_size == 0 { + std::thread::spawn(move || { + heartbeat_loop(self); + }); + + self.pool_state.heartbeat_state.wait_for_running(); + } + } + std::cmp::Ordering::Less => { + let terminating_threads = &self.threads[new_size..current_size]; + + for thread in terminating_threads { + thread.set_should_terminate(); + } + for thread in terminating_threads { + thread.wait_for_termination(); + } + + #[cfg(not(feature = "internal_heartbeat"))] + if new_size == 0 { + self.pool_state.heartbeat_state.set_should_terminate(); + self.pool_state.heartbeat_state.wait_for_termination(); + } + } + std::cmp::Ordering::Equal => unreachable!(), + } + + new_size + } + + fn ensure_one_worker(&'static self) -> usize { + self.resize(|current| current.max(1)) + } + + fn resize_to_available(&'static self) { + self.resize_to(available_parallelism().map(NonZero::get).unwrap_or(1)); + } + fn resize_to(&'static self, new_size: usize) -> usize { + self.resize(|_| new_size) + } + fn grow_by(&'static self, num_threads: usize) -> usize { + self.resize(|current| current.saturating_add(num_threads)) + } + fn shrink_by(&'static self, num_threads: usize) -> usize { + self.resize(|current| current.saturating_sub(num_threads)) + } + fn shrink_to(&'static self, num_threads: usize) -> usize { + self.resize(|_| num_threads) + } + + fn in_worker(&'static self, f: F) -> T + where + F: FnOnce(&WorkerThread, bool) -> T + Send, + T: Send, + { + WorkerThread::with(|worker| match worker { + Some(worker) => { + if worker.pool.id() == self.id() { + self.in_worker_cross(worker, f) + } else { + f(worker, false) + } + } + None => self.in_worker_cold(f), + }) + } + + #[cold] + fn in_worker_cold(&'static self, f: F) -> T + where + F: FnOnce(&WorkerThread, bool) -> T + Send, + T: Send, + { + std::thread_local! {static LATCH: MutexLatch = const {MutexLatch::new()}}; + + LATCH.with(|latch| { + let mut result = None; + let task = StackTask::new(|| { + WorkerThread::with(|worker| { + let worker = worker.unwrap(); + + result = Some(f(worker, true)); + + unsafe { + // SAFETY: static thread-local + Latch::set_raw(latch); + } + }) + }); + + let pinned = pin!(task); + let taskref = pinned.as_ref().as_task_ref(); + self.inject(taskref); + + latch.wait_and_reset(); + result.unwrap() + }) + } + + /// run f in `self`, but block current thread until work is complete. + fn in_worker_cross(&'static self, worker: &WorkerThread, f: F) -> T + where + F: FnOnce(&WorkerThread, bool) -> T + Send, + T: Send, + { + let latch = ThreadWakeLatch::new(worker); + + let mut result = None; + + let task = pin!(StackTask::new(|| { + WorkerThread::with(|worker| { + let worker = worker.unwrap(); + + result = Some(f(worker, true)); + + unsafe { + // SAFETY: static thread-local + Latch::set_raw(&latch); + } + }) + })); + + let taskref = task.into_ref().as_task_ref(); + self.inject(taskref); + + worker.run_until(&latch); + result.unwrap() + } +} + +impl ThreadPool { + fn spawn(&'static self, f: Fn) + where + Fn: FnOnce() + Send + 'static, + { + let task = HeapTask::new(f); + + let taskref = unsafe { task.into_static_task_ref() }; + self.push_local_or_inject(taskref); + } + + fn spawn_future(&'static self, future: Fut) -> Task + where + Fut: Future + Send + 'static, + T: Send + 'static, + { + let schedule = move |runnable: Runnable| { + let taskref = unsafe { + TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| { + let this = NonNull::new_unchecked(this.cast_mut()); + + let runnable = Runnable::<()>::from_raw(this); + runnable.run(); + }) + }; + + self.push_local_or_inject(taskref); + }; + + let (runnable, task) = async_task::spawn(future, schedule); + + runnable.schedule(); + task + } + + fn spawn_async(&'static self, f: Fn) -> Task + where + Fn: FnOnce() -> Fut + Send + 'static, + Fut: Future + Send + 'static, + T: Send + 'static, + { + self.spawn_future(async move { f().await }) + } + + fn block_on(&'static self, mut future: Fut) + where + Fut: Future + Send + 'static, + T: Send + 'static, + { + let mut future = unsafe { Pin::new_unchecked(&mut future) }; + self.in_worker(|worker, _| { + let wake = LatchWaker::new(ThreadWakeLatch::new(worker)); + let ctx_waker = Arc::clone(&wake).into(); + let mut ctx = Context::from_waker(&ctx_waker); + + loop { + match future.as_mut().poll(&mut ctx) { + std::task::Poll::Ready(t) => { + return t; + } + std::task::Poll::Pending => { + worker.run_until(wake.latch()); + wake.latch().reset(); + } + } + } + }); + } +} + +pub struct WorkerThread { + queue: TaskQueue, + pool: &'static ThreadPool, + index: usize, + rng: rng::XorShift64Star, + last_heartbeat: UnsafeCell, +} + +const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) }; + +std::thread_local! { + static WORKER_THREAD_STATE: CachePadded> = const {CachePadded::new(OnceCell::new())}; +} + +impl WorkerThread { + fn info(&self) -> &ThreadState { + &self.pool.threads[self.index as usize] + } + fn is_worker_thread() -> bool { + Self::with(|worker| worker.is_some()) + } + fn with) -> T>(f: F) -> T { + WORKER_THREAD_STATE.with(|thread| f(thread.get())) + } + fn pop_task(&self) -> Option { + self.queue.pop_front() + } + fn push_task(&self, task: TaskRef) { + self.queue.push_front(task); + } + + fn drain(&self) -> impl Iterator { + self.queue.drain() + } + + fn claim_shoved_task(&self) -> Option { + if let Some(task) = self.info().shoved_task.try_take() { + return Some(task); + } + + let threads = self.pool.threads(); + if threads.is_empty() { + return None; + } + let (start, end) = threads.split_at(self.rng.next_usize(threads.len())); + + end.iter() + .chain(start) + .find_map(|thread| thread.shoved_task.try_take()) + } + + #[cold] + fn shove_task(&self) { + if let Some(task) = self.queue.pop_back() { + match self.info().shoved_task.try_put(task) { + // shoved task is occupied, reinsert into queue + Some(task) => self.queue.push_back(task), + None => { + // wake thread to execute task + self.pool.wake_any(1); + } + } + } + } + + fn execute(&self, task: TaskRef) { + self.try_promote(); + task.execute(); + } + + fn try_promote(&self) { + #[cfg(feature = "internal_heartbeat")] + let now = std::time::Instant::now(); + // SAFETY: workerthread is thread-local non-sync + + #[cfg(feature = "internal_heartbeat")] + let should_shove = + unsafe { *self.last_heartbeat.get() }.duration_since(now) > HEARTBEAT_INTERVAL; + #[cfg(not(feature = "internal_heartbeat"))] + let should_shove = self.info().should_shove.load(Ordering::Acquire); + + if should_shove { + // SAFETY: workerthread is thread-local non-sync + #[cfg(feature = "internal_heartbeat")] + unsafe { + *&mut *self.last_heartbeat.get() = now; + } + #[cfg(not(feature = "internal_heartbeat"))] + self.info().should_shove.store(false, Ordering::Release); + self.shove_task(); + } + } + + fn find_any_task(&self) -> Option { + // TODO: attempt stealing work here, too. + self.pop_task() + .or_else(|| self.claim_shoved_task()) + .or_else(|| self.pool.global_queue.pop()) + } + + fn run_until(&self, latch: &L) + where + L: Probe, + { + if !latch.probe() { + self.run_until_cold(latch); + } + } + + #[cold] + fn run_until_cold(&self, latch: &L) + where + L: Probe, + { + while !latch.probe() { + self.run_until_inner(); + } + } + + fn run_until_inner(&self) { + match self.find_any_task() { + Some(task) => { + self.execute(task); + } + None => { + self.info().wait_for_should_wake(); + } + } + } + + fn worker_loop(pool: &'static ThreadPool, index: usize) { + let info = &pool.threads()[index as usize]; + + WORKER_THREAD_STATE.with(|worker| { + let worker = worker.get_or_init(|| WorkerThread { + queue: TaskQueue::new(), + pool, + index, + rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64), + last_heartbeat: UnsafeCell::new(std::time::Instant::now()), + }); + + info.notify_running(); + // info.notify_running(); + worker.run_until(&info.should_terminate); + + for task in worker.drain() { + pool.inject(task); + } + + if let Some(task) = info.shoved_task.try_take() { + pool.inject(task); + } + }); + + info.notify_termination(); + } +} + +fn heartbeat_loop(pool: &'static ThreadPool) { + let state = &pool.pool_state.heartbeat_state; + + state.notify_running(); + let mut i = 0; + while !state.should_terminate.probe() { + let threads = pool.threads(); + if threads.is_empty() { + break; + } + + if i >= threads.len() { + i = 0; + continue; + } + + threads[i].should_shove.store(true, Ordering::Relaxed); + i += 1; + + let interval = HEARTBEAT_INTERVAL / threads.len() as u32; + + state.wait_for_should_wake_timeout(interval); + } + + state.notify_termination(); +} + +pub struct TaskQueue(UnsafeCell>); + +impl TaskQueue { + /// Creates a new [`TaskQueue`]. + const fn new() -> Self { + Self(UnsafeCell::new(VecDeque::new())) + } + fn get_mut(&self) -> &mut VecDeque { + unsafe { &mut *self.0.get() } + } + fn pop_front(&self) -> Option { + self.get_mut().pop_front() + } + fn pop_back(&self) -> Option { + self.get_mut().pop_back() + } + fn push_back(&self, t: T) { + self.get_mut().push_back(t); + } + fn push_front(&self, t: T) { + self.get_mut().push_front(t); + } + fn take(&self) -> VecDeque { + let this = core::mem::replace(self.get_mut(), VecDeque::new()); + this + } + fn drain(&self) -> impl Iterator { + self.take().into_iter() + } +} + +bitflags! { + #[derive(Debug, Clone, Copy)] + pub struct SlotState: u8 { + const LOCKED = 1 << 1; + const OCCUPIED = 1 << 2; + } +} + +impl From for u8 { + fn from(value: SlotState) -> Self { + value.bits() + } +} + +pub struct Slot { + slot: UnsafeCell>, + state: AtomicU8, +} + +unsafe impl Send for Slot where T: Send {} +unsafe impl Sync for Slot where T: Send {} + +impl Drop for Slot { + fn drop(&mut self) { + if core::mem::needs_drop::() { + if SlotState::from_bits(*self.state.get_mut()) + .unwrap() + .contains(SlotState::OCCUPIED) + { + unsafe { + self.slot.get().drop_in_place(); + } + } + } + } +} + +impl Slot { + pub const fn new() -> Slot { + Self { + slot: UnsafeCell::new(MaybeUninit::uninit()), + state: AtomicU8::new(SlotState::empty().bits()), + } + } + + pub fn try_put(&self, t: T) -> Option { + match self.state.compare_exchange( + SlotState::empty().into(), + SlotState::LOCKED.into(), + Ordering::Acquire, + Ordering::Relaxed, + ) { + Err(_) => Some(t), + Ok(_) => { + let slot = self.slot.get(); + // SAFETY: we hold LOCKED on the spinlock + unsafe { (*slot).write(t) }; + + // release lock + self.state + .store(SlotState::OCCUPIED.into(), Ordering::Release); + None + } + } + } + + pub fn try_take(&self) -> Option { + match self.state.compare_exchange( + SlotState::OCCUPIED.into(), + SlotState::LOCKED.into(), + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => { + let slot = self.slot.get(); + // SAFETY: we hold LOCKED on the spinlock + let t = unsafe { (*slot).assume_init_read() }; + + // release lock + self.state + .store(SlotState::empty().into(), Ordering::Release); + Some(t) + } + Err(_) => None, + } + } +} + +mod rng { + use core::cell::Cell; + + pub struct XorShift64Star { + state: Cell, + } + + impl XorShift64Star { + /// Initializes the prng with a seed. Provided seed must be nonzero. + pub fn new(seed: u64) -> Self { + XorShift64Star { + state: Cell::new(seed), + } + } + + /// Returns a pseudorandom number. + pub fn next(&self) -> u64 { + let mut x = self.state.get(); + debug_assert_ne!(x, 0); + x ^= x >> 12; + x ^= x << 25; + x ^= x >> 27; + self.state.set(x); + x.wrapping_mul(0x2545_f491_4f6c_dd1d) + } + + /// Return a pseudorandom number from `0..n`. + pub fn next_usize(&self, n: usize) -> usize { + (self.next() % n as u64) as usize + } + } +} + +#[cfg(test)] +mod tests { + use std::cell::Cell; + + use super::*; + + const PRIMES: &'static [usize] = &[ + 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, + 1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, + 1423, 1427, 1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, + 1499, 1511, 1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, 1607, + 1609, 1613, 1619, 1621, 1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721, + 1723, 1733, 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847, + 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, + ]; + + fn run_on_static_pool(f: impl FnOnce(&'static ThreadPool)) { + let pool = Box::new(ThreadPool::new()); + let ptr = Box::into_raw(pool); + + { + let pool: &'static ThreadPool = unsafe { &*ptr }; + pool.ensure_one_worker(); + f(pool); + pool.resize_to(0); + assert!(pool.global_queue.pop().is_none()); + } + + let _pool = unsafe { Box::from_raw(ptr) }; + } + + #[test] + fn spawn_random() { + std::thread_local! {static WAIT_COUNT: Cell= Cell::new(0);} + run_on_static_pool(|pool| { + for &p in PRIMES { + pool.spawn(move || { + std::thread::sleep(Duration::from_micros(p as u64)); + }); + } + }); + } +}