does stuff, doesn't deadlock, faster than rayon (maybe?),
This commit is contained in:
parent
d611535994
commit
9b0cc41834
|
@ -56,7 +56,7 @@ mod tree {
|
|||
}
|
||||
}
|
||||
|
||||
const TREE_SIZE: usize = 16;
|
||||
const TREE_SIZE: usize = 13;
|
||||
|
||||
#[bench]
|
||||
fn join_melange(b: &mut Bencher) {
|
||||
|
@ -93,6 +93,7 @@ fn join_melange(b: &mut Bencher) {
|
|||
|
||||
#[bench]
|
||||
fn join_praetor(b: &mut Bencher) {
|
||||
tracing_subscriber::fmt().with_test_writer().init();
|
||||
use executor::praetor::Scope;
|
||||
let pool = executor::praetor::ThreadPool::global();
|
||||
|
||||
|
@ -121,6 +122,7 @@ fn join_praetor(b: &mut Bencher) {
|
|||
|
||||
#[bench]
|
||||
fn join_sync(b: &mut Bencher) {
|
||||
tracing_subscriber::fmt().with_test_writer().init();
|
||||
let tree = tree::Tree::new(TREE_SIZE, 1u32);
|
||||
|
||||
fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
|
||||
|
|
|
@ -426,12 +426,46 @@ mod job {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
// for some reason I confused head and tail here and the list is something like this:
|
||||
// tail <-> job1 <-> job2 <-> ... <-> head
|
||||
pub struct JobList {
|
||||
head: Box<Job>,
|
||||
tail: Box<Job>,
|
||||
}
|
||||
|
||||
impl Debug for JobList {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("JobList")
|
||||
.field("head", &self.head)
|
||||
.field("tail", &self.tail)
|
||||
.field_with("jobs", |f| {
|
||||
let mut list = f.debug_list();
|
||||
|
||||
// SAFETY: head always has prev
|
||||
let mut job = unsafe { self.head().as_ref().link_mut().prev.unwrap() };
|
||||
loop {
|
||||
if job == self.tail() {
|
||||
break;
|
||||
}
|
||||
|
||||
let job_ref = unsafe { job.as_ref() };
|
||||
list.entry(&job_ref);
|
||||
|
||||
// SAFETY: we are iterating over the linked list
|
||||
if let Some(next) = unsafe { job_ref.link_mut().prev } {
|
||||
job = next;
|
||||
} else {
|
||||
tracing::trace!("prev job is none?");
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
list.finish()
|
||||
})
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl JobList {
|
||||
pub fn new() -> JobList {
|
||||
let head = Box::new(Job::empty());
|
||||
|
@ -743,6 +777,7 @@ mod job {
|
|||
return JobResult::new(result);
|
||||
} else {
|
||||
// spin until lock is released.
|
||||
tracing::trace!("spin-waiting for job: {:?}", self);
|
||||
spin.spin();
|
||||
}
|
||||
}
|
||||
|
@ -779,6 +814,12 @@ mod job {
|
|||
}
|
||||
|
||||
pub fn execute(job: NonNull<Self>) {
|
||||
tracing::trace!(
|
||||
"thread {:?}: executing job: {:?}",
|
||||
std::thread::current().name(),
|
||||
job
|
||||
);
|
||||
|
||||
// SAFETY: self is non-null
|
||||
unsafe {
|
||||
let this = job.as_ref();
|
||||
|
@ -990,9 +1031,10 @@ use async_task::Runnable;
|
|||
use crossbeam::utils::CachePadded;
|
||||
use job::*;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use parking_lot_core::SpinWait;
|
||||
use util::{DropGuard, SendPtr};
|
||||
|
||||
use crate::latch::{AtomicLatch, LatchRef, NopLatch};
|
||||
use crate::latch::{AtomicLatch, LatchRef, NopLatch, Probe};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct JobCounter {
|
||||
|
@ -1212,11 +1254,12 @@ impl WorkerThread {
|
|||
|
||||
if !guard.jobs.contains_key(&self.index) {
|
||||
if let Some(job) = self.pop_back() {
|
||||
tracing::trace!("heartbeat: sharing job: {:?}", job);
|
||||
unsafe {
|
||||
job.as_ref().set_pending();
|
||||
}
|
||||
guard.jobs.insert(self.index, job);
|
||||
self.context.shared_job.notify_one();
|
||||
self.context.notify_shared_job();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1228,69 +1271,72 @@ impl WorkerThread {
|
|||
// does this optimise?
|
||||
assert!(!latch.probe());
|
||||
|
||||
'outer: while !latch.probe() {
|
||||
// take the shared job, if it exists
|
||||
if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) {
|
||||
self.execute(shared_job);
|
||||
}
|
||||
|
||||
while !latch.probe() {
|
||||
let mut guard = self.context.shared.lock();
|
||||
|
||||
match guard.jobs.pop_first().map(|(_, job)| job) {
|
||||
Some(job) => {
|
||||
drop(guard);
|
||||
self.execute(job);
|
||||
continue 'outer;
|
||||
}
|
||||
None => {
|
||||
// TODO: spin2win
|
||||
self.context.shared_job.wait(&mut guard);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
self.wait_until_predicate(|| latch.probe())
|
||||
}
|
||||
|
||||
pub fn wait_until_latch<Latch: crate::Probe>(&self, latch: &Latch) {
|
||||
if !latch.probe() {
|
||||
self.wait_until_latch_cold(latch);
|
||||
self.wait_until_latch_cold(latch)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait_until_job<T>(&self, job: &Job<T>) -> Option<JobResult<T>> {
|
||||
// take the shared job and check if it is our job
|
||||
let shared_job = self.context.shared.lock().jobs.remove(&self.index);
|
||||
#[inline]
|
||||
fn wait_until_predicate<F>(&self, pred: F)
|
||||
where
|
||||
F: Fn() -> bool,
|
||||
{
|
||||
'outer: while !pred() {
|
||||
// take a shared job, if it exists
|
||||
if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) {
|
||||
self.execute(shared_job);
|
||||
}
|
||||
|
||||
if let Some(ptr) = shared_job {
|
||||
if ptr.as_ptr() == &*job as *const _ as *mut _ {
|
||||
// we can more efficiently run the job inline
|
||||
return None;
|
||||
} else {
|
||||
// execute this job since it hasn't been taken
|
||||
self.execute(ptr);
|
||||
// process local jobs before locking shared context
|
||||
while let Some(job) = self.pop_front() {
|
||||
unsafe {
|
||||
job.as_ref().set_pending();
|
||||
}
|
||||
self.execute(job);
|
||||
}
|
||||
|
||||
while !pred() {
|
||||
let mut guard = self.context.shared.lock();
|
||||
let mut spin = SpinWait::new();
|
||||
|
||||
match guard.pop_job() {
|
||||
Some(job) => {
|
||||
drop(guard);
|
||||
self.execute(job);
|
||||
|
||||
continue 'outer;
|
||||
}
|
||||
None => {
|
||||
// TODO: spin2win
|
||||
tracing::trace!("waiting for shared job, thread id: {:?}", self.index);
|
||||
|
||||
// TODO: wait on latch? if we have something that can
|
||||
// signal being done, e.g. can be waited on instead of
|
||||
// shared jobs, we should wait on it instead, but we
|
||||
// would also want to receive shared jobs still?
|
||||
// self.context.shared_job.wait(&mut guard);
|
||||
// if spin.spin() {
|
||||
// // wait for more shared jobs.
|
||||
// // self.context.shared_job.wait(&mut guard);
|
||||
// return;
|
||||
// }
|
||||
std::thread::yield_now();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
while job.state() != JobState::Finished as u8 {
|
||||
let Some(job) = self
|
||||
.context
|
||||
.shared
|
||||
.lock()
|
||||
.jobs
|
||||
.pop_first()
|
||||
.map(|(_, job)| job)
|
||||
// .or_else(|| {
|
||||
// self.pop_front().inspect(|job| unsafe {
|
||||
// job.as_ref().set_pending();
|
||||
// })
|
||||
// })
|
||||
else {
|
||||
break;
|
||||
};
|
||||
|
||||
self.execute(job);
|
||||
}
|
||||
pub fn wait_until_job<T>(&self, job: &Job<T>) -> Option<JobResult<T>> {
|
||||
self.wait_until_predicate(|| {
|
||||
// check if job is finished
|
||||
job.state() == JobState::Finished as u8
|
||||
});
|
||||
|
||||
// someone else has this job and is working on it,
|
||||
// while job isn't done, suspend thread.
|
||||
|
@ -1308,10 +1354,14 @@ where
|
|||
|
||||
impl<'scope> Scope<'scope> {
|
||||
fn wait_for_jobs(&self) {
|
||||
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
|
||||
|
||||
let thread = WorkerThread::current_ref().unwrap();
|
||||
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
|
||||
tracing::trace!("thread id: {:?}, jobs: {:?}", thread.index, unsafe {
|
||||
thread.queue.as_ref_unchecked()
|
||||
});
|
||||
|
||||
thread.wait_until_latch(&self.job_counter);
|
||||
unsafe { self.job_counter.wait() };
|
||||
}
|
||||
|
||||
pub fn scope<F, R>(f: F) -> R
|
||||
|
@ -1325,7 +1375,7 @@ impl<'scope> Scope<'scope> {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R
|
||||
fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&Self) -> R + Send,
|
||||
R: Send,
|
||||
|
@ -1416,7 +1466,7 @@ impl<'scope> Scope<'scope> {
|
|||
where
|
||||
F: FnOnce(&Scope<'scope>) + Send,
|
||||
{
|
||||
WorkerThread::with_in(&self.context, |worker| {
|
||||
self.context.run_in_worker(|worker| {
|
||||
self.job_counter.increment();
|
||||
|
||||
let this = SendPtr::new_const(self).unwrap();
|
||||
|
@ -1522,13 +1572,14 @@ impl<'scope> Scope<'scope> {
|
|||
A: FnOnce(&Self) -> RA + Send,
|
||||
B: FnOnce(&Self) -> RB + Send,
|
||||
{
|
||||
// let count = self.join_count.get();
|
||||
// self.join_count.set(count.wrapping_add(1) % TIMES);
|
||||
let count = self
|
||||
.join_count
|
||||
.update(Ordering::Relaxed, Ordering::Relaxed, |n| {
|
||||
n.wrapping_add(1) % TIMES
|
||||
});
|
||||
let count = self.join_count.load(Ordering::Relaxed);
|
||||
self.join_count
|
||||
.store(count.wrapping_add(1) % TIMES, Ordering::Relaxed);
|
||||
// let count = self
|
||||
// .join_count
|
||||
// .update(Ordering::Relaxed, Ordering::Relaxed, |n| {
|
||||
// n.wrapping_add(1) % TIMES
|
||||
// });
|
||||
|
||||
if count == 1 {
|
||||
self.join_heartbeat(a, b)
|
||||
|
@ -1562,13 +1613,23 @@ impl<'scope> Scope<'scope> {
|
|||
let job = a.as_job();
|
||||
worker.push_front(&job);
|
||||
|
||||
let rb = b(self);
|
||||
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
|
||||
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
||||
Ok(val) => val,
|
||||
Err(payload) => {
|
||||
cold_path();
|
||||
// if b panicked, we need to wait for a to finish
|
||||
worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() });
|
||||
resume_unwind(payload);
|
||||
}
|
||||
};
|
||||
|
||||
let ra = if job.state() == JobState::Empty as u8 {
|
||||
unsafe {
|
||||
job.unlink();
|
||||
}
|
||||
|
||||
// a is allowed to panic here, because we already finished b.
|
||||
unsafe { a.unwrap()() }
|
||||
} else {
|
||||
match worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) {
|
||||
|
@ -1592,16 +1653,17 @@ impl<'scope> Scope<'scope> {
|
|||
}
|
||||
}
|
||||
|
||||
// #[allow(dead_code)]
|
||||
// pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
||||
// where
|
||||
// RA: Send,
|
||||
// RB: Send,
|
||||
// A: FnOnce() -> RA + Send,
|
||||
// B: FnOnce() -> RB + Send,
|
||||
// {
|
||||
// Scope::with(|scope| scope.join(|_| a(), |_| b()))
|
||||
// }
|
||||
/// run two closures potentially in parallel, in the global threadpool.
|
||||
#[allow(dead_code)]
|
||||
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
{
|
||||
Scope::scope(|scope| scope.join(|_| a(), |_| b()))
|
||||
}
|
||||
|
||||
pub struct ThreadPool {
|
||||
context: Arc<Context>,
|
||||
|
@ -1648,7 +1710,7 @@ unsafe impl Send for SharedContext {}
|
|||
impl SharedContext {
|
||||
fn new_heartbeat(&mut self) -> (Arc<CachePadded<AtomicBool>>, usize) {
|
||||
let index = self.heartbeats_id;
|
||||
self.heartbeats_id.checked_add(1).unwrap();
|
||||
self.heartbeats_id = self.heartbeats_id.checked_add(1).unwrap();
|
||||
|
||||
let is_set = Arc::new(CachePadded::new(AtomicBool::new(false)));
|
||||
let weak = Arc::downgrade(&is_set);
|
||||
|
@ -1693,14 +1755,21 @@ impl Context {
|
|||
// let num_threads = 2;
|
||||
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1));
|
||||
|
||||
for _ in 0..num_threads {
|
||||
for i in 0..num_threads {
|
||||
let ctx = this.clone();
|
||||
let barrier = barrier.clone();
|
||||
std::thread::spawn(|| worker(ctx, barrier));
|
||||
std::thread::Builder::new()
|
||||
.name(format!("{:?}-worker-{}", Arc::as_ptr(&this), i))
|
||||
.spawn(|| worker(ctx, barrier))
|
||||
.expect("Failed to spawn worker thread");
|
||||
}
|
||||
|
||||
let ctx = this.clone();
|
||||
std::thread::spawn(|| heartbeat_worker(ctx));
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name(format!("{:?}-heartbeat", Arc::as_ptr(&this)))
|
||||
.spawn(|| heartbeat_worker(ctx))
|
||||
.expect("Failed to spawn heartbeat thread");
|
||||
|
||||
barrier.wait();
|
||||
|
||||
|
@ -1714,10 +1783,14 @@ impl Context {
|
|||
pub fn inject_job(&self, job: NonNull<Job>) {
|
||||
let mut guard = self.shared.lock();
|
||||
guard.injected_jobs.push(job);
|
||||
self.notify_shared_job();
|
||||
}
|
||||
|
||||
fn notify_shared_job(&self) {
|
||||
self.shared_job.notify_one();
|
||||
}
|
||||
|
||||
/// Runs closure in this context, processing the worker's jobs while waiting for the result.
|
||||
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
||||
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
|
@ -1738,8 +1811,10 @@ impl Context {
|
|||
);
|
||||
|
||||
let job = job.as_job();
|
||||
job.set_pending();
|
||||
|
||||
self.inject_job(Into::into(&job));
|
||||
// no need to wait for latch to signal, because we're waiting on the job anyway
|
||||
worker.wait_until_latch(&latch);
|
||||
|
||||
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
|
||||
|
@ -1769,6 +1844,7 @@ impl Context {
|
|||
);
|
||||
|
||||
let job = job.as_job();
|
||||
job.set_pending();
|
||||
|
||||
self.inject_job(Into::into(&job));
|
||||
latch.wait();
|
||||
|
@ -1788,12 +1864,19 @@ impl Context {
|
|||
Some(worker) => {
|
||||
// check if worker is in the same context
|
||||
if Arc::ptr_eq(&worker.context, self) {
|
||||
tracing::trace!("run_in_worker: current thread");
|
||||
f(worker)
|
||||
} else {
|
||||
// current thread is a worker for a different context
|
||||
tracing::trace!("run_in_worker: cross-context");
|
||||
self.run_in_worker_cross(worker, f)
|
||||
}
|
||||
}
|
||||
None => self.run_in_worker_cold(f),
|
||||
None => {
|
||||
// current thread is not a worker for any context
|
||||
tracing::trace!("run_in_worker: inject into context");
|
||||
self.run_in_worker_cold(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1826,18 +1909,18 @@ fn worker(ctx: Arc<Context>, barrier: Arc<std::sync::Barrier>) {
|
|||
}
|
||||
|
||||
let _guard = DropGuard::new(|| unsafe {
|
||||
tracing::trace!("worker thread dropping {:?}", std::thread::current());
|
||||
WorkerThread::drop_in_place_and_dealloc(WorkerThread::unset_current().unwrap());
|
||||
});
|
||||
|
||||
let scope = WorkerThread::current_ref().unwrap();
|
||||
let worker = WorkerThread::current_ref().unwrap();
|
||||
|
||||
barrier.wait();
|
||||
|
||||
let mut job = ctx.shared.lock().pop_job();
|
||||
loop {
|
||||
tracing::trace!("worker({:?}): new job {:?}", std::thread::current(), job);
|
||||
if let Some(job) = job {
|
||||
scope.execute(job);
|
||||
worker.execute(job);
|
||||
}
|
||||
|
||||
let mut guard = ctx.shared.lock();
|
||||
|
|
|
@ -3,6 +3,8 @@ use std::{
|
|||
pin::{pin, Pin},
|
||||
};
|
||||
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::{util::TaggedAtomicPtr, *};
|
||||
|
||||
fn pin_ptr<T>(pin: &Pin<&mut T>) -> NonNull<T> {
|
||||
|
@ -446,6 +448,35 @@ fn join() {
|
|||
eprintln!("x: {x}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[traced_test]
|
||||
fn join_many() {
|
||||
use crate::util::tree::{Tree, TREE_SIZE};
|
||||
|
||||
let pool = ThreadPool::new();
|
||||
|
||||
let tree = Tree::new(16, 1u32);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize, scope: &Scope) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = scope.join_heartbeat(
|
||||
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
||||
|s| {
|
||||
node.right
|
||||
.map(|node| sum(tree, node, s))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
);
|
||||
|
||||
// eprintln!("node: {node:?}, l: {l}, r: {r}");
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
let sum = pool.scope(|s| sum(&tree, tree.root().unwrap(), s));
|
||||
eprintln!("{sum}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rebox() {
|
||||
struct A(u32);
|
||||
|
|
68
src/util.rs
68
src/util.rs
|
@ -65,3 +65,71 @@ impl XorShift64Star {
|
|||
(self.next() % n as u64) as usize
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! cfg_miri {
|
||||
(
|
||||
@miri => { $($tokens:tt)* }$(,)?
|
||||
_ => { $($tokens2:tt)* }
|
||||
) => {
|
||||
#[cfg(miri)]
|
||||
{
|
||||
$($tokens)*
|
||||
}
|
||||
#[cfg(not(miri))]
|
||||
{
|
||||
$($tokens2)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub mod tree {
|
||||
|
||||
pub struct Tree<T> {
|
||||
nodes: Box<[Node<T>]>,
|
||||
root: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Node<T> {
|
||||
pub leaf: T,
|
||||
pub left: Option<usize>,
|
||||
pub right: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T> Tree<T> {
|
||||
pub fn new(depth: usize, t: T) -> Tree<T>
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
let mut nodes = Vec::with_capacity((0..depth).sum());
|
||||
let root = Self::build_node(&mut nodes, depth, t);
|
||||
Self {
|
||||
nodes: nodes.into_boxed_slice(),
|
||||
root: Some(root),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn root(&self) -> Option<usize> {
|
||||
self.root
|
||||
}
|
||||
|
||||
pub fn get(&self, index: usize) -> &Node<T> {
|
||||
&self.nodes[index]
|
||||
}
|
||||
|
||||
pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
let node = Node {
|
||||
leaf: t,
|
||||
left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
|
||||
right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
|
||||
};
|
||||
nodes.push(node);
|
||||
nodes.len() - 1
|
||||
}
|
||||
}
|
||||
pub const TREE_SIZE: usize = 16;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue