does stuff, doesn't deadlock, faster than rayon (maybe?),

This commit is contained in:
Janis 2025-06-19 18:00:06 +02:00
parent d611535994
commit 9b0cc41834
4 changed files with 271 additions and 87 deletions

View file

@ -56,7 +56,7 @@ mod tree {
}
}
const TREE_SIZE: usize = 16;
const TREE_SIZE: usize = 13;
#[bench]
fn join_melange(b: &mut Bencher) {
@ -93,6 +93,7 @@ fn join_melange(b: &mut Bencher) {
#[bench]
fn join_praetor(b: &mut Bencher) {
tracing_subscriber::fmt().with_test_writer().init();
use executor::praetor::Scope;
let pool = executor::praetor::ThreadPool::global();
@ -121,6 +122,7 @@ fn join_praetor(b: &mut Bencher) {
#[bench]
fn join_sync(b: &mut Bencher) {
tracing_subscriber::fmt().with_test_writer().init();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {

View file

@ -426,12 +426,46 @@ mod job {
}
}
#[derive(Debug)]
// for some reason I confused head and tail here and the list is something like this:
// tail <-> job1 <-> job2 <-> ... <-> head
pub struct JobList {
head: Box<Job>,
tail: Box<Job>,
}
impl Debug for JobList {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JobList")
.field("head", &self.head)
.field("tail", &self.tail)
.field_with("jobs", |f| {
let mut list = f.debug_list();
// SAFETY: head always has prev
let mut job = unsafe { self.head().as_ref().link_mut().prev.unwrap() };
loop {
if job == self.tail() {
break;
}
let job_ref = unsafe { job.as_ref() };
list.entry(&job_ref);
// SAFETY: we are iterating over the linked list
if let Some(next) = unsafe { job_ref.link_mut().prev } {
job = next;
} else {
tracing::trace!("prev job is none?");
break;
};
}
list.finish()
})
.finish()
}
}
impl JobList {
pub fn new() -> JobList {
let head = Box::new(Job::empty());
@ -743,6 +777,7 @@ mod job {
return JobResult::new(result);
} else {
// spin until lock is released.
tracing::trace!("spin-waiting for job: {:?}", self);
spin.spin();
}
}
@ -779,6 +814,12 @@ mod job {
}
pub fn execute(job: NonNull<Self>) {
tracing::trace!(
"thread {:?}: executing job: {:?}",
std::thread::current().name(),
job
);
// SAFETY: self is non-null
unsafe {
let this = job.as_ref();
@ -990,9 +1031,10 @@ use async_task::Runnable;
use crossbeam::utils::CachePadded;
use job::*;
use parking_lot::{Condvar, Mutex};
use parking_lot_core::SpinWait;
use util::{DropGuard, SendPtr};
use crate::latch::{AtomicLatch, LatchRef, NopLatch};
use crate::latch::{AtomicLatch, LatchRef, NopLatch, Probe};
#[derive(Debug, Default)]
pub struct JobCounter {
@ -1212,11 +1254,12 @@ impl WorkerThread {
if !guard.jobs.contains_key(&self.index) {
if let Some(job) = self.pop_back() {
tracing::trace!("heartbeat: sharing job: {:?}", job);
unsafe {
job.as_ref().set_pending();
}
guard.jobs.insert(self.index, job);
self.context.shared_job.notify_one();
self.context.notify_shared_job();
}
}
@ -1228,69 +1271,72 @@ impl WorkerThread {
// does this optimise?
assert!(!latch.probe());
'outer: while !latch.probe() {
// take the shared job, if it exists
if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) {
self.execute(shared_job);
}
while !latch.probe() {
let mut guard = self.context.shared.lock();
match guard.jobs.pop_first().map(|(_, job)| job) {
Some(job) => {
drop(guard);
self.execute(job);
continue 'outer;
}
None => {
// TODO: spin2win
self.context.shared_job.wait(&mut guard);
}
}
}
}
self.wait_until_predicate(|| latch.probe())
}
pub fn wait_until_latch<Latch: crate::Probe>(&self, latch: &Latch) {
if !latch.probe() {
self.wait_until_latch_cold(latch);
self.wait_until_latch_cold(latch)
}
}
pub fn wait_until_job<T>(&self, job: &Job<T>) -> Option<JobResult<T>> {
// take the shared job and check if it is our job
let shared_job = self.context.shared.lock().jobs.remove(&self.index);
#[inline]
fn wait_until_predicate<F>(&self, pred: F)
where
F: Fn() -> bool,
{
'outer: while !pred() {
// take a shared job, if it exists
if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) {
self.execute(shared_job);
}
if let Some(ptr) = shared_job {
if ptr.as_ptr() == &*job as *const _ as *mut _ {
// we can more efficiently run the job inline
return None;
} else {
// execute this job since it hasn't been taken
self.execute(ptr);
// process local jobs before locking shared context
while let Some(job) = self.pop_front() {
unsafe {
job.as_ref().set_pending();
}
self.execute(job);
}
while !pred() {
let mut guard = self.context.shared.lock();
let mut spin = SpinWait::new();
match guard.pop_job() {
Some(job) => {
drop(guard);
self.execute(job);
continue 'outer;
}
None => {
// TODO: spin2win
tracing::trace!("waiting for shared job, thread id: {:?}", self.index);
// TODO: wait on latch? if we have something that can
// signal being done, e.g. can be waited on instead of
// shared jobs, we should wait on it instead, but we
// would also want to receive shared jobs still?
// self.context.shared_job.wait(&mut guard);
// if spin.spin() {
// // wait for more shared jobs.
// // self.context.shared_job.wait(&mut guard);
// return;
// }
std::thread::yield_now();
}
}
}
}
return;
}
while job.state() != JobState::Finished as u8 {
let Some(job) = self
.context
.shared
.lock()
.jobs
.pop_first()
.map(|(_, job)| job)
// .or_else(|| {
// self.pop_front().inspect(|job| unsafe {
// job.as_ref().set_pending();
// })
// })
else {
break;
};
self.execute(job);
}
pub fn wait_until_job<T>(&self, job: &Job<T>) -> Option<JobResult<T>> {
self.wait_until_predicate(|| {
// check if job is finished
job.state() == JobState::Finished as u8
});
// someone else has this job and is working on it,
// while job isn't done, suspend thread.
@ -1308,10 +1354,14 @@ where
impl<'scope> Scope<'scope> {
fn wait_for_jobs(&self) {
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
let thread = WorkerThread::current_ref().unwrap();
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
tracing::trace!("thread id: {:?}, jobs: {:?}", thread.index, unsafe {
thread.queue.as_ref_unchecked()
});
thread.wait_until_latch(&self.job_counter);
unsafe { self.job_counter.wait() };
}
pub fn scope<F, R>(f: F) -> R
@ -1325,7 +1375,7 @@ impl<'scope> Scope<'scope> {
})
}
pub fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R
fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R
where
F: FnOnce(&Self) -> R + Send,
R: Send,
@ -1416,7 +1466,7 @@ impl<'scope> Scope<'scope> {
where
F: FnOnce(&Scope<'scope>) + Send,
{
WorkerThread::with_in(&self.context, |worker| {
self.context.run_in_worker(|worker| {
self.job_counter.increment();
let this = SendPtr::new_const(self).unwrap();
@ -1522,13 +1572,14 @@ impl<'scope> Scope<'scope> {
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
// let count = self.join_count.get();
// self.join_count.set(count.wrapping_add(1) % TIMES);
let count = self
.join_count
.update(Ordering::Relaxed, Ordering::Relaxed, |n| {
n.wrapping_add(1) % TIMES
});
let count = self.join_count.load(Ordering::Relaxed);
self.join_count
.store(count.wrapping_add(1) % TIMES, Ordering::Relaxed);
// let count = self
// .join_count
// .update(Ordering::Relaxed, Ordering::Relaxed, |n| {
// n.wrapping_add(1) % TIMES
// });
if count == 1 {
self.join_heartbeat(a, b)
@ -1562,13 +1613,23 @@ impl<'scope> Scope<'scope> {
let job = a.as_job();
worker.push_front(&job);
let rb = b(self);
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
Ok(val) => val,
Err(payload) => {
cold_path();
// if b panicked, we need to wait for a to finish
worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() });
resume_unwind(payload);
}
};
let ra = if job.state() == JobState::Empty as u8 {
unsafe {
job.unlink();
}
// a is allowed to panic here, because we already finished b.
unsafe { a.unwrap()() }
} else {
match worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) {
@ -1592,16 +1653,17 @@ impl<'scope> Scope<'scope> {
}
}
// #[allow(dead_code)]
// pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
// where
// RA: Send,
// RB: Send,
// A: FnOnce() -> RA + Send,
// B: FnOnce() -> RB + Send,
// {
// Scope::with(|scope| scope.join(|_| a(), |_| b()))
// }
/// run two closures potentially in parallel, in the global threadpool.
#[allow(dead_code)]
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
Scope::scope(|scope| scope.join(|_| a(), |_| b()))
}
pub struct ThreadPool {
context: Arc<Context>,
@ -1648,7 +1710,7 @@ unsafe impl Send for SharedContext {}
impl SharedContext {
fn new_heartbeat(&mut self) -> (Arc<CachePadded<AtomicBool>>, usize) {
let index = self.heartbeats_id;
self.heartbeats_id.checked_add(1).unwrap();
self.heartbeats_id = self.heartbeats_id.checked_add(1).unwrap();
let is_set = Arc::new(CachePadded::new(AtomicBool::new(false)));
let weak = Arc::downgrade(&is_set);
@ -1693,14 +1755,21 @@ impl Context {
// let num_threads = 2;
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1));
for _ in 0..num_threads {
for i in 0..num_threads {
let ctx = this.clone();
let barrier = barrier.clone();
std::thread::spawn(|| worker(ctx, barrier));
std::thread::Builder::new()
.name(format!("{:?}-worker-{}", Arc::as_ptr(&this), i))
.spawn(|| worker(ctx, barrier))
.expect("Failed to spawn worker thread");
}
let ctx = this.clone();
std::thread::spawn(|| heartbeat_worker(ctx));
std::thread::Builder::new()
.name(format!("{:?}-heartbeat", Arc::as_ptr(&this)))
.spawn(|| heartbeat_worker(ctx))
.expect("Failed to spawn heartbeat thread");
barrier.wait();
@ -1714,10 +1783,14 @@ impl Context {
pub fn inject_job(&self, job: NonNull<Job>) {
let mut guard = self.shared.lock();
guard.injected_jobs.push(job);
self.notify_shared_job();
}
fn notify_shared_job(&self) {
self.shared_job.notify_one();
}
/// Runs closure in this context, processing the worker's jobs while waiting for the result.
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
where
F: FnOnce(&WorkerThread) -> T + Send,
@ -1738,8 +1811,10 @@ impl Context {
);
let job = job.as_job();
job.set_pending();
self.inject_job(Into::into(&job));
// no need to wait for latch to signal, because we're waiting on the job anyway
worker.wait_until_latch(&latch);
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
@ -1769,6 +1844,7 @@ impl Context {
);
let job = job.as_job();
job.set_pending();
self.inject_job(Into::into(&job));
latch.wait();
@ -1788,12 +1864,19 @@ impl Context {
Some(worker) => {
// check if worker is in the same context
if Arc::ptr_eq(&worker.context, self) {
tracing::trace!("run_in_worker: current thread");
f(worker)
} else {
// current thread is a worker for a different context
tracing::trace!("run_in_worker: cross-context");
self.run_in_worker_cross(worker, f)
}
}
None => self.run_in_worker_cold(f),
None => {
// current thread is not a worker for any context
tracing::trace!("run_in_worker: inject into context");
self.run_in_worker_cold(f)
}
}
}
}
@ -1826,18 +1909,18 @@ fn worker(ctx: Arc<Context>, barrier: Arc<std::sync::Barrier>) {
}
let _guard = DropGuard::new(|| unsafe {
tracing::trace!("worker thread dropping {:?}", std::thread::current());
WorkerThread::drop_in_place_and_dealloc(WorkerThread::unset_current().unwrap());
});
let scope = WorkerThread::current_ref().unwrap();
let worker = WorkerThread::current_ref().unwrap();
barrier.wait();
let mut job = ctx.shared.lock().pop_job();
loop {
tracing::trace!("worker({:?}): new job {:?}", std::thread::current(), job);
if let Some(job) = job {
scope.execute(job);
worker.execute(job);
}
let mut guard = ctx.shared.lock();

View file

@ -3,6 +3,8 @@ use std::{
pin::{pin, Pin},
};
use tracing_test::traced_test;
use super::{util::TaggedAtomicPtr, *};
fn pin_ptr<T>(pin: &Pin<&mut T>) -> NonNull<T> {
@ -446,6 +448,35 @@ fn join() {
eprintln!("x: {x}");
}
#[test]
#[traced_test]
fn join_many() {
use crate::util::tree::{Tree, TREE_SIZE};
let pool = ThreadPool::new();
let tree = Tree::new(16, 1u32);
fn sum(tree: &Tree<u32>, node: usize, scope: &Scope) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join_heartbeat(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
// eprintln!("node: {node:?}, l: {l}, r: {r}");
node.leaf + l + r
}
let sum = pool.scope(|s| sum(&tree, tree.root().unwrap(), s));
eprintln!("{sum}");
}
#[test]
fn rebox() {
struct A(u32);

View file

@ -65,3 +65,71 @@ impl XorShift64Star {
(self.next() % n as u64) as usize
}
}
#[macro_export]
macro_rules! cfg_miri {
(
@miri => { $($tokens:tt)* }$(,)?
_ => { $($tokens2:tt)* }
) => {
#[cfg(miri)]
{
$($tokens)*
}
#[cfg(not(miri))]
{
$($tokens2)*
}
};
}
pub mod tree {
pub struct Tree<T> {
nodes: Box<[Node<T>]>,
root: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct Node<T> {
pub leaf: T,
pub left: Option<usize>,
pub right: Option<usize>,
}
impl<T> Tree<T> {
pub fn new(depth: usize, t: T) -> Tree<T>
where
T: Copy,
{
let mut nodes = Vec::with_capacity((0..depth).sum());
let root = Self::build_node(&mut nodes, depth, t);
Self {
nodes: nodes.into_boxed_slice(),
root: Some(root),
}
}
pub fn root(&self) -> Option<usize> {
self.root
}
pub fn get(&self, index: usize) -> &Node<T> {
&self.nodes[index]
}
pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
where
T: Copy,
{
let node = Node {
leaf: t,
left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
};
nodes.push(node);
nodes.len() - 1
}
}
pub const TREE_SIZE: usize = 16;
}