chili-like executor for joins

This commit is contained in:
Janis 2025-02-01 00:42:33 +01:00
parent 736e4e1a60
commit b83bfeca51
6 changed files with 1126 additions and 12 deletions

View file

@ -17,10 +17,11 @@ never-local = []
futures = "0.3" futures = "0.3"
rayon = "1.10" rayon = "1.10"
bevy_tasks = "0.15.1" bevy_tasks = "0.15.1"
parking_lot = "0.12.3" parking_lot = {version = "0.12.3"}
thread_local = "1.1.8" thread_local = "1.1.8"
crossbeam = "0.8.4" crossbeam = "0.8.4"
st3 = "0.4" st3 = "0.4"
chili = "0.2.0"
async-task = "4.7.1" async-task = "4.7.1"

169
benches/join.rs Normal file
View file

@ -0,0 +1,169 @@
#![feature(test)]
use std::{
sync::{atomic::AtomicUsize, Arc},
thread,
time::Duration,
};
use bevy_tasks::available_parallelism;
use executor::{self};
use test::Bencher;
use tree::Node;
extern crate test;
mod tree {
pub struct Tree<T> {
nodes: Box<[Node<T>]>,
root: Option<usize>,
}
pub struct Node<T> {
pub leaf: T,
pub left: Option<usize>,
pub right: Option<usize>,
}
impl<T> Tree<T> {
pub fn new(depth: usize, t: T) -> Tree<T>
where
T: Copy,
{
let mut nodes = Vec::with_capacity((0..depth).sum());
let root = Self::build_node(&mut nodes, depth, t);
Self {
nodes: nodes.into_boxed_slice(),
root: Some(root),
}
}
pub fn root(&self) -> Option<usize> {
self.root
}
pub fn get(&self, index: usize) -> &Node<T> {
&self.nodes[index]
}
pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
where
T: Copy,
{
let node = Node {
leaf: t,
left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
};
nodes.push(node);
nodes.len() - 1
}
}
}
const PRIMES: &'static [usize] = &[
1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289,
1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, 1423, 1427,
1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, 1499, 1511, 1523,
1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, 1607, 1609, 1613, 1619, 1621,
1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721, 1723, 1733, 1741, 1747, 1753,
1759, 1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847, 1861, 1867, 1871, 1873, 1877, 1879,
1889, 1901, 1907,
];
const REPEAT: usize = 0x800;
const TREE_SIZE: usize = 14;
#[bench]
fn join_melange(b: &mut Bencher) {
let pool = executor::melange::ThreadPool::new(available_parallelism());
let mut scope = pool.new_worker();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(
tree: &tree::Tree<u32>,
node: usize,
scope: &mut executor::melange::WorkerThread,
) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
node.leaf + l + r
}
b.iter(move || {
assert_ne!(sum(&tree, tree.root().unwrap(), &mut scope), 0);
});
}
#[bench]
fn join_sync(b: &mut Bencher) {
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
let node = tree.get(node);
let (l, r) = (
node.left.map(|node| sum(tree, node)).unwrap_or_default(),
node.right.map(|node| sum(tree, node)).unwrap_or_default(),
);
node.leaf + l + r
}
b.iter(move || {
assert_ne!(sum(&tree, tree.root().unwrap()), 0);
});
}
#[bench]
fn join_chili(b: &mut Bencher) {
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize, scope: &mut chili::Scope<'_>) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
node.leaf + l + r
}
b.iter(move || {
assert_ne!(
sum(&tree, tree.root().unwrap(), &mut chili::Scope::global()),
0
);
});
}
#[bench]
fn join_rayon(b: &mut Bencher) {
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
let node = tree.get(node);
let (l, r) = rayon::join(
|| node.left.map(|node| sum(tree, node)).unwrap_or_default(),
|| node.right.map(|node| sum(tree, node)).unwrap_or_default(),
);
node.leaf + l + r
}
b.iter(move || {
assert_ne!(sum(&tree, tree.root().unwrap()), 0);
});
}

1
rust-toolchain Normal file
View file

@ -0,0 +1 @@
nightly

139
src/job/mod.rs Normal file
View file

@ -0,0 +1,139 @@
///! Rayon's job logic
use std::{cell::UnsafeCell, marker::PhantomPinned, sync::atomic::AtomicBool};
use crate::latch::Latch;
pub trait Job<Args = ()> {
unsafe fn execute(this: *const (), args: Args);
}
pub struct JobRef<Args = ()> {
this: *const (),
execute_fn: unsafe fn(*const (), Args),
}
unsafe impl<Args> Send for JobRef<Args> {}
unsafe impl<Args> Sync for JobRef<Args> {}
impl<Args> JobRef<Args> {
pub unsafe fn new<T>(data: *const T) -> JobRef<Args>
where
T: Job<Args>,
{
Self {
this: data.cast(),
execute_fn: <T as Job<Args>>::execute,
}
}
pub fn id(&self) -> impl Eq {
(self.this, self.execute_fn)
}
pub unsafe fn execute(self, args: Args) {
unsafe { (self.execute_fn)(self.this, args) }
}
}
pub struct StackJob<F, L>
where
L: Latch + Sync,
{
task: UnsafeCell<Option<F>>,
latch: L,
_phantom: PhantomPinned,
}
impl<F, L> StackJob<F, L>
where
L: Latch + Sync,
{
pub fn new(task: F, latch: L) -> StackJob<F, L> {
Self {
task: UnsafeCell::new(Some(task)),
latch,
_phantom: PhantomPinned,
}
}
pub unsafe fn take_once(self) -> F {
self.task.into_inner().unwrap()
}
#[inline]
pub fn run<Args>(self, args: Args)
where
F: FnOnce(Args),
{
self.task.into_inner().unwrap()(args);
}
#[inline]
pub unsafe fn as_task_ref<Args>(&self) -> JobRef<Args>
where
F: FnOnce(Args),
{
unsafe { JobRef::<Args>::new(self) }
}
}
impl<Args, F, L> Job<Args> for StackJob<F, L>
where
F: FnOnce(Args),
L: Latch + Sync,
{
unsafe fn execute(this: *const (), args: Args) {
let this = &*this.cast::<Self>();
let func = (*this.task.get()).take().unwrap();
func(args);
Latch::set_raw(&this.latch);
// set internal latch here?
}
}
pub struct HeapJob<F = ()>
where
F: Send,
{
func: F,
_phantom: PhantomPinned,
}
impl<F> HeapJob<F>
where
F: Send,
{
pub fn new(task: F) -> Box<HeapJob<F>> {
Box::new(Self {
func: task,
_phantom: PhantomPinned,
})
}
#[inline]
pub unsafe fn into_static_task_ref<Args>(self: Box<Self>) -> JobRef<Args>
where
F: FnOnce(Args) + 'static,
{
self.into_task_ref()
}
#[inline]
pub unsafe fn into_task_ref<Args>(self: Box<Self>) -> JobRef<Args>
where
F: FnOnce(Args),
{
JobRef::new(Box::into_raw(self))
}
}
impl<Args, F> Job<Args> for HeapJob<F>
where
F: FnOnce(Args) + Send,
{
unsafe fn execute(this: *const (), args: Args) {
let this = Box::from_raw(this.cast::<Self>().cast_mut());
(this.func)(args);
// set internal latch here?
}
}

View file

@ -27,6 +27,8 @@ use scope::Scope;
use task::{HeapTask, StackTask, TaskRef}; use task::{HeapTask, StackTask, TaskRef};
use tracing::debug; use tracing::debug;
pub mod job;
pub mod task { pub mod task {
use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin};
@ -59,11 +61,18 @@ pub mod task {
(self.ptr, self.execute_fn) (self.ptr, self.execute_fn)
} }
/// caller must ensure that this particular task is [`Send`]
#[inline] #[inline]
pub fn execute(self) { pub fn execute(self) {
unsafe { (self.execute_fn)(self.ptr) } unsafe { (self.execute_fn)(self.ptr) }
} }
#[inline]
pub unsafe fn execute_with_scope<T>(self, scope: &mut T) {
unsafe {
core::mem::transmute::<_, unsafe fn(*const (), &mut T)>(self.execute_fn)(
self.ptr, scope,
)
}
}
} }
unsafe impl Send for TaskRef {} unsafe impl Send for TaskRef {}
@ -191,6 +200,38 @@ pub mod latch {
} }
} }
pub struct ClosureLatch<S, P> {
set: S,
probe: P,
}
impl<S, P> ClosureLatch<S, P> {
pub fn new(set: S, probe: P) -> Self {
Self { set, probe }
}
pub fn new_boxed(set: S, probe: P) -> Box<Self> {
Box::new(Self { set, probe })
}
}
impl<S, P> Latch for ClosureLatch<S, P>
where
S: Fn(),
{
unsafe fn set_raw(this: *const Self) {
let this = &*this;
(this.set)();
}
}
impl<S, P> Probe for ClosureLatch<S, P>
where
P: Fn() -> bool,
{
fn probe(&self) -> bool {
(self.probe)()
}
}
pub struct ThreadWakeLatch { pub struct ThreadWakeLatch {
inner: AtomicLatch, inner: AtomicLatch,
index: usize, index: usize,
@ -337,6 +378,8 @@ pub mod latch {
} }
} }
pub mod melange;
pub struct ThreadPoolState { pub struct ThreadPoolState {
num_threads: AtomicUsize, num_threads: AtomicUsize,
lock: Mutex<()>, lock: Mutex<()>,
@ -344,6 +387,7 @@ pub struct ThreadPoolState {
} }
bitflags! { bitflags! {
#[derive(Clone)]
pub struct ThreadStatus: u8 { pub struct ThreadStatus: u8 {
const RUNNING = 1 << 0; const RUNNING = 1 << 0;
const SLEEPING = 1 << 1; const SLEEPING = 1 << 1;
@ -366,9 +410,16 @@ pub struct ThreadState {
} }
impl ThreadControl { impl ThreadControl {
pub const fn new() -> Self {
Self {
status: Mutex::new(ThreadStatus::empty()),
status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(),
}
}
/// returns true if thread was sleeping /// returns true if thread was sleeping
#[inline] #[inline]
fn wake(&self) -> bool { pub fn wake(&self) -> bool {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
guard.insert(ThreadStatus::SHOULD_WAKE); guard.insert(ThreadStatus::SHOULD_WAKE);
self.status_changed.notify_all(); self.status_changed.notify_all();
@ -376,7 +427,7 @@ impl ThreadControl {
} }
#[inline] #[inline]
fn wait_for_running(&self) { pub fn wait_for_running(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::RUNNING) { while !guard.contains(ThreadStatus::RUNNING) {
self.status_changed.wait(&mut guard); self.status_changed.wait(&mut guard);
@ -384,7 +435,7 @@ impl ThreadControl {
} }
#[inline] #[inline]
fn wait_for_should_wake(&self) { pub fn wait_for_should_wake(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::SHOULD_WAKE) { while !guard.contains(ThreadStatus::SHOULD_WAKE) {
guard.insert(ThreadStatus::SLEEPING); guard.insert(ThreadStatus::SLEEPING);
@ -394,7 +445,7 @@ impl ThreadControl {
} }
#[inline] #[inline]
fn wait_for_should_wake_timeout(&self, timeout: Duration) { pub fn wait_for_should_wake_timeout(&self, timeout: Duration) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::SHOULD_WAKE) { while !guard.contains(ThreadStatus::SHOULD_WAKE) {
guard.insert(ThreadStatus::SLEEPING); guard.insert(ThreadStatus::SLEEPING);
@ -410,7 +461,7 @@ impl ThreadControl {
} }
#[inline] #[inline]
fn wait_for_termination(&self) { pub fn wait_for_termination(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while guard.contains(ThreadStatus::RUNNING) { while guard.contains(ThreadStatus::RUNNING) {
self.status_changed.wait(&mut guard); self.status_changed.wait(&mut guard);
@ -418,21 +469,21 @@ impl ThreadControl {
} }
#[inline] #[inline]
fn notify_running(&self) { pub fn notify_running(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
guard.insert(ThreadStatus::RUNNING); guard.insert(ThreadStatus::RUNNING);
self.status_changed.notify_all(); self.status_changed.notify_all();
} }
#[inline] #[inline]
fn notify_termination(&self) { pub fn notify_termination(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
*guard = ThreadStatus::empty(); *guard = ThreadStatus::empty();
self.status_changed.notify_all(); self.status_changed.notify_all();
} }
#[inline] #[inline]
fn notify_should_terminate(&self) { pub fn notify_should_terminate(&self) {
unsafe { unsafe {
Latch::set_raw(&self.should_terminate); Latch::set_raw(&self.should_terminate);
} }
@ -1502,7 +1553,7 @@ mod scope {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{cell::Cell, hint::black_box}; use std::{cell::Cell, hint::black_box, time::Instant};
use tracing::info; use tracing::info;
@ -1643,7 +1694,7 @@ mod tests {
let elapsed = now.elapsed().as_micros(); let elapsed = now.elapsed().as_micros();
info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); info!("(rayon) {sum} total time: {}ms", elapsed as f32 / 1e3);
} }
#[test] #[test]
@ -1717,6 +1768,37 @@ mod tests {
}); });
} }
#[test]
#[tracing_test::traced_test]
fn melange_join() {
let pool = melange::ThreadPool::new(bevy_tasks::available_parallelism());
let mut scope = pool.new_worker();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize, scope: &mut melange::WorkerThread) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
node.leaf + l + r
}
let now = Instant::now();
let res = sum(&tree, tree.root().unwrap(), &mut scope);
eprintln!(
"res: {res} took {}ms",
now.elapsed().as_micros() as f32 / 1e3
);
assert_ne!(res, 0);
}
#[test] #[test]
#[tracing_test::traced_test] #[tracing_test::traced_test]
fn sync() { fn sync() {

722
src/melange.rs Normal file
View file

@ -0,0 +1,722 @@
use std::{
cell::Cell,
collections::VecDeque,
marker::PhantomPinned,
ops::{Deref, DerefMut},
pin::pin,
ptr::{self, NonNull},
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
thread,
time::{Duration, Instant},
};
use crossbeam::utils::CachePadded;
use parking_lot::{Condvar, Mutex};
use crate::{latch::*, task::*, ThreadControl, ThreadStatus};
mod job {
use std::{
cell::{Cell, UnsafeCell},
collections::VecDeque,
mem::ManuallyDrop,
panic::{self, AssertUnwindSafe},
ptr::NonNull,
sync::atomic::{AtomicU8, Ordering},
thread::{self, Thread},
};
use super::WorkerThread as Scope;
enum Poll {
Pending,
Ready,
Locked,
}
#[derive(Debug, Default)]
pub struct Future<T = ()> {
state: AtomicU8,
/// Can only be accessed if `state` is `Poll::Locked`.
waiting_thread: UnsafeCell<Option<Thread>>,
/// Can only be written if `state` is `Poll::Locked` and read if `state` is
/// `Poll::Ready`.
val: UnsafeCell<Option<Box<thread::Result<T>>>>,
}
impl<T> Future<T> {
pub fn poll(&self) -> bool {
self.state.load(Ordering::Acquire) == Poll::Ready as u8
}
pub fn wait(&self) -> Option<thread::Result<T>> {
loop {
let result = self.state.compare_exchange(
Poll::Pending as u8,
Poll::Locked as u8,
Ordering::AcqRel,
Ordering::Acquire,
);
match result {
Ok(_) => {
// SAFETY:
// Lock is acquired, only we are accessing `self.waiting_thread`.
unsafe { *self.waiting_thread.get() = Some(thread::current()) };
self.state.store(Poll::Pending as u8, Ordering::Release);
thread::park();
// Skip yielding after being woken up.
continue;
}
Err(state) if state == Poll::Ready as u8 => {
// SAFETY:
// `state` is `Poll::Ready` only after `Self::complete`
// releases the lock.
//
// Calling `Self::complete` when `state` is `Poll::Ready`
// cannot mutate `self.val`.
break unsafe { (*self.val.get()).take().map(|b| *b) };
}
_ => (),
}
thread::yield_now();
}
}
pub fn complete(&self, val: thread::Result<T>) {
let val = Box::new(val);
loop {
let result = self.state.compare_exchange(
Poll::Pending as u8,
Poll::Locked as u8,
Ordering::AcqRel,
Ordering::Acquire,
);
match result {
Ok(_) => break,
Err(_) => thread::yield_now(),
}
}
// SAFETY:
// Lock is acquired, only we are accessing `self.val`.
unsafe {
*self.val.get() = Some(val);
}
// SAFETY:
// Lock is acquired, only we are accessing `self.waiting_thread`.
if let Some(thread) = unsafe { (*self.waiting_thread.get()).take() } {
thread.unpark();
}
self.state.store(Poll::Ready as u8, Ordering::Release);
}
}
pub struct JobStack<F = ()> {
/// All code paths should call either `Job::execute` or `Self::unwrap` to
/// avoid a potential memory leak.
f: UnsafeCell<ManuallyDrop<F>>,
}
impl<F> JobStack<F> {
pub fn new(f: F) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
}
}
/// SAFETY:
/// It should only be called once.
pub unsafe fn take_once(&self) -> F {
// SAFETY:
// No `Job` has has been executed, therefore `self.f` has not yet been
// `take`n.
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
}
/// `Job` is only sent, not shared between threads.
///
/// When popped from the `JobQueue`, it gets copied before sending across
/// thread boundaries.
#[derive(Clone, Debug)]
pub struct Job<T = ()> {
stack: NonNull<JobStack>,
harness: unsafe fn(&mut Scope, NonNull<JobStack>, NonNull<Future>),
fut: Cell<Option<NonNull<Future<T>>>>,
}
impl<T> Job<T> {
pub fn new<F>(stack: &JobStack<F>) -> Self
where
F: FnOnce(&mut Scope) -> T + Send,
T: Send,
{
/// SAFETY:
/// It should only be called while the `stack` is still alive.
unsafe fn harness<F, T>(
scope: &mut Scope,
stack: NonNull<JobStack>,
fut: NonNull<Future>,
) where
F: FnOnce(&mut Scope) -> T + Send,
T: Send,
{
// SAFETY:
// The `stack` is still alive.
let stack: &JobStack<F> = unsafe { stack.cast().as_ref() };
// SAFETY:
// This is the first call to `take_once` since `Job::execute`
// (the only place where this harness is called) is called only
// after the job has been popped.
let f = unsafe { stack.take_once() };
// SAFETY:
// Before being popped, the `JobQueue` allocates and stores a
// `Future` in `self.fur_or_next` that should get passed here.
let fut: &Future<T> = unsafe { fut.cast().as_ref() };
fut.complete(panic::catch_unwind(AssertUnwindSafe(|| f(scope))));
}
Self {
stack: NonNull::from(stack).cast(),
harness: harness::<F, T>,
fut: Cell::new(None),
}
}
pub fn is_waiting(&self) -> bool {
self.fut.get().is_none()
}
pub fn eq(&self, other: &Job) -> bool {
self.stack == other.stack
}
/// SAFETY:
/// It should only be called after being popped from a `JobQueue`.
pub unsafe fn poll(&self) -> bool {
self.fut
.get()
.map(|fut| {
// SAFETY:
// Before being popped, the `JobQueue` allocates and stores a
// `Future` in `self.fur_or_next` that should get passed here.
let fut = unsafe { fut.as_ref() };
fut.poll()
})
.unwrap_or_default()
}
/// SAFETY:
/// It should only be called after being popped from a `JobQueue`.
pub unsafe fn wait(&self) -> Option<thread::Result<T>> {
self.fut.get().and_then(|fut| {
// SAFETY:
// Before being popped, the `JobQueue` allocates and stores a
// `Future` in `self.fur_or_next` that should get passed here.
let result = unsafe { fut.as_ref().wait() };
// SAFETY:
// We only can drop the `Box` *after* waiting on the `Future`
// in order to ensure unique access.
unsafe {
drop(Box::from_raw(fut.as_ptr()));
}
result
})
}
/// SAFETY:
/// It should only be called in the case where the job has been popped
/// from the front and will not be `Job::Wait`ed.
pub unsafe fn drop(&self) {
if let Some(fut) = self.fut.get() {
// SAFETY:
// Before being popped, the `JobQueue` allocates and store a
// `Future` in `self.fur_or_next` that should get passed here.
unsafe {
drop(Box::from_raw(fut.as_ptr()));
}
}
}
}
impl Job {
/// SAFETY:
/// It should only be called while the `JobStack` it was created with is
/// still alive and after being popped from a `JobQueue`.
pub unsafe fn execute(&self, scope: &mut Scope) {
// SAFETY:
// Before being popped, the `JobQueue` allocates and store a
// `Future` in `self.fur_or_next` that should get passed here.
unsafe {
(self.harness)(scope, self.stack, self.fut.get().unwrap());
}
}
}
// SAFETY:
// The job's `stack` will only be accessed after acquiring a lock (in
// `Future`), while `prev` and `fut_or_next` are never accessed after being
// sent across threads.
unsafe impl Send for Job {}
#[derive(Debug, Default)]
pub struct JobQueue(VecDeque<NonNull<Job>>);
impl JobQueue {
pub fn len(&self) -> usize {
self.0.len()
}
/// SAFETY:
/// Any `Job` pushed onto the queue should alive at least until it gets
/// popped.
pub unsafe fn push_back<T>(&mut self, job: &Job<T>) {
self.0.push_back(NonNull::from(job).cast());
}
pub fn pop_back(&mut self) {
self.0.pop_back();
}
pub fn pop_front(&mut self) -> Option<Job> {
// SAFETY:
// `Job` is still alive as per contract in `push_back`.
let job = unsafe { self.0.pop_front()?.as_ref() };
job.fut
.set(Some(Box::leak(Box::new(Future::default())).into()));
Some(job.clone())
}
}
}
//use job::{Future, Job, JobQueue, JobStack};
use crate::job::{Job, JobRef, StackJob};
struct ThreadState {
control: ThreadControl,
}
struct Heartbeat {
is_set: Weak<AtomicBool>,
last_time: Cell<Instant>,
}
pub struct SharedContext {
shared_tasks: Vec<Option<JobRef>>,
heartbeats: Vec<Option<Heartbeat>>,
rng: crate::rng::XorShift64Star,
}
pub struct Context {
shared: Mutex<SharedContext>,
threads: Box<[CachePadded<ThreadState>]>,
heartbeat_control: CachePadded<ThreadControl>,
task_shared: Condvar,
}
pub struct ThreadPool {
context: Arc<Context>,
}
impl SharedContext {
fn new_heartbeat(&mut self) -> (Arc<AtomicBool>, usize) {
let is_set = Arc::new(AtomicBool::new(true));
let heartbeat = Heartbeat {
is_set: Arc::downgrade(&is_set),
last_time: Cell::new(Instant::now()),
};
let index = match self.heartbeats.iter().position(|a| a.is_none()) {
Some(i) => {
self.heartbeats[i] = Some(heartbeat);
i
}
None => {
self.heartbeats.push(Some(heartbeat));
self.shared_tasks.push(None);
self.heartbeats.len() - 1
}
};
(is_set, index)
}
fn pop_first_task(&mut self) -> Option<JobRef> {
self.shared_tasks
.iter_mut()
.filter_map(|task| task.take())
.next()
}
fn pop_random_task(&mut self) -> Option<JobRef> {
let i = self.rng.next_usize(self.shared_tasks.len());
let (a, b) = self.shared_tasks.split_at_mut(i);
a.into_iter().chain(b).filter_map(|task| task.take()).next()
}
}
std::thread_local! {
static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null())};
}
pub struct WorkerThread {
context: Arc<Context>,
index: usize,
queue: VecDeque<JobRef>,
heartbeat: Arc<AtomicBool>,
join_count: u8,
sleep_count: usize,
_marker: PhantomPinned,
}
impl WorkerThread {
fn new(context: Arc<Context>, heartbeat: Arc<AtomicBool>, index: usize) -> WorkerThread {
WorkerThread {
context,
index,
queue: VecDeque::default(),
join_count: 0,
heartbeat,
sleep_count: 0,
_marker: PhantomPinned,
}
}
unsafe fn set_current(this: *const Self) {
WORKER_THREAD_STATE.with(|ptr| {
assert!(ptr.get().is_null());
ptr.set(this);
});
}
unsafe fn unset_current() {
WORKER_THREAD_STATE.with(|ptr| {
assert!(!ptr.get().is_null());
ptr.set(ptr::null());
});
}
unsafe fn current() -> *const WorkerThread {
let ptr = WORKER_THREAD_STATE.with(|ptr| ptr.get());
ptr
}
fn state(&self) -> &CachePadded<ThreadState> {
&self.context.threads[self.index]
}
fn control(&self) -> &ThreadControl {
&self.context.threads[self.index].control
}
fn shared(&self) -> &Mutex<SharedContext> {
&self.context.shared
}
fn ctx(&self) -> &Arc<Context> {
&self.context
}
fn with<T, F: FnOnce(Option<&WorkerThread>) -> T>(f: F) -> T {
WORKER_THREAD_STATE.with(|worker| {
f(unsafe { NonNull::new(worker.get().cast_mut()).map(|ptr| ptr.as_ref()) })
})
}
fn with_mut<T, F: FnOnce(Option<&mut WorkerThread>) -> T>(f: F) -> T {
WORKER_THREAD_STATE.with(|worker| {
f(unsafe { NonNull::new(worker.get().cast_mut()).map(|mut ptr| ptr.as_mut()) })
})
}
}
struct CurrentWorker;
impl Deref for CurrentWorker {
type Target = WorkerThread;
fn deref(&self) -> &Self::Target {
unsafe {
NonNull::new(WorkerThread::current().cast_mut())
.unwrap()
.as_ref()
}
}
}
impl DerefMut for CurrentWorker {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
NonNull::new(WorkerThread::current().cast_mut())
.unwrap()
.as_mut()
}
}
}
// impl Drop for WorkerThread {
// fn drop(&mut self) {
// WORKER_THREAD_STATE.with(|ptr| {
// assert!(!ptr.get().is_null());
// ptr.set(ptr::null());
// });
// }
// }
impl WorkerThread {
fn worker(self) {
{
let worker = Box::leak(Box::new(self));
unsafe {
WorkerThread::set_current(worker);
}
}
CurrentWorker.control().notify_running();
loop {
let task = { CurrentWorker.shared().lock().pop_first_task() };
if let Some(task) = task {
CurrentWorker.execute_job(task);
}
if CurrentWorker.control().should_terminate.probe() {
break;
}
let mut guard = CurrentWorker.shared().lock();
CurrentWorker.ctx().task_shared.wait(&mut guard);
}
CurrentWorker.control().notify_termination();
unsafe {
let worker = Box::from_raw(WorkerThread::current().cast_mut());
WorkerThread::unset_current();
}
}
fn execute_job(&mut self, job: JobRef) {
unsafe { core::mem::transmute::<JobRef, JobRef<&mut WorkerThread>>(job).execute(self) };
}
#[cold]
fn heartbeat_cold(&mut self) {
let mut guard = self.context.shared.lock();
if guard.shared_tasks[self.index].is_none() {
if let Some(task) = self.queue.pop_front() {
guard.shared_tasks[self.index] = Some(task);
self.context.task_shared.notify_one();
}
}
self.heartbeat.store(false, Ordering::Relaxed);
}
pub fn join<A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
self.join_with_every::<64, _, _, _, _>(a, b)
}
pub fn join_with_every<const T: u8, A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
self.join_count = self.join_count.wrapping_add(1) % T;
if self.join_count == 0 || self.queue.len() < 3 {
self.join_heartbeat(a, b)
} else {
self.join_seq(a, b)
}
}
fn join_seq<A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
let rb = b(self);
let ra = a(self);
(ra, rb)
}
fn join_heartbeat<A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
let mut ra = None;
let a = |scope: &mut WorkerThread| {
if scope.heartbeat.load(Ordering::Relaxed) {
scope.heartbeat_cold();
}
ra = Some(a(scope));
};
let latch = AtomicLatch::new();
let ctx = self.context.clone();
let idx = self.index;
let stack = StackJob::new(a, latch);
let task: JobRef =
unsafe { core::mem::transmute::<JobRef<_>, JobRef>(stack.as_task_ref()) };
let id = task.id();
self.queue.push_back(task);
let rb = b(self);
if !latch.probe() {
if let Some(job) = self.queue.pop_back() {
if job.id() == id {
unsafe {
(stack.take_once())(self);
}
return (ra.unwrap(), rb);
} else {
self.queue.push_back(job);
}
}
}
self.run_until(&latch);
(ra.unwrap(), rb)
}
fn run_until<L: Probe>(&mut self, latch: &L) {
if !latch.probe() {
self.run_until_cold(latch);
}
}
#[cold]
fn run_until_cold<L: Probe>(&mut self, latch: &L) {
let job = self.shared().lock().shared_tasks[self.index].take();
if let Some(job) = job {
self.execute_job(job);
}
while !latch.probe() {
let job = self.context.shared.lock().pop_first_task();
if let Some(job) = job {
self.execute_job(job);
}
}
}
}
impl Context {
fn heartbeat(self: Arc<Self>, interaval: Duration) {
loop {
if self.heartbeat_control.should_terminate.probe() {
break;
}
let sleep_for = {
let guard = self.shared.lock();
let now = Instant::now();
let num_heartbeats = guard
.heartbeats
.iter()
.filter_map(Option::as_ref)
.filter_map(|h| h.is_set.upgrade().map(|is_set| (is_set, &h.last_time)))
.inspect(|(is_set, last_time)| {
if now.duration_since(last_time.get()) >= interaval {
is_set.store(true, Ordering::Relaxed);
last_time.set(now);
}
})
.count();
interaval.checked_div(num_heartbeats as u32)
};
if let Some(duration) = sleep_for {
thread::sleep(duration);
}
}
}
}
impl Drop for Context {
fn drop(&mut self) {
for thread in &self.threads {
thread.control.notify_should_terminate();
}
self.heartbeat_control.notify_should_terminate();
for thread in &self.threads {
thread.control.wait_for_termination();
}
self.heartbeat_control.wait_for_termination();
}
}
impl ThreadPool {
pub fn new_worker(&self) -> WorkerThread {
let (heartbeat, index) = self.context.shared.lock().new_heartbeat();
WorkerThread::new(self.context.clone(), heartbeat, index)
}
pub fn new(num_threads: usize) -> ThreadPool {
let threads = (0..num_threads)
.map(|_| {
CachePadded::new(ThreadState {
control: ThreadControl::new(),
})
})
.collect::<Box<_>>();
let context = Arc::new(Context {
shared: Mutex::new(SharedContext {
shared_tasks: Vec::with_capacity(num_threads),
heartbeats: Vec::with_capacity(num_threads),
rng: crate::rng::XorShift64Star::new(num_threads as u64),
}),
threads,
heartbeat_control: CachePadded::new(ThreadControl::new()),
task_shared: Condvar::new(),
});
let this = Self { context };
for _ in 0..num_threads {
let worker = this.new_worker();
std::thread::spawn(move || {
worker.worker();
});
}
let ctx = this.context.clone();
std::thread::spawn(|| {
ctx.heartbeat(Duration::from_micros(100));
});
this
}
}