From b83bfeca51016d416fc316d079fea9cbb2f8c778 Mon Sep 17 00:00:00 2001 From: Janis Date: Sat, 1 Feb 2025 00:42:33 +0100 Subject: [PATCH] chili-like executor for joins --- Cargo.toml | 3 +- benches/join.rs | 169 ++++++++++++ rust-toolchain | 1 + src/job/mod.rs | 139 ++++++++++ src/lib.rs | 104 ++++++- src/melange.rs | 722 ++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 1126 insertions(+), 12 deletions(-) create mode 100644 benches/join.rs create mode 100644 rust-toolchain create mode 100644 src/job/mod.rs create mode 100644 src/melange.rs diff --git a/Cargo.toml b/Cargo.toml index 31f526e..8e5c658 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,10 +17,11 @@ never-local = [] futures = "0.3" rayon = "1.10" bevy_tasks = "0.15.1" -parking_lot = "0.12.3" +parking_lot = {version = "0.12.3"} thread_local = "1.1.8" crossbeam = "0.8.4" st3 = "0.4" +chili = "0.2.0" async-task = "4.7.1" diff --git a/benches/join.rs b/benches/join.rs new file mode 100644 index 0000000..0b3db14 --- /dev/null +++ b/benches/join.rs @@ -0,0 +1,169 @@ +#![feature(test)] +use std::{ + sync::{atomic::AtomicUsize, Arc}, + thread, + time::Duration, +}; + +use bevy_tasks::available_parallelism; +use executor::{self}; +use test::Bencher; +use tree::Node; + +extern crate test; + +mod tree { + + pub struct Tree { + nodes: Box<[Node]>, + root: Option, + } + pub struct Node { + pub leaf: T, + pub left: Option, + pub right: Option, + } + + impl Tree { + pub fn new(depth: usize, t: T) -> Tree + where + T: Copy, + { + let mut nodes = Vec::with_capacity((0..depth).sum()); + let root = Self::build_node(&mut nodes, depth, t); + Self { + nodes: nodes.into_boxed_slice(), + root: Some(root), + } + } + + pub fn root(&self) -> Option { + self.root + } + + pub fn get(&self, index: usize) -> &Node { + &self.nodes[index] + } + + pub fn build_node(nodes: &mut Vec>, depth: usize, t: T) -> usize + where + T: Copy, + { + let node = Node { + leaf: t, + left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), + right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), + }; + nodes.push(node); + nodes.len() - 1 + } + } +} + +const PRIMES: &'static [usize] = &[ + 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289, + 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, 1423, 1427, + 1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, 1499, 1511, 1523, + 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, 1607, 1609, 1613, 1619, 1621, + 1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721, 1723, 1733, 1741, 1747, 1753, + 1759, 1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847, 1861, 1867, 1871, 1873, 1877, 1879, + 1889, 1901, 1907, +]; + +const REPEAT: usize = 0x800; +const TREE_SIZE: usize = 14; + +#[bench] +fn join_melange(b: &mut Bencher) { + let pool = executor::melange::ThreadPool::new(available_parallelism()); + + let mut scope = pool.new_worker(); + + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum( + tree: &tree::Tree, + node: usize, + scope: &mut executor::melange::WorkerThread, + ) -> u32 { + let node = tree.get(node); + let (l, r) = scope.join( + |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), + |s| { + node.right + .map(|node| sum(tree, node, s)) + .unwrap_or_default() + }, + ); + + node.leaf + l + r + } + + b.iter(move || { + assert_ne!(sum(&tree, tree.root().unwrap(), &mut scope), 0); + }); +} + +#[bench] +fn join_sync(b: &mut Bencher) { + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum(tree: &tree::Tree, node: usize) -> u32 { + let node = tree.get(node); + let (l, r) = ( + node.left.map(|node| sum(tree, node)).unwrap_or_default(), + node.right.map(|node| sum(tree, node)).unwrap_or_default(), + ); + + node.leaf + l + r + } + + b.iter(move || { + assert_ne!(sum(&tree, tree.root().unwrap()), 0); + }); +} + +#[bench] +fn join_chili(b: &mut Bencher) { + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum(tree: &tree::Tree, node: usize, scope: &mut chili::Scope<'_>) -> u32 { + let node = tree.get(node); + let (l, r) = scope.join( + |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), + |s| { + node.right + .map(|node| sum(tree, node, s)) + .unwrap_or_default() + }, + ); + + node.leaf + l + r + } + + b.iter(move || { + assert_ne!( + sum(&tree, tree.root().unwrap(), &mut chili::Scope::global()), + 0 + ); + }); +} + +#[bench] +fn join_rayon(b: &mut Bencher) { + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum(tree: &tree::Tree, node: usize) -> u32 { + let node = tree.get(node); + let (l, r) = rayon::join( + || node.left.map(|node| sum(tree, node)).unwrap_or_default(), + || node.right.map(|node| sum(tree, node)).unwrap_or_default(), + ); + + node.leaf + l + r + } + + b.iter(move || { + assert_ne!(sum(&tree, tree.root().unwrap()), 0); + }); +} diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000..bf867e0 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly diff --git a/src/job/mod.rs b/src/job/mod.rs new file mode 100644 index 0000000..386ce78 --- /dev/null +++ b/src/job/mod.rs @@ -0,0 +1,139 @@ +///! Rayon's job logic +use std::{cell::UnsafeCell, marker::PhantomPinned, sync::atomic::AtomicBool}; + +use crate::latch::Latch; + +pub trait Job { + unsafe fn execute(this: *const (), args: Args); +} + +pub struct JobRef { + this: *const (), + execute_fn: unsafe fn(*const (), Args), +} + +unsafe impl Send for JobRef {} +unsafe impl Sync for JobRef {} + +impl JobRef { + pub unsafe fn new(data: *const T) -> JobRef + where + T: Job, + { + Self { + this: data.cast(), + execute_fn: >::execute, + } + } + + pub fn id(&self) -> impl Eq { + (self.this, self.execute_fn) + } + + pub unsafe fn execute(self, args: Args) { + unsafe { (self.execute_fn)(self.this, args) } + } +} + +pub struct StackJob +where + L: Latch + Sync, +{ + task: UnsafeCell>, + latch: L, + _phantom: PhantomPinned, +} + +impl StackJob +where + L: Latch + Sync, +{ + pub fn new(task: F, latch: L) -> StackJob { + Self { + task: UnsafeCell::new(Some(task)), + latch, + _phantom: PhantomPinned, + } + } + + pub unsafe fn take_once(self) -> F { + self.task.into_inner().unwrap() + } + + #[inline] + pub fn run(self, args: Args) + where + F: FnOnce(Args), + { + self.task.into_inner().unwrap()(args); + } + + #[inline] + pub unsafe fn as_task_ref(&self) -> JobRef + where + F: FnOnce(Args), + { + unsafe { JobRef::::new(self) } + } +} + +impl Job for StackJob +where + F: FnOnce(Args), + L: Latch + Sync, +{ + unsafe fn execute(this: *const (), args: Args) { + let this = &*this.cast::(); + let func = (*this.task.get()).take().unwrap(); + func(args); + Latch::set_raw(&this.latch); + // set internal latch here? + } +} + +pub struct HeapJob +where + F: Send, +{ + func: F, + _phantom: PhantomPinned, +} + +impl HeapJob +where + F: Send, +{ + pub fn new(task: F) -> Box> { + Box::new(Self { + func: task, + _phantom: PhantomPinned, + }) + } + + #[inline] + pub unsafe fn into_static_task_ref(self: Box) -> JobRef + where + F: FnOnce(Args) + 'static, + { + self.into_task_ref() + } + + #[inline] + pub unsafe fn into_task_ref(self: Box) -> JobRef + where + F: FnOnce(Args), + { + JobRef::new(Box::into_raw(self)) + } +} + +impl Job for HeapJob +where + F: FnOnce(Args) + Send, +{ + unsafe fn execute(this: *const (), args: Args) { + let this = Box::from_raw(this.cast::().cast_mut()); + (this.func)(args); + // set internal latch here? + } +} diff --git a/src/lib.rs b/src/lib.rs index 068f513..1b5c87d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,8 @@ use scope::Scope; use task::{HeapTask, StackTask, TaskRef}; use tracing::debug; +pub mod job; + pub mod task { use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; @@ -59,11 +61,18 @@ pub mod task { (self.ptr, self.execute_fn) } - /// caller must ensure that this particular task is [`Send`] #[inline] pub fn execute(self) { unsafe { (self.execute_fn)(self.ptr) } } + #[inline] + pub unsafe fn execute_with_scope(self, scope: &mut T) { + unsafe { + core::mem::transmute::<_, unsafe fn(*const (), &mut T)>(self.execute_fn)( + self.ptr, scope, + ) + } + } } unsafe impl Send for TaskRef {} @@ -191,6 +200,38 @@ pub mod latch { } } + pub struct ClosureLatch { + set: S, + probe: P, + } + + impl ClosureLatch { + pub fn new(set: S, probe: P) -> Self { + Self { set, probe } + } + pub fn new_boxed(set: S, probe: P) -> Box { + Box::new(Self { set, probe }) + } + } + + impl Latch for ClosureLatch + where + S: Fn(), + { + unsafe fn set_raw(this: *const Self) { + let this = &*this; + (this.set)(); + } + } + impl Probe for ClosureLatch + where + P: Fn() -> bool, + { + fn probe(&self) -> bool { + (self.probe)() + } + } + pub struct ThreadWakeLatch { inner: AtomicLatch, index: usize, @@ -337,6 +378,8 @@ pub mod latch { } } +pub mod melange; + pub struct ThreadPoolState { num_threads: AtomicUsize, lock: Mutex<()>, @@ -344,6 +387,7 @@ pub struct ThreadPoolState { } bitflags! { +#[derive(Clone)] pub struct ThreadStatus: u8 { const RUNNING = 1 << 0; const SLEEPING = 1 << 1; @@ -366,9 +410,16 @@ pub struct ThreadState { } impl ThreadControl { + pub const fn new() -> Self { + Self { + status: Mutex::new(ThreadStatus::empty()), + status_changed: Condvar::new(), + should_terminate: AtomicLatch::new(), + } + } /// returns true if thread was sleeping #[inline] - fn wake(&self) -> bool { + pub fn wake(&self) -> bool { let mut guard = self.status.lock(); guard.insert(ThreadStatus::SHOULD_WAKE); self.status_changed.notify_all(); @@ -376,7 +427,7 @@ impl ThreadControl { } #[inline] - fn wait_for_running(&self) { + pub fn wait_for_running(&self) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::RUNNING) { self.status_changed.wait(&mut guard); @@ -384,7 +435,7 @@ impl ThreadControl { } #[inline] - fn wait_for_should_wake(&self) { + pub fn wait_for_should_wake(&self) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::SHOULD_WAKE) { guard.insert(ThreadStatus::SLEEPING); @@ -394,7 +445,7 @@ impl ThreadControl { } #[inline] - fn wait_for_should_wake_timeout(&self, timeout: Duration) { + pub fn wait_for_should_wake_timeout(&self, timeout: Duration) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::SHOULD_WAKE) { guard.insert(ThreadStatus::SLEEPING); @@ -410,7 +461,7 @@ impl ThreadControl { } #[inline] - fn wait_for_termination(&self) { + pub fn wait_for_termination(&self) { let mut guard = self.status.lock(); while guard.contains(ThreadStatus::RUNNING) { self.status_changed.wait(&mut guard); @@ -418,21 +469,21 @@ impl ThreadControl { } #[inline] - fn notify_running(&self) { + pub fn notify_running(&self) { let mut guard = self.status.lock(); guard.insert(ThreadStatus::RUNNING); self.status_changed.notify_all(); } #[inline] - fn notify_termination(&self) { + pub fn notify_termination(&self) { let mut guard = self.status.lock(); *guard = ThreadStatus::empty(); self.status_changed.notify_all(); } #[inline] - fn notify_should_terminate(&self) { + pub fn notify_should_terminate(&self) { unsafe { Latch::set_raw(&self.should_terminate); } @@ -1502,7 +1553,7 @@ mod scope { #[cfg(test)] mod tests { - use std::{cell::Cell, hint::black_box}; + use std::{cell::Cell, hint::black_box, time::Instant}; use tracing::info; @@ -1643,7 +1694,7 @@ mod tests { let elapsed = now.elapsed().as_micros(); - info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); + info!("(rayon) {sum} total time: {}ms", elapsed as f32 / 1e3); } #[test] @@ -1717,6 +1768,37 @@ mod tests { }); } + #[test] + #[tracing_test::traced_test] + fn melange_join() { + let pool = melange::ThreadPool::new(bevy_tasks::available_parallelism()); + + let mut scope = pool.new_worker(); + + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum(tree: &tree::Tree, node: usize, scope: &mut melange::WorkerThread) -> u32 { + let node = tree.get(node); + let (l, r) = scope.join( + |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), + |s| { + node.right + .map(|node| sum(tree, node, s)) + .unwrap_or_default() + }, + ); + + node.leaf + l + r + } + let now = Instant::now(); + let res = sum(&tree, tree.root().unwrap(), &mut scope); + eprintln!( + "res: {res} took {}ms", + now.elapsed().as_micros() as f32 / 1e3 + ); + assert_ne!(res, 0); + } + #[test] #[tracing_test::traced_test] fn sync() { diff --git a/src/melange.rs b/src/melange.rs new file mode 100644 index 0000000..40b1ce9 --- /dev/null +++ b/src/melange.rs @@ -0,0 +1,722 @@ +use std::{ + cell::Cell, + collections::VecDeque, + marker::PhantomPinned, + ops::{Deref, DerefMut}, + pin::pin, + ptr::{self, NonNull}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Weak, + }, + thread, + time::{Duration, Instant}, +}; + +use crossbeam::utils::CachePadded; +use parking_lot::{Condvar, Mutex}; + +use crate::{latch::*, task::*, ThreadControl, ThreadStatus}; +mod job { + use std::{ + cell::{Cell, UnsafeCell}, + collections::VecDeque, + mem::ManuallyDrop, + panic::{self, AssertUnwindSafe}, + ptr::NonNull, + sync::atomic::{AtomicU8, Ordering}, + thread::{self, Thread}, + }; + + use super::WorkerThread as Scope; + + enum Poll { + Pending, + Ready, + Locked, + } + + #[derive(Debug, Default)] + pub struct Future { + state: AtomicU8, + /// Can only be accessed if `state` is `Poll::Locked`. + waiting_thread: UnsafeCell>, + /// Can only be written if `state` is `Poll::Locked` and read if `state` is + /// `Poll::Ready`. + val: UnsafeCell>>>, + } + + impl Future { + pub fn poll(&self) -> bool { + self.state.load(Ordering::Acquire) == Poll::Ready as u8 + } + + pub fn wait(&self) -> Option> { + loop { + let result = self.state.compare_exchange( + Poll::Pending as u8, + Poll::Locked as u8, + Ordering::AcqRel, + Ordering::Acquire, + ); + + match result { + Ok(_) => { + // SAFETY: + // Lock is acquired, only we are accessing `self.waiting_thread`. + unsafe { *self.waiting_thread.get() = Some(thread::current()) }; + + self.state.store(Poll::Pending as u8, Ordering::Release); + + thread::park(); + + // Skip yielding after being woken up. + continue; + } + Err(state) if state == Poll::Ready as u8 => { + // SAFETY: + // `state` is `Poll::Ready` only after `Self::complete` + // releases the lock. + // + // Calling `Self::complete` when `state` is `Poll::Ready` + // cannot mutate `self.val`. + break unsafe { (*self.val.get()).take().map(|b| *b) }; + } + _ => (), + } + + thread::yield_now(); + } + } + + pub fn complete(&self, val: thread::Result) { + let val = Box::new(val); + + loop { + let result = self.state.compare_exchange( + Poll::Pending as u8, + Poll::Locked as u8, + Ordering::AcqRel, + Ordering::Acquire, + ); + + match result { + Ok(_) => break, + Err(_) => thread::yield_now(), + } + } + + // SAFETY: + // Lock is acquired, only we are accessing `self.val`. + unsafe { + *self.val.get() = Some(val); + } + + // SAFETY: + // Lock is acquired, only we are accessing `self.waiting_thread`. + if let Some(thread) = unsafe { (*self.waiting_thread.get()).take() } { + thread.unpark(); + } + + self.state.store(Poll::Ready as u8, Ordering::Release); + } + } + + pub struct JobStack { + /// All code paths should call either `Job::execute` or `Self::unwrap` to + /// avoid a potential memory leak. + f: UnsafeCell>, + } + + impl JobStack { + pub fn new(f: F) -> Self { + Self { + f: UnsafeCell::new(ManuallyDrop::new(f)), + } + } + + /// SAFETY: + /// It should only be called once. + pub unsafe fn take_once(&self) -> F { + // SAFETY: + // No `Job` has has been executed, therefore `self.f` has not yet been + // `take`n. + unsafe { ManuallyDrop::take(&mut *self.f.get()) } + } + } + + /// `Job` is only sent, not shared between threads. + /// + /// When popped from the `JobQueue`, it gets copied before sending across + /// thread boundaries. + #[derive(Clone, Debug)] + pub struct Job { + stack: NonNull, + harness: unsafe fn(&mut Scope, NonNull, NonNull), + fut: Cell>>>, + } + + impl Job { + pub fn new(stack: &JobStack) -> Self + where + F: FnOnce(&mut Scope) -> T + Send, + T: Send, + { + /// SAFETY: + /// It should only be called while the `stack` is still alive. + unsafe fn harness( + scope: &mut Scope, + stack: NonNull, + fut: NonNull, + ) where + F: FnOnce(&mut Scope) -> T + Send, + T: Send, + { + // SAFETY: + // The `stack` is still alive. + let stack: &JobStack = unsafe { stack.cast().as_ref() }; + // SAFETY: + // This is the first call to `take_once` since `Job::execute` + // (the only place where this harness is called) is called only + // after the job has been popped. + let f = unsafe { stack.take_once() }; + // SAFETY: + // Before being popped, the `JobQueue` allocates and stores a + // `Future` in `self.fur_or_next` that should get passed here. + let fut: &Future = unsafe { fut.cast().as_ref() }; + + fut.complete(panic::catch_unwind(AssertUnwindSafe(|| f(scope)))); + } + + Self { + stack: NonNull::from(stack).cast(), + harness: harness::, + fut: Cell::new(None), + } + } + + pub fn is_waiting(&self) -> bool { + self.fut.get().is_none() + } + + pub fn eq(&self, other: &Job) -> bool { + self.stack == other.stack + } + + /// SAFETY: + /// It should only be called after being popped from a `JobQueue`. + pub unsafe fn poll(&self) -> bool { + self.fut + .get() + .map(|fut| { + // SAFETY: + // Before being popped, the `JobQueue` allocates and stores a + // `Future` in `self.fur_or_next` that should get passed here. + let fut = unsafe { fut.as_ref() }; + fut.poll() + }) + .unwrap_or_default() + } + + /// SAFETY: + /// It should only be called after being popped from a `JobQueue`. + pub unsafe fn wait(&self) -> Option> { + self.fut.get().and_then(|fut| { + // SAFETY: + // Before being popped, the `JobQueue` allocates and stores a + // `Future` in `self.fur_or_next` that should get passed here. + let result = unsafe { fut.as_ref().wait() }; + // SAFETY: + // We only can drop the `Box` *after* waiting on the `Future` + // in order to ensure unique access. + unsafe { + drop(Box::from_raw(fut.as_ptr())); + } + + result + }) + } + + /// SAFETY: + /// It should only be called in the case where the job has been popped + /// from the front and will not be `Job::Wait`ed. + pub unsafe fn drop(&self) { + if let Some(fut) = self.fut.get() { + // SAFETY: + // Before being popped, the `JobQueue` allocates and store a + // `Future` in `self.fur_or_next` that should get passed here. + unsafe { + drop(Box::from_raw(fut.as_ptr())); + } + } + } + } + + impl Job { + /// SAFETY: + /// It should only be called while the `JobStack` it was created with is + /// still alive and after being popped from a `JobQueue`. + pub unsafe fn execute(&self, scope: &mut Scope) { + // SAFETY: + // Before being popped, the `JobQueue` allocates and store a + // `Future` in `self.fur_or_next` that should get passed here. + unsafe { + (self.harness)(scope, self.stack, self.fut.get().unwrap()); + } + } + } + + // SAFETY: + // The job's `stack` will only be accessed after acquiring a lock (in + // `Future`), while `prev` and `fut_or_next` are never accessed after being + // sent across threads. + unsafe impl Send for Job {} + + #[derive(Debug, Default)] + pub struct JobQueue(VecDeque>); + + impl JobQueue { + pub fn len(&self) -> usize { + self.0.len() + } + + /// SAFETY: + /// Any `Job` pushed onto the queue should alive at least until it gets + /// popped. + pub unsafe fn push_back(&mut self, job: &Job) { + self.0.push_back(NonNull::from(job).cast()); + } + + pub fn pop_back(&mut self) { + self.0.pop_back(); + } + + pub fn pop_front(&mut self) -> Option { + // SAFETY: + // `Job` is still alive as per contract in `push_back`. + let job = unsafe { self.0.pop_front()?.as_ref() }; + job.fut + .set(Some(Box::leak(Box::new(Future::default())).into())); + + Some(job.clone()) + } + } +} + +//use job::{Future, Job, JobQueue, JobStack}; +use crate::job::{Job, JobRef, StackJob}; + +struct ThreadState { + control: ThreadControl, +} + +struct Heartbeat { + is_set: Weak, + last_time: Cell, +} + +pub struct SharedContext { + shared_tasks: Vec>, + heartbeats: Vec>, + rng: crate::rng::XorShift64Star, +} + +pub struct Context { + shared: Mutex, + threads: Box<[CachePadded]>, + heartbeat_control: CachePadded, + task_shared: Condvar, +} + +pub struct ThreadPool { + context: Arc, +} + +impl SharedContext { + fn new_heartbeat(&mut self) -> (Arc, usize) { + let is_set = Arc::new(AtomicBool::new(true)); + let heartbeat = Heartbeat { + is_set: Arc::downgrade(&is_set), + last_time: Cell::new(Instant::now()), + }; + + let index = match self.heartbeats.iter().position(|a| a.is_none()) { + Some(i) => { + self.heartbeats[i] = Some(heartbeat); + i + } + None => { + self.heartbeats.push(Some(heartbeat)); + self.shared_tasks.push(None); + self.heartbeats.len() - 1 + } + }; + + (is_set, index) + } + + fn pop_first_task(&mut self) -> Option { + self.shared_tasks + .iter_mut() + .filter_map(|task| task.take()) + .next() + } + + fn pop_random_task(&mut self) -> Option { + let i = self.rng.next_usize(self.shared_tasks.len()); + let (a, b) = self.shared_tasks.split_at_mut(i); + a.into_iter().chain(b).filter_map(|task| task.take()).next() + } +} + +std::thread_local! { + static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null())}; +} + +pub struct WorkerThread { + context: Arc, + index: usize, + queue: VecDeque, + heartbeat: Arc, + join_count: u8, + sleep_count: usize, + _marker: PhantomPinned, +} + +impl WorkerThread { + fn new(context: Arc, heartbeat: Arc, index: usize) -> WorkerThread { + WorkerThread { + context, + index, + queue: VecDeque::default(), + join_count: 0, + heartbeat, + sleep_count: 0, + _marker: PhantomPinned, + } + } + unsafe fn set_current(this: *const Self) { + WORKER_THREAD_STATE.with(|ptr| { + assert!(ptr.get().is_null()); + ptr.set(this); + }); + } + unsafe fn unset_current() { + WORKER_THREAD_STATE.with(|ptr| { + assert!(!ptr.get().is_null()); + ptr.set(ptr::null()); + }); + } + unsafe fn current() -> *const WorkerThread { + let ptr = WORKER_THREAD_STATE.with(|ptr| ptr.get()); + + ptr + } + fn state(&self) -> &CachePadded { + &self.context.threads[self.index] + } + fn control(&self) -> &ThreadControl { + &self.context.threads[self.index].control + } + fn shared(&self) -> &Mutex { + &self.context.shared + } + fn ctx(&self) -> &Arc { + &self.context + } + + fn with) -> T>(f: F) -> T { + WORKER_THREAD_STATE.with(|worker| { + f(unsafe { NonNull::new(worker.get().cast_mut()).map(|ptr| ptr.as_ref()) }) + }) + } + + fn with_mut) -> T>(f: F) -> T { + WORKER_THREAD_STATE.with(|worker| { + f(unsafe { NonNull::new(worker.get().cast_mut()).map(|mut ptr| ptr.as_mut()) }) + }) + } +} + +struct CurrentWorker; + +impl Deref for CurrentWorker { + type Target = WorkerThread; + + fn deref(&self) -> &Self::Target { + unsafe { + NonNull::new(WorkerThread::current().cast_mut()) + .unwrap() + .as_ref() + } + } +} + +impl DerefMut for CurrentWorker { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { + NonNull::new(WorkerThread::current().cast_mut()) + .unwrap() + .as_mut() + } + } +} + +// impl Drop for WorkerThread { +// fn drop(&mut self) { +// WORKER_THREAD_STATE.with(|ptr| { +// assert!(!ptr.get().is_null()); +// ptr.set(ptr::null()); +// }); +// } +// } + +impl WorkerThread { + fn worker(self) { + { + let worker = Box::leak(Box::new(self)); + unsafe { + WorkerThread::set_current(worker); + } + } + + CurrentWorker.control().notify_running(); + + loop { + let task = { CurrentWorker.shared().lock().pop_first_task() }; + + if let Some(task) = task { + CurrentWorker.execute_job(task); + } + + if CurrentWorker.control().should_terminate.probe() { + break; + } + + let mut guard = CurrentWorker.shared().lock(); + CurrentWorker.ctx().task_shared.wait(&mut guard); + } + + CurrentWorker.control().notify_termination(); + unsafe { + let worker = Box::from_raw(WorkerThread::current().cast_mut()); + WorkerThread::unset_current(); + } + } + + fn execute_job(&mut self, job: JobRef) { + unsafe { core::mem::transmute::>(job).execute(self) }; + } + + #[cold] + fn heartbeat_cold(&mut self) { + let mut guard = self.context.shared.lock(); + + if guard.shared_tasks[self.index].is_none() { + if let Some(task) = self.queue.pop_front() { + guard.shared_tasks[self.index] = Some(task); + self.context.task_shared.notify_one(); + } + } + + self.heartbeat.store(false, Ordering::Relaxed); + } + + pub fn join(&mut self, a: A, b: B) -> (RA, RB) + where + A: FnOnce(&mut WorkerThread) -> RA + Send, + B: FnOnce(&mut WorkerThread) -> RB + Send, + RA: Send, + RB: Send, + { + self.join_with_every::<64, _, _, _, _>(a, b) + } + + pub fn join_with_every(&mut self, a: A, b: B) -> (RA, RB) + where + A: FnOnce(&mut WorkerThread) -> RA + Send, + B: FnOnce(&mut WorkerThread) -> RB + Send, + RA: Send, + RB: Send, + { + self.join_count = self.join_count.wrapping_add(1) % T; + + if self.join_count == 0 || self.queue.len() < 3 { + self.join_heartbeat(a, b) + } else { + self.join_seq(a, b) + } + } + + fn join_seq(&mut self, a: A, b: B) -> (RA, RB) + where + A: FnOnce(&mut WorkerThread) -> RA + Send, + B: FnOnce(&mut WorkerThread) -> RB + Send, + RA: Send, + RB: Send, + { + let rb = b(self); + let ra = a(self); + + (ra, rb) + } + + fn join_heartbeat(&mut self, a: A, b: B) -> (RA, RB) + where + A: FnOnce(&mut WorkerThread) -> RA + Send, + B: FnOnce(&mut WorkerThread) -> RB + Send, + RA: Send, + RB: Send, + { + let mut ra = None; + let a = |scope: &mut WorkerThread| { + if scope.heartbeat.load(Ordering::Relaxed) { + scope.heartbeat_cold(); + } + + ra = Some(a(scope)); + }; + + let latch = AtomicLatch::new(); + let ctx = self.context.clone(); + let idx = self.index; + let stack = StackJob::new(a, latch); + let task: JobRef = + unsafe { core::mem::transmute::, JobRef>(stack.as_task_ref()) }; + + let id = task.id(); + self.queue.push_back(task); + + let rb = b(self); + + if !latch.probe() { + if let Some(job) = self.queue.pop_back() { + if job.id() == id { + unsafe { + (stack.take_once())(self); + } + return (ra.unwrap(), rb); + } else { + self.queue.push_back(job); + } + } + } + + self.run_until(&latch); + + (ra.unwrap(), rb) + } + + fn run_until(&mut self, latch: &L) { + if !latch.probe() { + self.run_until_cold(latch); + } + } + + #[cold] + fn run_until_cold(&mut self, latch: &L) { + let job = self.shared().lock().shared_tasks[self.index].take(); + if let Some(job) = job { + self.execute_job(job); + } + + while !latch.probe() { + let job = self.context.shared.lock().pop_first_task(); + if let Some(job) = job { + self.execute_job(job); + } + } + } +} + +impl Context { + fn heartbeat(self: Arc, interaval: Duration) { + loop { + if self.heartbeat_control.should_terminate.probe() { + break; + } + let sleep_for = { + let guard = self.shared.lock(); + let now = Instant::now(); + + let num_heartbeats = guard + .heartbeats + .iter() + .filter_map(Option::as_ref) + .filter_map(|h| h.is_set.upgrade().map(|is_set| (is_set, &h.last_time))) + .inspect(|(is_set, last_time)| { + if now.duration_since(last_time.get()) >= interaval { + is_set.store(true, Ordering::Relaxed); + last_time.set(now); + } + }) + .count(); + + interaval.checked_div(num_heartbeats as u32) + }; + + if let Some(duration) = sleep_for { + thread::sleep(duration); + } + } + } +} + +impl Drop for Context { + fn drop(&mut self) { + for thread in &self.threads { + thread.control.notify_should_terminate(); + } + self.heartbeat_control.notify_should_terminate(); + + for thread in &self.threads { + thread.control.wait_for_termination(); + } + + self.heartbeat_control.wait_for_termination(); + } +} + +impl ThreadPool { + pub fn new_worker(&self) -> WorkerThread { + let (heartbeat, index) = self.context.shared.lock().new_heartbeat(); + WorkerThread::new(self.context.clone(), heartbeat, index) + } + + pub fn new(num_threads: usize) -> ThreadPool { + let threads = (0..num_threads) + .map(|_| { + CachePadded::new(ThreadState { + control: ThreadControl::new(), + }) + }) + .collect::>(); + + let context = Arc::new(Context { + shared: Mutex::new(SharedContext { + shared_tasks: Vec::with_capacity(num_threads), + heartbeats: Vec::with_capacity(num_threads), + rng: crate::rng::XorShift64Star::new(num_threads as u64), + }), + threads, + heartbeat_control: CachePadded::new(ThreadControl::new()), + task_shared: Condvar::new(), + }); + + let this = Self { context }; + + for _ in 0..num_threads { + let worker = this.new_worker(); + std::thread::spawn(move || { + worker.worker(); + }); + } + + let ctx = this.context.clone(); + std::thread::spawn(|| { + ctx.heartbeat(Duration::from_micros(100)); + }); + + this + } +}