does stuff, doesn't deadlock, faster than rayon (maybe?),

This commit is contained in:
Janis 2025-06-19 18:00:06 +02:00
parent d611535994
commit 9b0cc41834
4 changed files with 271 additions and 87 deletions

View file

@ -56,7 +56,7 @@ mod tree {
} }
} }
const TREE_SIZE: usize = 16; const TREE_SIZE: usize = 13;
#[bench] #[bench]
fn join_melange(b: &mut Bencher) { fn join_melange(b: &mut Bencher) {
@ -93,6 +93,7 @@ fn join_melange(b: &mut Bencher) {
#[bench] #[bench]
fn join_praetor(b: &mut Bencher) { fn join_praetor(b: &mut Bencher) {
tracing_subscriber::fmt().with_test_writer().init();
use executor::praetor::Scope; use executor::praetor::Scope;
let pool = executor::praetor::ThreadPool::global(); let pool = executor::praetor::ThreadPool::global();
@ -121,6 +122,7 @@ fn join_praetor(b: &mut Bencher) {
#[bench] #[bench]
fn join_sync(b: &mut Bencher) { fn join_sync(b: &mut Bencher) {
tracing_subscriber::fmt().with_test_writer().init();
let tree = tree::Tree::new(TREE_SIZE, 1u32); let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 { fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {

View file

@ -426,12 +426,46 @@ mod job {
} }
} }
#[derive(Debug)] // for some reason I confused head and tail here and the list is something like this:
// tail <-> job1 <-> job2 <-> ... <-> head
pub struct JobList { pub struct JobList {
head: Box<Job>, head: Box<Job>,
tail: Box<Job>, tail: Box<Job>,
} }
impl Debug for JobList {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JobList")
.field("head", &self.head)
.field("tail", &self.tail)
.field_with("jobs", |f| {
let mut list = f.debug_list();
// SAFETY: head always has prev
let mut job = unsafe { self.head().as_ref().link_mut().prev.unwrap() };
loop {
if job == self.tail() {
break;
}
let job_ref = unsafe { job.as_ref() };
list.entry(&job_ref);
// SAFETY: we are iterating over the linked list
if let Some(next) = unsafe { job_ref.link_mut().prev } {
job = next;
} else {
tracing::trace!("prev job is none?");
break;
};
}
list.finish()
})
.finish()
}
}
impl JobList { impl JobList {
pub fn new() -> JobList { pub fn new() -> JobList {
let head = Box::new(Job::empty()); let head = Box::new(Job::empty());
@ -743,6 +777,7 @@ mod job {
return JobResult::new(result); return JobResult::new(result);
} else { } else {
// spin until lock is released. // spin until lock is released.
tracing::trace!("spin-waiting for job: {:?}", self);
spin.spin(); spin.spin();
} }
} }
@ -779,6 +814,12 @@ mod job {
} }
pub fn execute(job: NonNull<Self>) { pub fn execute(job: NonNull<Self>) {
tracing::trace!(
"thread {:?}: executing job: {:?}",
std::thread::current().name(),
job
);
// SAFETY: self is non-null // SAFETY: self is non-null
unsafe { unsafe {
let this = job.as_ref(); let this = job.as_ref();
@ -990,9 +1031,10 @@ use async_task::Runnable;
use crossbeam::utils::CachePadded; use crossbeam::utils::CachePadded;
use job::*; use job::*;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use parking_lot_core::SpinWait;
use util::{DropGuard, SendPtr}; use util::{DropGuard, SendPtr};
use crate::latch::{AtomicLatch, LatchRef, NopLatch}; use crate::latch::{AtomicLatch, LatchRef, NopLatch, Probe};
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct JobCounter { pub struct JobCounter {
@ -1212,11 +1254,12 @@ impl WorkerThread {
if !guard.jobs.contains_key(&self.index) { if !guard.jobs.contains_key(&self.index) {
if let Some(job) = self.pop_back() { if let Some(job) = self.pop_back() {
tracing::trace!("heartbeat: sharing job: {:?}", job);
unsafe { unsafe {
job.as_ref().set_pending(); job.as_ref().set_pending();
} }
guard.jobs.insert(self.index, job); guard.jobs.insert(self.index, job);
self.context.shared_job.notify_one(); self.context.notify_shared_job();
} }
} }
@ -1228,69 +1271,72 @@ impl WorkerThread {
// does this optimise? // does this optimise?
assert!(!latch.probe()); assert!(!latch.probe());
'outer: while !latch.probe() { self.wait_until_predicate(|| latch.probe())
// take the shared job, if it exists
if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) {
self.execute(shared_job);
}
while !latch.probe() {
let mut guard = self.context.shared.lock();
match guard.jobs.pop_first().map(|(_, job)| job) {
Some(job) => {
drop(guard);
self.execute(job);
continue 'outer;
}
None => {
// TODO: spin2win
self.context.shared_job.wait(&mut guard);
}
}
}
}
} }
pub fn wait_until_latch<Latch: crate::Probe>(&self, latch: &Latch) { pub fn wait_until_latch<Latch: crate::Probe>(&self, latch: &Latch) {
if !latch.probe() { if !latch.probe() {
self.wait_until_latch_cold(latch); self.wait_until_latch_cold(latch)
} }
} }
#[inline]
fn wait_until_predicate<F>(&self, pred: F)
where
F: Fn() -> bool,
{
'outer: while !pred() {
// take a shared job, if it exists
if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) {
self.execute(shared_job);
}
// process local jobs before locking shared context
while let Some(job) = self.pop_front() {
unsafe {
job.as_ref().set_pending();
}
self.execute(job);
}
while !pred() {
let mut guard = self.context.shared.lock();
let mut spin = SpinWait::new();
match guard.pop_job() {
Some(job) => {
drop(guard);
self.execute(job);
continue 'outer;
}
None => {
// TODO: spin2win
tracing::trace!("waiting for shared job, thread id: {:?}", self.index);
// TODO: wait on latch? if we have something that can
// signal being done, e.g. can be waited on instead of
// shared jobs, we should wait on it instead, but we
// would also want to receive shared jobs still?
// self.context.shared_job.wait(&mut guard);
// if spin.spin() {
// // wait for more shared jobs.
// // self.context.shared_job.wait(&mut guard);
// return;
// }
std::thread::yield_now();
}
}
}
}
return;
}
pub fn wait_until_job<T>(&self, job: &Job<T>) -> Option<JobResult<T>> { pub fn wait_until_job<T>(&self, job: &Job<T>) -> Option<JobResult<T>> {
// take the shared job and check if it is our job self.wait_until_predicate(|| {
let shared_job = self.context.shared.lock().jobs.remove(&self.index); // check if job is finished
job.state() == JobState::Finished as u8
if let Some(ptr) = shared_job { });
if ptr.as_ptr() == &*job as *const _ as *mut _ {
// we can more efficiently run the job inline
return None;
} else {
// execute this job since it hasn't been taken
self.execute(ptr);
}
}
while job.state() != JobState::Finished as u8 {
let Some(job) = self
.context
.shared
.lock()
.jobs
.pop_first()
.map(|(_, job)| job)
// .or_else(|| {
// self.pop_front().inspect(|job| unsafe {
// job.as_ref().set_pending();
// })
// })
else {
break;
};
self.execute(job);
}
// someone else has this job and is working on it, // someone else has this job and is working on it,
// while job isn't done, suspend thread. // while job isn't done, suspend thread.
@ -1308,10 +1354,14 @@ where
impl<'scope> Scope<'scope> { impl<'scope> Scope<'scope> {
fn wait_for_jobs(&self) { fn wait_for_jobs(&self) {
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
let thread = WorkerThread::current_ref().unwrap(); let thread = WorkerThread::current_ref().unwrap();
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
tracing::trace!("thread id: {:?}, jobs: {:?}", thread.index, unsafe {
thread.queue.as_ref_unchecked()
});
thread.wait_until_latch(&self.job_counter); thread.wait_until_latch(&self.job_counter);
unsafe { self.job_counter.wait() };
} }
pub fn scope<F, R>(f: F) -> R pub fn scope<F, R>(f: F) -> R
@ -1325,7 +1375,7 @@ impl<'scope> Scope<'scope> {
}) })
} }
pub fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R
where where
F: FnOnce(&Self) -> R + Send, F: FnOnce(&Self) -> R + Send,
R: Send, R: Send,
@ -1416,7 +1466,7 @@ impl<'scope> Scope<'scope> {
where where
F: FnOnce(&Scope<'scope>) + Send, F: FnOnce(&Scope<'scope>) + Send,
{ {
WorkerThread::with_in(&self.context, |worker| { self.context.run_in_worker(|worker| {
self.job_counter.increment(); self.job_counter.increment();
let this = SendPtr::new_const(self).unwrap(); let this = SendPtr::new_const(self).unwrap();
@ -1522,13 +1572,14 @@ impl<'scope> Scope<'scope> {
A: FnOnce(&Self) -> RA + Send, A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send, B: FnOnce(&Self) -> RB + Send,
{ {
// let count = self.join_count.get(); let count = self.join_count.load(Ordering::Relaxed);
// self.join_count.set(count.wrapping_add(1) % TIMES); self.join_count
let count = self .store(count.wrapping_add(1) % TIMES, Ordering::Relaxed);
.join_count // let count = self
.update(Ordering::Relaxed, Ordering::Relaxed, |n| { // .join_count
n.wrapping_add(1) % TIMES // .update(Ordering::Relaxed, Ordering::Relaxed, |n| {
}); // n.wrapping_add(1) % TIMES
// });
if count == 1 { if count == 1 {
self.join_heartbeat(a, b) self.join_heartbeat(a, b)
@ -1562,13 +1613,23 @@ impl<'scope> Scope<'scope> {
let job = a.as_job(); let job = a.as_job();
worker.push_front(&job); worker.push_front(&job);
let rb = b(self); use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
Ok(val) => val,
Err(payload) => {
cold_path();
// if b panicked, we need to wait for a to finish
worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() });
resume_unwind(payload);
}
};
let ra = if job.state() == JobState::Empty as u8 { let ra = if job.state() == JobState::Empty as u8 {
unsafe { unsafe {
job.unlink(); job.unlink();
} }
// a is allowed to panic here, because we already finished b.
unsafe { a.unwrap()() } unsafe { a.unwrap()() }
} else { } else {
match worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) { match worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) {
@ -1592,16 +1653,17 @@ impl<'scope> Scope<'scope> {
} }
} }
// #[allow(dead_code)] /// run two closures potentially in parallel, in the global threadpool.
// pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB) #[allow(dead_code)]
// where pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
// RA: Send, where
// RB: Send, RA: Send,
// A: FnOnce() -> RA + Send, RB: Send,
// B: FnOnce() -> RB + Send, A: FnOnce() -> RA + Send,
// { B: FnOnce() -> RB + Send,
// Scope::with(|scope| scope.join(|_| a(), |_| b())) {
// } Scope::scope(|scope| scope.join(|_| a(), |_| b()))
}
pub struct ThreadPool { pub struct ThreadPool {
context: Arc<Context>, context: Arc<Context>,
@ -1648,7 +1710,7 @@ unsafe impl Send for SharedContext {}
impl SharedContext { impl SharedContext {
fn new_heartbeat(&mut self) -> (Arc<CachePadded<AtomicBool>>, usize) { fn new_heartbeat(&mut self) -> (Arc<CachePadded<AtomicBool>>, usize) {
let index = self.heartbeats_id; let index = self.heartbeats_id;
self.heartbeats_id.checked_add(1).unwrap(); self.heartbeats_id = self.heartbeats_id.checked_add(1).unwrap();
let is_set = Arc::new(CachePadded::new(AtomicBool::new(false))); let is_set = Arc::new(CachePadded::new(AtomicBool::new(false)));
let weak = Arc::downgrade(&is_set); let weak = Arc::downgrade(&is_set);
@ -1693,14 +1755,21 @@ impl Context {
// let num_threads = 2; // let num_threads = 2;
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1)); let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1));
for _ in 0..num_threads { for i in 0..num_threads {
let ctx = this.clone(); let ctx = this.clone();
let barrier = barrier.clone(); let barrier = barrier.clone();
std::thread::spawn(|| worker(ctx, barrier)); std::thread::Builder::new()
.name(format!("{:?}-worker-{}", Arc::as_ptr(&this), i))
.spawn(|| worker(ctx, barrier))
.expect("Failed to spawn worker thread");
} }
let ctx = this.clone(); let ctx = this.clone();
std::thread::spawn(|| heartbeat_worker(ctx));
std::thread::Builder::new()
.name(format!("{:?}-heartbeat", Arc::as_ptr(&this)))
.spawn(|| heartbeat_worker(ctx))
.expect("Failed to spawn heartbeat thread");
barrier.wait(); barrier.wait();
@ -1714,10 +1783,14 @@ impl Context {
pub fn inject_job(&self, job: NonNull<Job>) { pub fn inject_job(&self, job: NonNull<Job>) {
let mut guard = self.shared.lock(); let mut guard = self.shared.lock();
guard.injected_jobs.push(job); guard.injected_jobs.push(job);
self.notify_shared_job();
}
fn notify_shared_job(&self) {
self.shared_job.notify_one(); self.shared_job.notify_one();
} }
/// Runs closure in this context, processing the worker's jobs while waiting for the result. /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
where where
F: FnOnce(&WorkerThread) -> T + Send, F: FnOnce(&WorkerThread) -> T + Send,
@ -1738,8 +1811,10 @@ impl Context {
); );
let job = job.as_job(); let job = job.as_job();
job.set_pending();
self.inject_job(Into::into(&job)); self.inject_job(Into::into(&job));
// no need to wait for latch to signal, because we're waiting on the job anyway
worker.wait_until_latch(&latch); worker.wait_until_latch(&latch);
let t = unsafe { job.transmute_ref::<T>().wait().into_result() }; let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
@ -1769,6 +1844,7 @@ impl Context {
); );
let job = job.as_job(); let job = job.as_job();
job.set_pending();
self.inject_job(Into::into(&job)); self.inject_job(Into::into(&job));
latch.wait(); latch.wait();
@ -1788,12 +1864,19 @@ impl Context {
Some(worker) => { Some(worker) => {
// check if worker is in the same context // check if worker is in the same context
if Arc::ptr_eq(&worker.context, self) { if Arc::ptr_eq(&worker.context, self) {
tracing::trace!("run_in_worker: current thread");
f(worker) f(worker)
} else { } else {
// current thread is a worker for a different context
tracing::trace!("run_in_worker: cross-context");
self.run_in_worker_cross(worker, f) self.run_in_worker_cross(worker, f)
} }
} }
None => self.run_in_worker_cold(f), None => {
// current thread is not a worker for any context
tracing::trace!("run_in_worker: inject into context");
self.run_in_worker_cold(f)
}
} }
} }
} }
@ -1826,18 +1909,18 @@ fn worker(ctx: Arc<Context>, barrier: Arc<std::sync::Barrier>) {
} }
let _guard = DropGuard::new(|| unsafe { let _guard = DropGuard::new(|| unsafe {
tracing::trace!("worker thread dropping {:?}", std::thread::current());
WorkerThread::drop_in_place_and_dealloc(WorkerThread::unset_current().unwrap()); WorkerThread::drop_in_place_and_dealloc(WorkerThread::unset_current().unwrap());
}); });
let scope = WorkerThread::current_ref().unwrap(); let worker = WorkerThread::current_ref().unwrap();
barrier.wait(); barrier.wait();
let mut job = ctx.shared.lock().pop_job(); let mut job = ctx.shared.lock().pop_job();
loop { loop {
tracing::trace!("worker({:?}): new job {:?}", std::thread::current(), job);
if let Some(job) = job { if let Some(job) = job {
scope.execute(job); worker.execute(job);
} }
let mut guard = ctx.shared.lock(); let mut guard = ctx.shared.lock();

View file

@ -3,6 +3,8 @@ use std::{
pin::{pin, Pin}, pin::{pin, Pin},
}; };
use tracing_test::traced_test;
use super::{util::TaggedAtomicPtr, *}; use super::{util::TaggedAtomicPtr, *};
fn pin_ptr<T>(pin: &Pin<&mut T>) -> NonNull<T> { fn pin_ptr<T>(pin: &Pin<&mut T>) -> NonNull<T> {
@ -446,6 +448,35 @@ fn join() {
eprintln!("x: {x}"); eprintln!("x: {x}");
} }
#[test]
#[traced_test]
fn join_many() {
use crate::util::tree::{Tree, TREE_SIZE};
let pool = ThreadPool::new();
let tree = Tree::new(16, 1u32);
fn sum(tree: &Tree<u32>, node: usize, scope: &Scope) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join_heartbeat(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
// eprintln!("node: {node:?}, l: {l}, r: {r}");
node.leaf + l + r
}
let sum = pool.scope(|s| sum(&tree, tree.root().unwrap(), s));
eprintln!("{sum}");
}
#[test] #[test]
fn rebox() { fn rebox() {
struct A(u32); struct A(u32);

View file

@ -65,3 +65,71 @@ impl XorShift64Star {
(self.next() % n as u64) as usize (self.next() % n as u64) as usize
} }
} }
#[macro_export]
macro_rules! cfg_miri {
(
@miri => { $($tokens:tt)* }$(,)?
_ => { $($tokens2:tt)* }
) => {
#[cfg(miri)]
{
$($tokens)*
}
#[cfg(not(miri))]
{
$($tokens2)*
}
};
}
pub mod tree {
pub struct Tree<T> {
nodes: Box<[Node<T>]>,
root: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct Node<T> {
pub leaf: T,
pub left: Option<usize>,
pub right: Option<usize>,
}
impl<T> Tree<T> {
pub fn new(depth: usize, t: T) -> Tree<T>
where
T: Copy,
{
let mut nodes = Vec::with_capacity((0..depth).sum());
let root = Self::build_node(&mut nodes, depth, t);
Self {
nodes: nodes.into_boxed_slice(),
root: Some(root),
}
}
pub fn root(&self) -> Option<usize> {
self.root
}
pub fn get(&self, index: usize) -> &Node<T> {
&self.nodes[index]
}
pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
where
T: Copy,
{
let node = Node {
leaf: t,
left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
};
nodes.push(node);
nodes.len() - 1
}
}
pub const TREE_SIZE: usize = 16;
}