moved join methods to workerthread
This commit is contained in:
parent
a2112b9ef5
commit
448d2d02b4
|
@ -2,7 +2,6 @@
|
||||||
vec_deque_pop_if,
|
vec_deque_pop_if,
|
||||||
unsafe_cell_access,
|
unsafe_cell_access,
|
||||||
debug_closure_helpers,
|
debug_closure_helpers,
|
||||||
cell_update,
|
|
||||||
cold_path,
|
cold_path,
|
||||||
fn_align,
|
fn_align,
|
||||||
box_vec_non_null,
|
box_vec_non_null,
|
||||||
|
|
|
@ -87,7 +87,7 @@ mod util {
|
||||||
// Miri doesn't like tagging pointers that it doesn't know the alignment of.
|
// Miri doesn't like tagging pointers that it doesn't know the alignment of.
|
||||||
// This includes function pointers, which aren't guaranteed to be aligned to
|
// This includes function pointers, which aren't guaranteed to be aligned to
|
||||||
// anything, but generally have an alignment of 8, and can be specified to
|
// anything, but generally have an alignment of 8, and can be specified to
|
||||||
// be aligned to `n` with `#[repr(align(n))]`.
|
// be aligned to `n` with `#[align(n)]`.
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
pub struct TaggedAtomicPtr<T, const BITS: usize> {
|
pub struct TaggedAtomicPtr<T, const BITS: usize> {
|
||||||
ptr: AtomicPtr<()>,
|
ptr: AtomicPtr<()>,
|
||||||
|
@ -910,7 +910,7 @@ mod job {
|
||||||
F: FnOnce() -> T + Send,
|
F: FnOnce() -> T + Send,
|
||||||
T: Send,
|
T: Send,
|
||||||
{
|
{
|
||||||
#[repr(align(8))]
|
#[align(8)]
|
||||||
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
|
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
|
||||||
where
|
where
|
||||||
F: FnOnce() -> T + Send,
|
F: FnOnce() -> T + Send,
|
||||||
|
@ -963,7 +963,7 @@ mod job {
|
||||||
F: FnOnce() -> T + Send,
|
F: FnOnce() -> T + Send,
|
||||||
T: Send,
|
T: Send,
|
||||||
{
|
{
|
||||||
#[repr(align(8))]
|
#[align(8)]
|
||||||
unsafe fn harness<F, T, L: Latch>(this: *const (), job: *const Job<()>)
|
unsafe fn harness<F, T, L: Latch>(this: *const (), job: *const Job<()>)
|
||||||
where
|
where
|
||||||
F: FnOnce() -> T + Send,
|
F: FnOnce() -> T + Send,
|
||||||
|
@ -1225,6 +1225,96 @@ impl WorkerThread {
|
||||||
self.heartbeat.store(false, Ordering::Relaxed);
|
self.heartbeat.store(false, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
let rb = b();
|
||||||
|
let ra = a();
|
||||||
|
|
||||||
|
(ra, rb)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function must be called from a worker thread.
|
||||||
|
#[inline]
|
||||||
|
fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
||||||
|
let count = self.join_count.get();
|
||||||
|
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
|
||||||
|
|
||||||
|
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
||||||
|
// see: chili
|
||||||
|
if self.join_count.get() == 1 {
|
||||||
|
self.join_heartbeat(a, b)
|
||||||
|
} else {
|
||||||
|
self.join_seq(a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function must be called from a worker thread.
|
||||||
|
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
|
||||||
|
|
||||||
|
let a = StackJob::new(
|
||||||
|
move || {
|
||||||
|
// TODO: bench whether tick'ing here is good.
|
||||||
|
// turns out this actually costs a lot of time, likely because of the thread local check.
|
||||||
|
// WorkerThread::current_ref()
|
||||||
|
// .expect("stackjob is run in workerthread.")
|
||||||
|
// .tick();
|
||||||
|
|
||||||
|
a()
|
||||||
|
},
|
||||||
|
NopLatch,
|
||||||
|
);
|
||||||
|
|
||||||
|
let job = a.as_job();
|
||||||
|
self.push_front(&job);
|
||||||
|
|
||||||
|
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(payload) => {
|
||||||
|
cold_path();
|
||||||
|
// if b panicked, we need to wait for a to finish
|
||||||
|
self.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() });
|
||||||
|
resume_unwind(payload);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let ra = if job.state() == JobState::Empty as u8 {
|
||||||
|
unsafe {
|
||||||
|
job.unlink();
|
||||||
|
}
|
||||||
|
|
||||||
|
// a is allowed to panic here, because we already finished b.
|
||||||
|
unsafe { a.unwrap()() }
|
||||||
|
} else {
|
||||||
|
match self.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) {
|
||||||
|
Some(t) => t.into_result(), // propagate panic here
|
||||||
|
None => unsafe { a.unwrap()() },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
drop(a);
|
||||||
|
(ra, rb)
|
||||||
|
}
|
||||||
|
|
||||||
#[cold]
|
#[cold]
|
||||||
fn wait_until_latch_cold<Latch: crate::Probe>(&self, latch: &Latch) {
|
fn wait_until_latch_cold<Latch: crate::Probe>(&self, latch: &Latch) {
|
||||||
// does this optimise?
|
// does this optimise?
|
||||||
|
@ -1355,7 +1445,7 @@ impl<'scope> Scope<'scope> {
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn make_job<F: FnOnce() -> T, T>(f: F) -> Job<T> {
|
fn make_job<F: FnOnce() -> T, T>(f: F) -> Job<T> {
|
||||||
#[repr(align(8))]
|
#[align(8)]
|
||||||
unsafe fn harness<F: FnOnce() -> T, T>(this: *const (), job: *const Job<T>) {
|
unsafe fn harness<F: FnOnce() -> T, T>(this: *const (), job: *const Job<T>) {
|
||||||
let f = unsafe { Box::from_raw(this.cast::<F>().cast_mut()) };
|
let f = unsafe { Box::from_raw(this.cast::<F>().cast_mut()) };
|
||||||
|
|
||||||
|
@ -1462,7 +1552,7 @@ impl<'scope> Scope<'scope> {
|
||||||
};
|
};
|
||||||
|
|
||||||
let schedule = move |runnable: Runnable| {
|
let schedule = move |runnable: Runnable| {
|
||||||
#[repr(align(8))]
|
#[align(8)]
|
||||||
unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
|
unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
|
||||||
let runnable =
|
let runnable =
|
||||||
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
|
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
|
||||||
|
@ -1522,12 +1612,10 @@ impl<'scope> Scope<'scope> {
|
||||||
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
||||||
let this = SendPtr::new_const(self).unwrap();
|
let this = SendPtr::new_const(self).unwrap();
|
||||||
|
|
||||||
unsafe {
|
worker.join_heartbeat_every::<_, _, _, _, 64>(
|
||||||
worker.context.join_heartbeat_every::<_, _, _, _, 64>(
|
make_scope_closure(this, a),
|
||||||
make_scope_closure(this, a),
|
make_scope_closure(this, b),
|
||||||
make_scope_closure(this, b),
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_context(ctx: Arc<Context>) -> Self {
|
fn from_context(ctx: Arc<Context>) -> Self {
|
||||||
|
@ -1708,101 +1796,7 @@ impl Context {
|
||||||
B: FnOnce() -> RB + Send,
|
B: FnOnce() -> RB + Send,
|
||||||
{
|
{
|
||||||
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
||||||
self.run_in_worker(move |_| unsafe { self.join_heartbeat_every::<_, _, _, _, 64>(a, b) })
|
self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b))
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
||||||
where
|
|
||||||
RA: Send,
|
|
||||||
RB: Send,
|
|
||||||
A: FnOnce() -> RA + Send,
|
|
||||||
B: FnOnce() -> RB + Send,
|
|
||||||
{
|
|
||||||
let rb = b();
|
|
||||||
let ra = a();
|
|
||||||
|
|
||||||
(ra, rb)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function must be called from a worker thread.
|
|
||||||
#[inline]
|
|
||||||
unsafe fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
|
||||||
where
|
|
||||||
RA: Send,
|
|
||||||
RB: Send,
|
|
||||||
A: FnOnce() -> RA + Send,
|
|
||||||
B: FnOnce() -> RB + Send,
|
|
||||||
{
|
|
||||||
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
|
||||||
|
|
||||||
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
|
||||||
worker
|
|
||||||
.join_count
|
|
||||||
.set(worker.join_count.get().wrapping_add(1) % TIMES as u8);
|
|
||||||
|
|
||||||
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
|
||||||
// see: chili
|
|
||||||
if worker.join_count.get() == 0 {
|
|
||||||
self.join_heartbeat(a, b)
|
|
||||||
} else {
|
|
||||||
self.join_seq(a, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function must be called from a worker thread.
|
|
||||||
unsafe fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
||||||
where
|
|
||||||
RA: Send,
|
|
||||||
RB: Send,
|
|
||||||
A: FnOnce() -> RA + Send,
|
|
||||||
B: FnOnce() -> RB + Send,
|
|
||||||
{
|
|
||||||
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
|
||||||
|
|
||||||
let a = StackJob::new(
|
|
||||||
move || {
|
|
||||||
// TODO: bench whether tick'ing here is good.
|
|
||||||
// turns out this actually costs a lot of time, likely because of the thread local check.
|
|
||||||
// WorkerThread::current_ref()
|
|
||||||
// .expect("stackjob is run in workerthread.")
|
|
||||||
// .tick();
|
|
||||||
|
|
||||||
a()
|
|
||||||
},
|
|
||||||
NopLatch,
|
|
||||||
);
|
|
||||||
|
|
||||||
let job = a.as_job();
|
|
||||||
worker.push_front(&job);
|
|
||||||
|
|
||||||
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
|
|
||||||
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
|
|
||||||
Ok(val) => val,
|
|
||||||
Err(payload) => {
|
|
||||||
cold_path();
|
|
||||||
// if b panicked, we need to wait for a to finish
|
|
||||||
worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() });
|
|
||||||
resume_unwind(payload);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let ra = if job.state() == JobState::Empty as u8 {
|
|
||||||
unsafe {
|
|
||||||
job.unlink();
|
|
||||||
}
|
|
||||||
|
|
||||||
// a is allowed to panic here, because we already finished b.
|
|
||||||
unsafe { a.unwrap()() }
|
|
||||||
} else {
|
|
||||||
match worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) {
|
|
||||||
Some(t) => t.into_result(), // propagate panic here
|
|
||||||
None => unsafe { a.unwrap()() },
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
drop(a);
|
|
||||||
(ra, rb)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
||||||
|
|
|
@ -459,7 +459,7 @@ fn join_many() {
|
||||||
|
|
||||||
fn sum(tree: &Tree<u32>, node: usize, scope: &Scope) -> u32 {
|
fn sum(tree: &Tree<u32>, node: usize, scope: &Scope) -> u32 {
|
||||||
let node = tree.get(node);
|
let node = tree.get(node);
|
||||||
let (l, r) = scope.join_heartbeat(
|
let (l, r) = scope.join(
|
||||||
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
||||||
|s| {
|
|s| {
|
||||||
node.right
|
node.right
|
||||||
|
|
Loading…
Reference in a new issue