move join stuff to context, but should be moved to workerthread
This commit is contained in:
parent
f6f8095440
commit
a2112b9ef5
|
@ -1081,6 +1081,7 @@ struct WorkerThread {
|
||||||
index: usize,
|
index: usize,
|
||||||
heartbeat: Arc<CachePadded<AtomicBool>>,
|
heartbeat: Arc<CachePadded<AtomicBool>>,
|
||||||
queue: UnsafeCell<JobList>,
|
queue: UnsafeCell<JobList>,
|
||||||
|
join_count: Cell<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Scope<'scope> {
|
pub struct Scope<'scope> {
|
||||||
|
@ -1115,7 +1116,7 @@ impl WorkerThread {
|
||||||
index,
|
index,
|
||||||
heartbeat,
|
heartbeat,
|
||||||
queue: UnsafeCell::new(JobList::new()),
|
queue: UnsafeCell::new(JobList::new()),
|
||||||
// join_count: Cell::new(0),
|
join_count: Cell::new(0),
|
||||||
// _pd: PhantomData,
|
// _pd: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1505,109 +1506,28 @@ impl<'scope> Scope<'scope> {
|
||||||
A: FnOnce(&Self) -> RA + Send,
|
A: FnOnce(&Self) -> RA + Send,
|
||||||
B: FnOnce(&Self) -> RB + Send,
|
B: FnOnce(&Self) -> RB + Send,
|
||||||
{
|
{
|
||||||
self.join_heartbeat_every::<_, _, _, _, 64>(a, b)
|
#[inline(always)]
|
||||||
// self.join_heartbeat(a, b)
|
fn make_scope_closure<'scope, A, RA>(
|
||||||
}
|
this: SendPtr<Scope<'scope>>,
|
||||||
|
a: A,
|
||||||
pub fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
) -> impl FnOnce() -> RA + use<'scope, RA, A>
|
||||||
where
|
where
|
||||||
|
A: FnOnce(&Scope<'scope>) -> RA + Send,
|
||||||
RA: Send,
|
RA: Send,
|
||||||
RB: Send,
|
|
||||||
A: FnOnce(&Self) -> RA + Send,
|
|
||||||
B: FnOnce(&Self) -> RB + Send,
|
|
||||||
{
|
{
|
||||||
let rb = b(&self);
|
let scope = unsafe { this.as_ref() };
|
||||||
let ra = a(&self);
|
move || a(scope)
|
||||||
|
|
||||||
(ra, rb)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
|
||||||
where
|
|
||||||
RA: Send,
|
|
||||||
RB: Send,
|
|
||||||
A: FnOnce(&Self) -> RA + Send,
|
|
||||||
B: FnOnce(&Self) -> RB + Send,
|
|
||||||
{
|
|
||||||
thread_local! {
|
|
||||||
static JOIN_COUNT: Cell<usize> = Cell::new(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// a threadlocal counter is much faster than a sync atomic counter
|
|
||||||
let count = JOIN_COUNT.with(|count| {
|
|
||||||
count.set(count.get().wrapping_add(1) % TIMES);
|
|
||||||
count.get()
|
|
||||||
});
|
|
||||||
|
|
||||||
// let count = self.join_count.load(Ordering::Relaxed);
|
|
||||||
// self.join_count
|
|
||||||
// .store(count.wrapping_add(1) % TIMES, Ordering::Relaxed);
|
|
||||||
|
|
||||||
// let count = self
|
|
||||||
// .join_count
|
|
||||||
// .update(Ordering::Relaxed, Ordering::Relaxed, |n| {
|
|
||||||
// n.wrapping_add(1) % TIMES
|
|
||||||
// });
|
|
||||||
|
|
||||||
if count == 1 {
|
|
||||||
self.join_heartbeat(a, b)
|
|
||||||
} else {
|
|
||||||
self.join_seq(a, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
||||||
where
|
|
||||||
RA: Send,
|
|
||||||
RB: Send,
|
|
||||||
A: FnOnce(&Self) -> RA + Send,
|
|
||||||
B: FnOnce(&Self) -> RB + Send,
|
|
||||||
{
|
|
||||||
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
||||||
|
|
||||||
let this = SendPtr::new_const(self).unwrap();
|
let this = SendPtr::new_const(self).unwrap();
|
||||||
let a = StackJob::new(
|
|
||||||
move || unsafe {
|
|
||||||
WorkerThread::current_ref()
|
|
||||||
.expect("stackjob is run in workerthread.")
|
|
||||||
.tick();
|
|
||||||
|
|
||||||
let scope = this.as_ref();
|
|
||||||
a(scope)
|
|
||||||
},
|
|
||||||
NopLatch,
|
|
||||||
);
|
|
||||||
|
|
||||||
let job = a.as_job();
|
|
||||||
worker.push_front(&job);
|
|
||||||
|
|
||||||
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
|
|
||||||
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
|
||||||
Ok(val) => val,
|
|
||||||
Err(payload) => {
|
|
||||||
cold_path();
|
|
||||||
// if b panicked, we need to wait for a to finish
|
|
||||||
worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() });
|
|
||||||
resume_unwind(payload);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let ra = if job.state() == JobState::Empty as u8 {
|
|
||||||
unsafe {
|
unsafe {
|
||||||
job.unlink();
|
worker.context.join_heartbeat_every::<_, _, _, _, 64>(
|
||||||
|
make_scope_closure(this, a),
|
||||||
|
make_scope_closure(this, b),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// a is allowed to panic here, because we already finished b.
|
|
||||||
unsafe { a.unwrap()() }
|
|
||||||
} else {
|
|
||||||
match worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) {
|
|
||||||
Some(t) => t.into_result(), // propagate panic here
|
|
||||||
None => unsafe { a.unwrap()() },
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
drop(a);
|
|
||||||
(ra, rb)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_context(ctx: Arc<Context>) -> Self {
|
fn from_context(ctx: Arc<Context>) -> Self {
|
||||||
|
@ -1629,7 +1549,19 @@ where
|
||||||
A: FnOnce() -> RA + Send,
|
A: FnOnce() -> RA + Send,
|
||||||
B: FnOnce() -> RB + Send,
|
B: FnOnce() -> RB + Send,
|
||||||
{
|
{
|
||||||
Scope::scope(|scope| scope.join(|_| a(), |_| b()))
|
join_in(Context::global().clone(), a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// run two closures potentially in parallel, in the global threadpool.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn join_in<A, B, RA, RB>(context: Arc<Context>, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
context.join(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ThreadPool {
|
pub struct ThreadPool {
|
||||||
|
@ -1649,6 +1581,16 @@ impl ThreadPool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
self.context.join(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn scope<'scope, R, F>(&self, f: F) -> R
|
pub fn scope<'scope, R, F>(&self, f: F) -> R
|
||||||
where
|
where
|
||||||
F: FnOnce(&Scope<'scope>) -> R + Send,
|
F: FnOnce(&Scope<'scope>) -> R + Send,
|
||||||
|
@ -1757,6 +1699,112 @@ impl Context {
|
||||||
self.shared_job.notify_one();
|
self.shared_job.notify_one();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn join<A, B, RA, RB>(self: &Arc<Self>, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
||||||
|
self.run_in_worker(move |_| unsafe { self.join_heartbeat_every::<_, _, _, _, 64>(a, b) })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
let rb = b();
|
||||||
|
let ra = a();
|
||||||
|
|
||||||
|
(ra, rb)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function must be called from a worker thread.
|
||||||
|
#[inline]
|
||||||
|
unsafe fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
||||||
|
|
||||||
|
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
||||||
|
worker
|
||||||
|
.join_count
|
||||||
|
.set(worker.join_count.get().wrapping_add(1) % TIMES as u8);
|
||||||
|
|
||||||
|
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
||||||
|
// see: chili
|
||||||
|
if worker.join_count.get() == 0 {
|
||||||
|
self.join_heartbeat(a, b)
|
||||||
|
} else {
|
||||||
|
self.join_seq(a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function must be called from a worker thread.
|
||||||
|
unsafe fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
|
where
|
||||||
|
RA: Send,
|
||||||
|
RB: Send,
|
||||||
|
A: FnOnce() -> RA + Send,
|
||||||
|
B: FnOnce() -> RB + Send,
|
||||||
|
{
|
||||||
|
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
||||||
|
|
||||||
|
let a = StackJob::new(
|
||||||
|
move || {
|
||||||
|
// TODO: bench whether tick'ing here is good.
|
||||||
|
// turns out this actually costs a lot of time, likely because of the thread local check.
|
||||||
|
// WorkerThread::current_ref()
|
||||||
|
// .expect("stackjob is run in workerthread.")
|
||||||
|
// .tick();
|
||||||
|
|
||||||
|
a()
|
||||||
|
},
|
||||||
|
NopLatch,
|
||||||
|
);
|
||||||
|
|
||||||
|
let job = a.as_job();
|
||||||
|
worker.push_front(&job);
|
||||||
|
|
||||||
|
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
|
||||||
|
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(payload) => {
|
||||||
|
cold_path();
|
||||||
|
// if b panicked, we need to wait for a to finish
|
||||||
|
worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() });
|
||||||
|
resume_unwind(payload);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let ra = if job.state() == JobState::Empty as u8 {
|
||||||
|
unsafe {
|
||||||
|
job.unlink();
|
||||||
|
}
|
||||||
|
|
||||||
|
// a is allowed to panic here, because we already finished b.
|
||||||
|
unsafe { a.unwrap()() }
|
||||||
|
} else {
|
||||||
|
match worker.wait_until_job::<RA>(unsafe { job.transmute_ref::<RA>() }) {
|
||||||
|
Some(t) => t.into_result(), // propagate panic here
|
||||||
|
None => unsafe { a.unwrap()() },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
drop(a);
|
||||||
|
(ra, rb)
|
||||||
|
}
|
||||||
|
|
||||||
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
||||||
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
|
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
|
||||||
where
|
where
|
||||||
|
|
Loading…
Reference in a new issue