executor/distaff/src/scope.rs

635 lines
20 KiB
Rust

use std::{
any::Any,
marker::{PhantomData, PhantomPinned},
panic::{AssertUnwindSafe, catch_unwind},
pin::{self, Pin},
ptr::{self, NonNull},
sync::{
Arc,
atomic::{AtomicPtr, AtomicUsize, Ordering},
},
};
use async_task::Runnable;
use werkzeug::util;
use crate::{
channel::Sender,
context::{Context, Message},
job::{
HeapJob, Job2 as Job, SharedJob,
traits::{InlineJob, IntoJob},
},
latch::{CountLatch, Probe},
queue::ReceiverToken,
util::{DropGuard, SendPtr},
workerthread::WorkerThread,
};
// thinking:
// the scope needs to keep track of any spawn() and spawn_async() calls, across all worker threads.
// that means, that for any spawn() or spawn_async() calls, we have to share a counter across all worker threads.
// we want to minimise the number of atomic operations in general.
// atomic operations occur in the following cases:
// - when we spawn() or spawn_async() a job, we increment the counter
// - when the same job finishes, we decrement the counter
// - when a join() job finishes, it's latch is set
// - when we wait for a join() job, we loop over the latch until it is set
// a Scope must keep track of:
// - The number of async jobs spawned, which is used to determine when the scope
// is complete.
// - A panic box, which is set when a job panics and is used to resume the panic
// when the scope is completed.
// - The Parker of the worker on which the scope was created, which is signaled
// when the last outstanding async job finishes.
// - The current worker thread in order to avoid having to query the
// thread-local storage.
struct ScopeInner {
outstanding_jobs: AtomicUsize,
parker: ReceiverToken,
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
}
unsafe impl Send for ScopeInner {}
unsafe impl Sync for ScopeInner {}
#[derive(Clone, Copy)]
pub struct Scope<'scope, 'env: 'scope> {
inner: SendPtr<ScopeInner>,
worker: SendPtr<WorkerThread>,
_scope: PhantomData<&'scope mut &'scope ()>,
_env: PhantomData<&'env mut &'env ()>,
}
impl ScopeInner {
fn from_worker(worker: &WorkerThread) -> Self {
Self {
outstanding_jobs: AtomicUsize::new(0),
parker: worker.receiver.get_token(),
panic: AtomicPtr::new(ptr::null_mut()),
}
}
fn increment(&self) {
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
}
fn decrement(&self, worker: &WorkerThread) {
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
worker
.context
.queue
.as_sender()
.unicast(Message::ScopeFinished, self.parker);
}
}
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
unsafe {
let err = Box::into_raw(Box::new(err));
if !self
.panic
.compare_exchange(ptr::null_mut(), err, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
// someone else already set the panic, so we drop the error
_ = Box::from_raw(err);
}
}
}
fn maybe_propagate_panic(&self) {
let err = self.panic.swap(ptr::null_mut(), Ordering::AcqRel);
if err.is_null() {
return;
} else {
// SAFETY: we have exclusive access to the panic error, so we can safely resume it.
unsafe {
let err = *Box::from_raw(err);
std::panic::resume_unwind(err);
}
}
}
}
// find below a sketch of an unbalanced tree:
// []
// / \
// [] []
// / \ / \
// [] [] [] []
// / \ / \
// [] [][] []
// / \ / \
// [] [] [] []
// / \ / \
// [] [] [] []
// / \
// [] []
// in this tree of join() calls, it is possible to wait for a long time, so it is necessary to keep waking up when a job is shared.
// the worker waits on it's latch, which may be woken by:
// - a job finishing
// - another thread sharing a job
// - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share.
pub struct Scope2<'scope, 'env: 'scope> {
// latch to wait on before the scope finishes
job_counter: CountLatch,
// local threadpool
context: Arc<Context>,
// panic error
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
// variant lifetime
_scope: PhantomData<&'scope mut &'scope ()>,
_env: PhantomData<&'env mut &'env ()>,
}
pub fn scope<'env, F, R>(f: F) -> R
where
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
R: Send,
{
scope_with_context(Context::global_context(), f)
}
pub fn scope_with_context<'env, F, R>(context: &Arc<Context>, f: F) -> R
where
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
R: Send,
{
context.run_in_worker(|worker| {
// SAFETY: we call complete() after creating this scope, which
// ensures that any jobs spawned from the scope exit before the
// scope closes.
let inner = pin::pin!(ScopeInner::from_worker(worker));
let this = Scope::<'_, 'env>::new(worker, inner.as_ref());
this.complete(|| f(this))
})
}
impl<'scope, 'env> Scope<'scope, 'env> {
/// should be called from within a worker thread.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn complete<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
use std::panic::{AssertUnwindSafe, catch_unwind};
let result = match catch_unwind(AssertUnwindSafe(|| f())) {
Ok(val) => Some(val),
Err(payload) => {
self.panicked(payload);
None
}
};
self.wait_for_jobs();
let inner = self.inner();
inner.maybe_propagate_panic();
// SAFETY: if result panicked, we would have propagated the panic above.
result.unwrap()
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn wait_for_jobs(&self) {
loop {
let count = self.inner().outstanding_jobs.load(Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!("waiting for {} jobs to finish.", count);
if count == 0 {
break;
}
match self.worker().receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self.worker());
},
Message::Finished(util::Send(result)) => {
#[cfg(feature = "tracing")]
tracing::error!(
"received result when waiting for jobs to finish: {:p}.",
result
);
}
Message::Exit => {}
Message::ScopeFinished => {
#[cfg(feature = "tracing")]
tracing::trace!("scope finished, decrementing outstanding jobs.");
assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0);
break;
}
}
}
}
fn decrement(&self) {
self.inner().decrement(self.worker());
}
fn inner(&self) -> &ScopeInner {
unsafe { self.inner.as_ref() }
}
/// stores the first panic that happened in this scope.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
self.inner().panicked(err);
}
pub fn spawn<F>(&self, f: F)
where
F: FnOnce(Self) + Send,
{
struct SpawnedJob<F> {
f: F,
inner: SendPtr<ScopeInner>,
}
impl<F> SpawnedJob<F> {
fn new<'scope, 'env, T>(f: F, inner: SendPtr<ScopeInner>) -> Job
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
Job::from_harness(
Self::harness,
Box::into_non_null(Box::new(Self { f, inner })).cast(),
)
}
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
_: Option<ReceiverToken>,
) where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
let Self { f, inner } =
unsafe { *Box::<SpawnedJob<F>>::from_non_null(this.cast()) };
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, inner) };
// SAFETY: we are in a worker thread, so the inner is valid.
(f)(scope);
}
}
self.inner().increment();
let job = SpawnedJob::new(
move |scope| {
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(scope))) {
scope.inner().panicked(payload);
}
scope.decrement();
},
self.inner,
);
self.context().inject_job(job.share(None));
}
pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
where
F: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.spawn_async_internal(move |_| future)
}
#[allow(dead_code)]
pub fn spawn_async<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce(Self) -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.spawn_async_internal(f)
}
#[inline]
fn spawn_async_internal<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce(Self) -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.inner().increment();
// TODO: make sure this worker lasts long enough for the
// reference to remain valid for the duration of the future.
let scope = unsafe { Self::new_unchecked(self.worker.as_ref(), self.inner) };
let future = async move {
let _guard = DropGuard::new(move || {
scope.decrement();
});
// TODO: handle panics here
f(scope).await
};
let schedule = move |runnable: Runnable| {
#[align(8)]
unsafe fn harness(_: &WorkerThread, this: NonNull<()>, _: Option<ReceiverToken>) {
unsafe {
let runnable = Runnable::<()>::from_raw(this.cast());
runnable.run();
}
}
let job = Job::<()>::from_harness(harness, runnable.into_raw());
// casting into Job<()> here
self.context().inject_job(job.share(None));
// WorkerThread::current_ref()
// .expect("spawn_async_internal is run in workerthread.")
// .push_front(job);
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
A: FnOnce(Self) -> RA + Send,
B: FnOnce(Self) -> RB,
{
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
use std::{
cell::UnsafeCell,
mem::{self, ManuallyDrop},
};
let worker = self.worker();
struct ScopeJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
inner: SendPtr<ScopeInner>,
_pin: PhantomPinned,
}
impl<F> ScopeJob<F> {
fn new(f: F, inner: SendPtr<ScopeInner>) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
inner,
_pin: PhantomPinned,
}
}
fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job<T>
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
}
unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
sender: Option<ReceiverToken>,
) where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
let f = unsafe { this.unwrap() };
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
_ = worker.context.queue.as_sender().unicast(
Message::Finished(werkzeug::util::Send(
Box::into_non_null(Box::new(catch_unwind(AssertUnwindSafe(|| f(scope)))))
.cast(),
)),
sender.unwrap(),
);
}
}
impl<'scope, 'env, F, T> IntoJob<T> for Pin<&ScopeJob<F>>
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
fn into_job(self) -> Job<T> {
self.into_job()
}
}
impl<'scope, 'env, F, T> InlineJob<T> for Pin<&ScopeJob<F>>
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
fn run_inline(self, worker: &WorkerThread) -> T {
unsafe { self.unwrap()(Scope::<'scope, 'env>::new_unchecked(worker, self.inner)) }
}
}
let mut _pinned = ScopeJob::new(a, self.inner);
let job = unsafe { Pin::new_unchecked(&_pinned) };
let (a, b) = worker.join_heartbeat2_every::<_, _, _, _, 64>(job, |_| b(*self));
// touch job here to ensure it is not dropped before we run the join.
drop(_pinned);
(a, b)
// let stack = ScopeJob::new(a, self.inner);
// let job = ScopeJob::into_job(&stack);
// worker.push_back(&job);
// worker.tick();
// let rb = match catch_unwind(AssertUnwindSafe(|| b(*self))) {
// Ok(val) => val,
// Err(payload) => {
// #[cfg(feature = "tracing")]
// tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
// std::hint::cold_path();
// // if b panicked, we need to wait for a to finish
// let mut receiver = job.take_receiver();
// worker.wait_until_pred(|| match &receiver {
// Some(recv) => recv.poll().is_some(),
// None => {
// receiver = job.take_receiver();
// false
// }
// });
// resume_unwind(payload);
// }
// };
// let ra = if let Some(recv) = job.take_receiver() {
// match worker.wait_until_recv(recv) {
// Some(t) => crate::util::unwrap_or_panic(t),
// None => {
// #[cfg(feature = "tracing")]
// tracing::trace!(
// "join_heartbeat: job was shared, but reclaimed, running a() inline"
// );
// // the job was shared, but not yet stolen, so we get to run the
// // job inline
// unsafe { stack.unwrap()(*self) }
// }
// }
// } else {
// worker.pop_back();
// unsafe {
// // SAFETY: we just popped the job from the queue, so it is safe to unwrap.
// #[cfg(feature = "tracing")]
// tracing::trace!("join_heartbeat: job was not shared, running a() inline");
// stack.unwrap()(*self)
// }
// };
// (ra, rb)
}
fn new(worker: &WorkerThread, inner: Pin<&'scope ScopeInner>) -> Self {
// SAFETY: we are creating a new scope, so the inner is valid.
unsafe { Self::new_unchecked(worker, SendPtr::new_const(&*inner).unwrap()) }
}
unsafe fn new_unchecked(worker: &WorkerThread, inner: SendPtr<ScopeInner>) -> Self {
Self {
inner,
worker: SendPtr::new_const(worker).unwrap(),
_scope: PhantomData,
_env: PhantomData,
}
}
pub fn context(&self) -> &Arc<Context> {
unsafe { &self.worker.as_ref().context }
}
pub fn worker(&self) -> &WorkerThread {
unsafe { self.worker.as_ref() }
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicU8;
use super::*;
use crate::ThreadPool;
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_spawn_sync() {
let pool = ThreadPool::new_with_threads(1);
let count = Arc::new(AtomicU8::new(0));
scope_with_context(&pool.context, |scope| {
scope.spawn(|_| {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
});
});
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_join_one() {
let pool = ThreadPool::new_with_threads(1);
let count = AtomicU8::new(0);
let a = pool.scope(|scope| {
let (a, b) = scope.join(
|_| count.fetch_add(1, Ordering::Relaxed) + 4,
|_| count.fetch_add(2, Ordering::Relaxed) + 6,
);
a + b
});
assert_eq!(count.load(Ordering::Relaxed), 3);
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_join_many() {
let pool = ThreadPool::new_with_threads(1);
fn sum<'scope, 'env>(scope: Scope<'scope, 'env>, n: usize) -> usize {
if n == 0 {
return 0;
}
let (l, r) = scope.join(|s| sum(s, n - 1), |s| sum(s, n - 1));
l + r + 1
}
pool.scope(|scope| {
let total = sum(scope, 5);
// assert_eq!(total, 1023);
eprintln!("Total sum: {}", total);
});
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_spawn_future() {
let pool = ThreadPool::new_with_threads(1);
let mut x = 0;
pool.scope(|scope| {
let task = scope.spawn_async(|_| async {
x += 1;
});
task.detach();
});
assert_eq!(x, 1);
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_spawn_many() {
let pool = ThreadPool::new_with_threads(1);
let count = Arc::new(AtomicU8::new(0));
pool.scope(|scope| {
for _ in 0..10 {
let count = count.clone();
scope.spawn(move |_| {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
});
}
});
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 10);
}
}