LOTS OF CHANGES: but! this works
This commit is contained in:
parent
69d3794ff1
commit
6e4f6a1285
|
@ -24,10 +24,10 @@ fn nodes() -> impl Iterator<Item = (usize, usize)> {
|
||||||
|
|
||||||
#[divan::bench(args = nodes())]
|
#[divan::bench(args = nodes())]
|
||||||
fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||||
fn join_no_overhead<A, B, RA, RB>(scope: &Scope<'_, '_>, a: A, b: B) -> (RA, RB)
|
fn join_no_overhead<A, B, RA, RB>(scope: Scope<'_, '_>, a: A, b: B) -> (RA, RB)
|
||||||
where
|
where
|
||||||
A: FnOnce(&Scope<'_, '_>) -> RA + Send,
|
A: FnOnce(Scope<'_, '_>) -> RA + Send,
|
||||||
B: FnOnce(&Scope<'_, '_>) -> RB + Send,
|
B: FnOnce(Scope<'_, '_>) -> RB + Send,
|
||||||
RA: Send,
|
RA: Send,
|
||||||
RB: Send,
|
RB: Send,
|
||||||
{
|
{
|
||||||
|
@ -35,7 +35,7 @@ fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sum(node: &Node, scope: &Scope<'_, '_>) -> u64 {
|
fn sum(node: &Node, scope: Scope<'_, '_>) -> u64 {
|
||||||
let (left, right) = join_no_overhead(
|
let (left, right) = join_no_overhead(
|
||||||
scope,
|
scope,
|
||||||
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||||
|
@ -57,7 +57,7 @@ fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||||
|
|
||||||
#[divan::bench(args = nodes())]
|
#[divan::bench(args = nodes())]
|
||||||
fn distaff_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
fn distaff_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||||
fn sum<'scope, 'env>(node: &Node, scope: &'scope Scope<'scope, 'env>) -> u64 {
|
fn sum<'scope, 'env>(node: &Node, scope: Scope<'scope, 'env>) -> u64 {
|
||||||
let (left, right) = scope.join(
|
let (left, right) = scope.join(
|
||||||
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||||
|s| node.right.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
|s| node.right.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||||
|
|
|
@ -15,7 +15,6 @@ use crate::{
|
||||||
channel::{Parker, Sender},
|
channel::{Parker, Sender},
|
||||||
heartbeat::HeartbeatList,
|
heartbeat::HeartbeatList,
|
||||||
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
|
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
|
||||||
latch::NopLatch,
|
|
||||||
util::DropGuard,
|
util::DropGuard,
|
||||||
workerthread::{HeartbeatThread, WorkerThread},
|
workerthread::{HeartbeatThread, WorkerThread},
|
||||||
};
|
};
|
||||||
|
@ -142,6 +141,7 @@ impl Context {
|
||||||
.iter()
|
.iter()
|
||||||
.find(|(_, heartbeat)| heartbeat.is_waiting())
|
.find(|(_, heartbeat)| heartbeat.is_waiting())
|
||||||
{
|
{
|
||||||
|
_ = i;
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
tracing::trace!("Notifying worker thread {} about job sharing", i);
|
tracing::trace!("Notifying worker thread {} about job sharing", i);
|
||||||
sender.wake();
|
sender.wake();
|
||||||
|
@ -160,15 +160,7 @@ impl Context {
|
||||||
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
|
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
|
||||||
|
|
||||||
// SAFETY: we are waiting on this latch in this thread.
|
// SAFETY: we are waiting on this latch in this thread.
|
||||||
let job = StackJob::new(
|
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
|
||||||
move || {
|
|
||||||
let worker = WorkerThread::current_ref()
|
|
||||||
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
|
||||||
|
|
||||||
f(worker)
|
|
||||||
},
|
|
||||||
NopLatch,
|
|
||||||
);
|
|
||||||
|
|
||||||
let job = Job::from_stackjob(&job);
|
let job = Job::from_stackjob(&job);
|
||||||
|
|
||||||
|
@ -188,15 +180,7 @@ impl Context {
|
||||||
// current thread isn't a worker thread, create job and inject into context
|
// current thread isn't a worker thread, create job and inject into context
|
||||||
let parker = Parker::new();
|
let parker = Parker::new();
|
||||||
|
|
||||||
let job = StackJob::new(
|
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
|
||||||
move || {
|
|
||||||
let worker = WorkerThread::current_ref()
|
|
||||||
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
|
||||||
|
|
||||||
f(worker)
|
|
||||||
},
|
|
||||||
NopLatch,
|
|
||||||
);
|
|
||||||
|
|
||||||
let job = Job::from_stackjob(&job);
|
let job = Job::from_stackjob(&job);
|
||||||
|
|
||||||
|
@ -247,7 +231,7 @@ impl Context {
|
||||||
where
|
where
|
||||||
F: FnOnce() + Send + 'static,
|
F: FnOnce() + Send + 'static,
|
||||||
{
|
{
|
||||||
let job = Job::from_heapjob(Box::new(HeapJob::new(f)));
|
let job = Job::from_heapjob(HeapJob::new(|_: &WorkerThread| f()));
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
tracing::trace!("Context::spawn: spawning job: {:?}", job);
|
tracing::trace!("Context::spawn: spawning job: {:?}", job);
|
||||||
self.inject_job(job.share(None));
|
self.inject_job(job.share(None));
|
||||||
|
@ -364,19 +348,16 @@ mod tests {
|
||||||
|
|
||||||
let parker = Parker::new();
|
let parker = Parker::new();
|
||||||
|
|
||||||
let job = StackJob::new(
|
let job = StackJob::new({
|
||||||
{
|
|
||||||
let counter = counter.clone();
|
let counter = counter.clone();
|
||||||
move || {
|
move |_: &WorkerThread| {
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
tracing::info!("Job running");
|
tracing::info!("Job running");
|
||||||
counter.fetch_add(1, Ordering::SeqCst);
|
counter.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
42
|
42
|
||||||
}
|
}
|
||||||
},
|
});
|
||||||
NopLatch,
|
|
||||||
);
|
|
||||||
|
|
||||||
let job = Job::from_stackjob(&job);
|
let job = Job::from_stackjob(&job);
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -6,7 +6,6 @@ use std::{hint::cold_path, sync::Arc};
|
||||||
use crate::{
|
use crate::{
|
||||||
context::Context,
|
context::Context,
|
||||||
job::{Job2 as Job, StackJob},
|
job::{Job2 as Job, StackJob},
|
||||||
latch::NopLatch,
|
|
||||||
workerthread::WorkerThread,
|
workerthread::WorkerThread,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,19 +13,19 @@ impl WorkerThread {
|
||||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||||
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
where
|
where
|
||||||
A: FnOnce() -> RA,
|
A: FnOnce(&WorkerThread) -> RA,
|
||||||
B: FnOnce() -> RB,
|
B: FnOnce(&WorkerThread) -> RB,
|
||||||
{
|
{
|
||||||
let rb = b();
|
let rb = b(self);
|
||||||
let ra = a();
|
let ra = a(self);
|
||||||
|
|
||||||
(ra, rb)
|
(ra, rb)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn join_heartbeat_every<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
pub(crate) fn join_heartbeat_every<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
where
|
where
|
||||||
A: FnOnce() -> RA + Send,
|
A: FnOnce(&WorkerThread) -> RA + Send,
|
||||||
B: FnOnce() -> RB,
|
B: FnOnce(&WorkerThread) -> RB,
|
||||||
RA: Send,
|
RA: Send,
|
||||||
{
|
{
|
||||||
// self.join_heartbeat_every_inner::<A, B, RA, RB, 2>(a, b)
|
// self.join_heartbeat_every_inner::<A, B, RA, RB, 2>(a, b)
|
||||||
|
@ -34,12 +33,13 @@ impl WorkerThread {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This function must be called from a worker thread.
|
/// This function must be called from a worker thread.
|
||||||
|
#[allow(dead_code)]
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn join_heartbeat_every_inner<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
fn join_heartbeat_every_inner<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
||||||
where
|
where
|
||||||
RA: Send,
|
RA: Send,
|
||||||
A: FnOnce() -> RA + Send,
|
A: FnOnce(&WorkerThread) -> RA + Send,
|
||||||
B: FnOnce() -> RB,
|
B: FnOnce(&WorkerThread) -> RB,
|
||||||
{
|
{
|
||||||
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
||||||
let count = self.join_count.get();
|
let count = self.join_count.get();
|
||||||
|
@ -63,22 +63,22 @@ impl WorkerThread {
|
||||||
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
where
|
where
|
||||||
RA: Send,
|
RA: Send,
|
||||||
A: FnOnce() -> RA + Send,
|
A: FnOnce(&WorkerThread) -> RA + Send,
|
||||||
B: FnOnce() -> RB,
|
B: FnOnce(&WorkerThread) -> RB,
|
||||||
{
|
{
|
||||||
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
||||||
|
|
||||||
#[cfg(feature = "metrics")]
|
#[cfg(feature = "metrics")]
|
||||||
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
|
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
let a = StackJob::new(a, NopLatch);
|
let a = StackJob::new(a);
|
||||||
let job = Job::from_stackjob(&a);
|
let job = Job::from_stackjob(&a);
|
||||||
|
|
||||||
self.push_back(&job);
|
self.push_back(&job);
|
||||||
|
|
||||||
self.tick();
|
self.tick();
|
||||||
|
|
||||||
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
|
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
||||||
Ok(val) => val,
|
Ok(val) => val,
|
||||||
Err(payload) => {
|
Err(payload) => {
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
|
@ -109,7 +109,7 @@ impl WorkerThread {
|
||||||
);
|
);
|
||||||
// the job was shared, but not yet stolen, so we get to run the
|
// the job was shared, but not yet stolen, so we get to run the
|
||||||
// job inline
|
// job inline
|
||||||
unsafe { a.unwrap()() }
|
unsafe { a.unwrap()(self) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -119,7 +119,7 @@ impl WorkerThread {
|
||||||
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||||
a.unwrap()()
|
a.unwrap()(self)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -136,12 +136,13 @@ impl Context {
|
||||||
RB: Send,
|
RB: Send,
|
||||||
{
|
{
|
||||||
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
||||||
self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _>(a, b))
|
self.run_in_worker(move |worker| {
|
||||||
|
worker.join_heartbeat_every::<_, _, _, _>(|_| a(), |_| b())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// run two closures potentially in parallel, in the global threadpool.
|
/// run two closures potentially in parallel, in the global threadpool.
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
||||||
where
|
where
|
||||||
A: FnOnce() -> RA + Send,
|
A: FnOnce() -> RA + Send,
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
use std::{
|
use std::{
|
||||||
any::Any,
|
any::Any,
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
|
panic::{AssertUnwindSafe, catch_unwind},
|
||||||
|
pin::{self, Pin},
|
||||||
ptr::{self, NonNull},
|
ptr::{self, NonNull},
|
||||||
sync::{
|
sync::{
|
||||||
Arc,
|
Arc,
|
||||||
atomic::{AtomicPtr, Ordering},
|
atomic::{AtomicPtr, AtomicUsize, Ordering},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -30,6 +32,102 @@ use crate::{
|
||||||
// - when a join() job finishes, it's latch is set
|
// - when a join() job finishes, it's latch is set
|
||||||
// - when we wait for a join() job, we loop over the latch until it is set
|
// - when we wait for a join() job, we loop over the latch until it is set
|
||||||
|
|
||||||
|
// a Scope must keep track of:
|
||||||
|
// - The number of async jobs spawned, which is used to determine when the scope
|
||||||
|
// is complete.
|
||||||
|
// - A panic box, which is set when a job panics and is used to resume the panic
|
||||||
|
// when the scope is completed.
|
||||||
|
// - The Parker of the worker on which the scope was created, which is signaled
|
||||||
|
// when the last outstanding async job finishes.
|
||||||
|
// - The current worker thread in order to avoid having to query the
|
||||||
|
// thread-local storage.
|
||||||
|
|
||||||
|
struct ScopeInner {
|
||||||
|
outstanding_jobs: AtomicUsize,
|
||||||
|
parker: NonNull<crate::channel::Parker>,
|
||||||
|
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for ScopeInner {}
|
||||||
|
unsafe impl Sync for ScopeInner {}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct Scope<'scope, 'env: 'scope> {
|
||||||
|
inner: SendPtr<ScopeInner>,
|
||||||
|
worker: SendPtr<WorkerThread>,
|
||||||
|
_scope: PhantomData<&'scope mut &'scope ()>,
|
||||||
|
_env: PhantomData<&'env mut &'env ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScopeInner {
|
||||||
|
fn from_worker(worker: &WorkerThread) -> Self {
|
||||||
|
Self {
|
||||||
|
outstanding_jobs: AtomicUsize::new(0),
|
||||||
|
parker: worker.heartbeat.parker().into(),
|
||||||
|
panic: AtomicPtr::new(ptr::null_mut()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increment(&self) {
|
||||||
|
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrement(&self) {
|
||||||
|
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
|
||||||
|
unsafe {
|
||||||
|
self.parker.as_ref().unpark();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
||||||
|
unsafe {
|
||||||
|
let err = Box::into_raw(Box::new(err));
|
||||||
|
if !self
|
||||||
|
.panic
|
||||||
|
.compare_exchange(ptr::null_mut(), err, Ordering::AcqRel, Ordering::Acquire)
|
||||||
|
.is_ok()
|
||||||
|
{
|
||||||
|
// someone else already set the panic, so we drop the error
|
||||||
|
_ = Box::from_raw(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn maybe_propagate_panic(&self) {
|
||||||
|
let err = self.panic.swap(ptr::null_mut(), Ordering::AcqRel);
|
||||||
|
|
||||||
|
if err.is_null() {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
// SAFETY: we have exclusive access to the panic error, so we can safely resume it.
|
||||||
|
unsafe {
|
||||||
|
let err = *Box::from_raw(err);
|
||||||
|
std::panic::resume_unwind(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||||
|
fn wait_for_jobs(&self) {
|
||||||
|
loop {
|
||||||
|
let count = self.outstanding_jobs.load(Ordering::Relaxed);
|
||||||
|
if count == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
tracing::trace!("waiting for {} jobs to finish.", count);
|
||||||
|
|
||||||
|
// wait until the parker is unparked
|
||||||
|
unsafe {
|
||||||
|
self.parker.as_ref().park();
|
||||||
|
}
|
||||||
|
// parking gives us AcqRel semantics.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// find below a sketch of an unbalanced tree:
|
// find below a sketch of an unbalanced tree:
|
||||||
// []
|
// []
|
||||||
// / \
|
// / \
|
||||||
|
@ -52,7 +150,7 @@ use crate::{
|
||||||
// - another thread sharing a job
|
// - another thread sharing a job
|
||||||
// - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share.
|
// - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share.
|
||||||
|
|
||||||
pub struct Scope<'scope, 'env: 'scope> {
|
pub struct Scope2<'scope, 'env: 'scope> {
|
||||||
// latch to wait on before the scope finishes
|
// latch to wait on before the scope finishes
|
||||||
job_counter: CountLatch,
|
job_counter: CountLatch,
|
||||||
// local threadpool
|
// local threadpool
|
||||||
|
@ -66,7 +164,7 @@ pub struct Scope<'scope, 'env: 'scope> {
|
||||||
|
|
||||||
pub fn scope<'env, F, R>(f: F) -> R
|
pub fn scope<'env, F, R>(f: F) -> R
|
||||||
where
|
where
|
||||||
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
|
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
||||||
R: Send,
|
R: Send,
|
||||||
{
|
{
|
||||||
scope_with_context(Context::global_context(), f)
|
scope_with_context(Context::global_context(), f)
|
||||||
|
@ -74,43 +172,25 @@ where
|
||||||
|
|
||||||
pub fn scope_with_context<'env, F, R>(context: &Arc<Context>, f: F) -> R
|
pub fn scope_with_context<'env, F, R>(context: &Arc<Context>, f: F) -> R
|
||||||
where
|
where
|
||||||
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
|
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
||||||
R: Send,
|
R: Send,
|
||||||
{
|
{
|
||||||
context.run_in_worker(|worker| {
|
context.run_in_worker(|worker| {
|
||||||
// SAFETY: we call complete() after creating this scope, which
|
// SAFETY: we call complete() after creating this scope, which
|
||||||
// ensures that any jobs spawned from the scope exit before the
|
// ensures that any jobs spawned from the scope exit before the
|
||||||
// scope closes.
|
// scope closes.
|
||||||
let this = unsafe { Scope::from_context(context.clone()) };
|
let inner = pin::pin!(ScopeInner::from_worker(worker));
|
||||||
this.complete(worker, || f(&this))
|
let this = Scope::<'_, 'env>::new(worker, inner.as_ref());
|
||||||
|
this.complete(|| f(this))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'scope, 'env> Scope<'scope, 'env> {
|
impl<'scope, 'env> Scope<'scope, 'env> {
|
||||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
||||||
fn wait_for_jobs(&self, worker: &WorkerThread) {
|
|
||||||
self.job_counter.set_inner(worker.heartbeat.parker());
|
|
||||||
if self.job_counter.count() > 0 {
|
|
||||||
#[cfg(feature = "tracing")]
|
|
||||||
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
|
|
||||||
#[cfg(feature = "tracing")]
|
|
||||||
tracing::trace!(
|
|
||||||
"thread id: {:?}, jobs: {:?}",
|
|
||||||
worker.heartbeat.index(),
|
|
||||||
unsafe { worker.queue.as_ref_unchecked() }
|
|
||||||
);
|
|
||||||
|
|
||||||
// set worker index in the job counter
|
|
||||||
worker.wait_until_pred(|| self.job_counter.probe());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// should be called from within a worker thread.
|
/// should be called from within a worker thread.
|
||||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||||
fn complete<F, R>(&self, worker: &WorkerThread, f: F) -> R
|
fn complete<F, R>(&self, f: F) -> R
|
||||||
where
|
where
|
||||||
F: FnOnce() -> R + Send,
|
F: FnOnce() -> R,
|
||||||
R: Send,
|
|
||||||
{
|
{
|
||||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||||
|
|
||||||
|
@ -122,76 +202,90 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.wait_for_jobs(worker);
|
let inner = self.inner();
|
||||||
self.maybe_propagate_panic();
|
inner.wait_for_jobs();
|
||||||
|
inner.maybe_propagate_panic();
|
||||||
|
|
||||||
// SAFETY: if result panicked, we would have propagated the panic above.
|
// SAFETY: if result panicked, we would have propagated the panic above.
|
||||||
result.unwrap()
|
result.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// resumes the panic if one happened in this scope.
|
fn inner(&self) -> &ScopeInner {
|
||||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
unsafe { self.inner.as_ref() }
|
||||||
fn maybe_propagate_panic(&self) {
|
|
||||||
let err_ptr = self.panic.load(Ordering::Relaxed);
|
|
||||||
if !err_ptr.is_null() {
|
|
||||||
unsafe {
|
|
||||||
let err = Box::from_raw(err_ptr);
|
|
||||||
std::panic::resume_unwind(*err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// stores the first panic that happened in this scope.
|
/// stores the first panic that happened in this scope.
|
||||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||||
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
||||||
#[cfg(feature = "tracing")]
|
self.inner().panicked(err);
|
||||||
tracing::debug!("panicked in scope, storing error: {:?}", err);
|
|
||||||
self.panic.load(Ordering::Relaxed).is_null().then(|| {
|
|
||||||
use core::mem::ManuallyDrop;
|
|
||||||
let mut boxed = ManuallyDrop::new(Box::new(err));
|
|
||||||
|
|
||||||
let err_ptr: *mut Box<dyn Any + Send + 'static> = &mut **boxed;
|
|
||||||
if self
|
|
||||||
.panic
|
|
||||||
.compare_exchange(
|
|
||||||
ptr::null_mut(),
|
|
||||||
err_ptr,
|
|
||||||
Ordering::SeqCst,
|
|
||||||
Ordering::Relaxed,
|
|
||||||
)
|
|
||||||
.is_ok()
|
|
||||||
{
|
|
||||||
// we successfully set the panic, no need to drop
|
|
||||||
} else {
|
|
||||||
// drop the error, someone else already set it
|
|
||||||
_ = ManuallyDrop::into_inner(boxed);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn<F>(&'scope self, f: F)
|
pub fn spawn<F>(&self, f: F)
|
||||||
where
|
where
|
||||||
F: FnOnce(&'scope Self) + Send,
|
F: FnOnce(Self) + Send,
|
||||||
{
|
{
|
||||||
self.job_counter.increment();
|
let inner = self.inner;
|
||||||
|
|
||||||
let this = SendPtr::new_const(self).unwrap();
|
unsafe {
|
||||||
|
inner.as_ref().increment();
|
||||||
let job = Job::from_heapjob(Box::new(HeapJob::new(move || unsafe {
|
|
||||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
|
||||||
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(this.as_ref()))) {
|
|
||||||
this.as_unchecked_ref().panicked(payload);
|
|
||||||
}
|
}
|
||||||
this.as_unchecked_ref().job_counter.decrement();
|
|
||||||
})));
|
|
||||||
|
|
||||||
self.context.inject_job(job.share(None));
|
struct SpawnedJob<F> {
|
||||||
|
f: F,
|
||||||
|
inner: SendPtr<ScopeInner>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> SpawnedJob<F> {
|
||||||
|
fn new<'scope, 'env, T>(f: F, inner: SendPtr<ScopeInner>) -> Job
|
||||||
|
where
|
||||||
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||||
|
'env: 'scope,
|
||||||
|
T: Send,
|
||||||
|
{
|
||||||
|
Job::from_harness(
|
||||||
|
Self::harness,
|
||||||
|
Box::into_non_null(Box::new(Self { f, inner })).cast(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn harness<'scope, 'env, T>(
|
||||||
|
worker: &WorkerThread,
|
||||||
|
this: NonNull<()>,
|
||||||
|
_: Option<Sender>,
|
||||||
|
) where
|
||||||
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||||
|
'env: 'scope,
|
||||||
|
T: Send,
|
||||||
|
{
|
||||||
|
let Self { f, inner } =
|
||||||
|
unsafe { *Box::<SpawnedJob<F>>::from_non_null(this.cast()) };
|
||||||
|
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, inner) };
|
||||||
|
|
||||||
|
// SAFETY: we are in a worker thread, so the inner is valid.
|
||||||
|
(f)(scope);
|
||||||
|
|
||||||
|
unsafe { inner.as_ref().decrement() };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let job = SpawnedJob::new(
|
||||||
|
move |scope| {
|
||||||
|
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(scope))) {
|
||||||
|
scope.inner().panicked(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.inner().decrement();
|
||||||
|
},
|
||||||
|
self.inner,
|
||||||
|
);
|
||||||
|
|
||||||
|
self.context().inject_job(job.share(None));
|
||||||
// WorkerThread::current_ref()
|
// WorkerThread::current_ref()
|
||||||
// .expect("spawn is run in workerthread.")
|
// .expect("spawn is run in workerthread.")
|
||||||
// .push_front(job.as_ptr());
|
// .push_front(job.as_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn_future<T, F>(&'scope self, future: F) -> async_task::Task<T>
|
pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
|
||||||
where
|
where
|
||||||
F: Future<Output = T> + Send + 'scope,
|
F: Future<Output = T> + Send + 'scope,
|
||||||
T: Send + 'scope,
|
T: Send + 'scope,
|
||||||
|
@ -200,9 +294,9 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn spawn_async<T, Fut, Fn>(&'scope self, f: Fn) -> async_task::Task<T>
|
pub fn spawn_async<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
|
||||||
where
|
where
|
||||||
Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope,
|
Fn: FnOnce(Self) -> Fut + Send + 'scope,
|
||||||
Fut: Future<Output = T> + Send + 'scope,
|
Fut: Future<Output = T> + Send + 'scope,
|
||||||
T: Send + 'scope,
|
T: Send + 'scope,
|
||||||
{
|
{
|
||||||
|
@ -210,25 +304,30 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn spawn_async_internal<T, Fut, Fn>(&'scope self, f: Fn) -> async_task::Task<T>
|
fn spawn_async_internal<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
|
||||||
where
|
where
|
||||||
Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope,
|
Fn: FnOnce(Self) -> Fut + Send + 'scope,
|
||||||
Fut: Future<Output = T> + Send + 'scope,
|
Fut: Future<Output = T> + Send + 'scope,
|
||||||
T: Send + 'scope,
|
T: Send + 'scope,
|
||||||
{
|
{
|
||||||
self.job_counter.increment();
|
self.inner().increment();
|
||||||
|
|
||||||
let this = SendPtr::new_const(self).unwrap();
|
let this = SendPtr::new_const(self).unwrap();
|
||||||
// let this = SendPtr::new_const(&self.job_counter).unwrap();
|
// let this = SendPtr::new_const(&self.job_counter).unwrap();
|
||||||
|
|
||||||
|
// TODO: make sure this worker lasts long enough for the
|
||||||
|
// reference to remain valid for the duration of the future.
|
||||||
|
let scope = unsafe { Self::new_unchecked(self.worker.as_ref(), self.inner) };
|
||||||
|
|
||||||
let future = async move {
|
let future = async move {
|
||||||
// SAFETY: this is valid until we decrement the job counter.
|
// SAFETY: this is valid until we decrement the job counter.
|
||||||
unsafe {
|
unsafe {
|
||||||
let _guard = DropGuard::new(move || {
|
let _guard = DropGuard::new(move || {
|
||||||
this.as_unchecked_ref().job_counter.decrement();
|
this.as_ref().inner().decrement();
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: handle panics here
|
// TODO: handle panics here
|
||||||
f(this.as_ref()).await
|
f(scope).await
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -244,7 +343,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
||||||
let job = Job::<()>::from_harness(harness, runnable.into_raw());
|
let job = Job::<()>::from_harness(harness, runnable.into_raw());
|
||||||
|
|
||||||
// casting into Job<()> here
|
// casting into Job<()> here
|
||||||
self.context.inject_job(job.share(None));
|
self.context().inject_job(job.share(None));
|
||||||
// WorkerThread::current_ref()
|
// WorkerThread::current_ref()
|
||||||
// .expect("spawn_async_internal is run in workerthread.")
|
// .expect("spawn_async_internal is run in workerthread.")
|
||||||
// .push_front(job);
|
// .push_front(job);
|
||||||
|
@ -257,37 +356,140 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
||||||
task
|
task
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn join<A, B, RA, RB>(&'scope self, a: A, b: B) -> (RA, RB)
|
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||||
where
|
where
|
||||||
RA: Send,
|
RA: Send,
|
||||||
RB: Send,
|
A: FnOnce(Self) -> RA + Send,
|
||||||
A: FnOnce(&'scope Self) -> RA + Send,
|
B: FnOnce(Self) -> RB,
|
||||||
B: FnOnce(&'scope Self) -> RB + Send,
|
|
||||||
{
|
{
|
||||||
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
||||||
let this = SendPtr::new_const(self).unwrap();
|
use std::{
|
||||||
|
cell::UnsafeCell,
|
||||||
|
mem::{self, ManuallyDrop},
|
||||||
|
};
|
||||||
|
|
||||||
worker.join_heartbeat_every::<_, _, _, _>(
|
let worker = self.worker();
|
||||||
{
|
|
||||||
let this = this;
|
struct ScopeJob<F> {
|
||||||
move || a(unsafe { this.as_ref() })
|
f: UnsafeCell<ManuallyDrop<F>>,
|
||||||
},
|
inner: SendPtr<ScopeInner>,
|
||||||
{
|
|
||||||
let this = this;
|
|
||||||
move || b(unsafe { this.as_ref() })
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn from_context(context: Arc<Context>) -> Self {
|
impl<F> ScopeJob<F> {
|
||||||
|
fn new(f: F, inner: SendPtr<ScopeInner>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
context,
|
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
||||||
job_counter: CountLatch::new(ptr::null()),
|
inner,
|
||||||
panic: AtomicPtr::new(ptr::null_mut()),
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_job<'scope, 'env, T>(&self) -> Job<T>
|
||||||
|
where
|
||||||
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||||
|
'env: 'scope,
|
||||||
|
T: Send,
|
||||||
|
{
|
||||||
|
Job::from_harness(Self::harness, NonNull::from(self).cast())
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn unwrap(&self) -> F {
|
||||||
|
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn harness<'scope, 'env, T>(
|
||||||
|
worker: &WorkerThread,
|
||||||
|
this: NonNull<()>,
|
||||||
|
sender: Option<Sender>,
|
||||||
|
) where
|
||||||
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||||
|
'env: 'scope,
|
||||||
|
T: Send,
|
||||||
|
{
|
||||||
|
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
|
||||||
|
let f = unsafe { this.unwrap() };
|
||||||
|
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
|
||||||
|
let sender: Sender<T> = unsafe { mem::transmute(sender) };
|
||||||
|
|
||||||
|
// SAFETY: we are in a worker thread, so the inner is valid.
|
||||||
|
sender.send(catch_unwind(AssertUnwindSafe(|| f(scope))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let stack = ScopeJob::new(a, self.inner);
|
||||||
|
let job = ScopeJob::into_job(&stack);
|
||||||
|
|
||||||
|
worker.push_back(&job);
|
||||||
|
|
||||||
|
worker.tick();
|
||||||
|
|
||||||
|
let rb = match catch_unwind(AssertUnwindSafe(|| b(*self))) {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(payload) => {
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
||||||
|
std::hint::cold_path();
|
||||||
|
|
||||||
|
// if b panicked, we need to wait for a to finish
|
||||||
|
let mut receiver = job.take_receiver();
|
||||||
|
worker.wait_until_pred(|| match &receiver {
|
||||||
|
Some(recv) => recv.poll().is_some(),
|
||||||
|
None => {
|
||||||
|
receiver = job.take_receiver();
|
||||||
|
false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
resume_unwind(payload);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let ra = if let Some(recv) = job.take_receiver() {
|
||||||
|
match worker.wait_until_recv(recv) {
|
||||||
|
Some(t) => crate::util::unwrap_or_panic(t),
|
||||||
|
None => {
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
tracing::trace!(
|
||||||
|
"join_heartbeat: job was shared, but reclaimed, running a() inline"
|
||||||
|
);
|
||||||
|
// the job was shared, but not yet stolen, so we get to run the
|
||||||
|
// job inline
|
||||||
|
unsafe { stack.unwrap()(*self) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
worker.pop_back();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||||
|
stack.unwrap()(*self)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
(ra, rb)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(worker: &WorkerThread, inner: Pin<&'scope ScopeInner>) -> Self {
|
||||||
|
// SAFETY: we are creating a new scope, so the inner is valid.
|
||||||
|
unsafe { Self::new_unchecked(worker, SendPtr::new_const(&*inner).unwrap()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn new_unchecked(worker: &WorkerThread, inner: SendPtr<ScopeInner>) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
worker: SendPtr::new_const(worker).unwrap(),
|
||||||
_scope: PhantomData,
|
_scope: PhantomData,
|
||||||
_env: PhantomData,
|
_env: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn context(&self) -> &Arc<Context> {
|
||||||
|
unsafe { &self.worker.as_ref().context }
|
||||||
|
}
|
||||||
|
pub fn worker(&self) -> &WorkerThread {
|
||||||
|
unsafe { self.worker.as_ref() }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -330,7 +532,7 @@ mod tests {
|
||||||
fn scope_join_many() {
|
fn scope_join_many() {
|
||||||
let pool = ThreadPool::new_with_threads(1);
|
let pool = ThreadPool::new_with_threads(1);
|
||||||
|
|
||||||
fn sum<'scope, 'env>(scope: &'scope Scope<'scope, 'env>, n: usize) -> usize {
|
fn sum<'scope, 'env>(scope: Scope<'scope, 'env>, n: usize) -> usize {
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ impl ThreadPool {
|
||||||
|
|
||||||
pub fn scope<'env, F, R>(&self, f: F) -> R
|
pub fn scope<'env, F, R>(&self, f: F) -> R
|
||||||
where
|
where
|
||||||
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
|
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
||||||
R: Send,
|
R: Send,
|
||||||
{
|
{
|
||||||
scope_with_context(&self.context, f)
|
scope_with_context(&self.context, f)
|
||||||
|
|
|
@ -94,8 +94,7 @@ impl<T> SendPtr<T> {
|
||||||
unsafe { Self::new_unchecked(ptr.cast_mut()) }
|
unsafe { Self::new_unchecked(ptr.cast_mut()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn as_unchecked_ref(&self) -> &T {
|
pub(crate) unsafe fn as_ref(&self) -> &T {
|
||||||
// SAFETY: `self.0` is a valid non-null pointer.
|
|
||||||
unsafe { self.0.as_ref() }
|
unsafe { self.0.as_ref() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue