635 lines
20 KiB
Rust
635 lines
20 KiB
Rust
use std::{
|
|
any::Any,
|
|
marker::{PhantomData, PhantomPinned},
|
|
panic::{AssertUnwindSafe, catch_unwind},
|
|
pin::{self, Pin},
|
|
ptr::{self, NonNull},
|
|
sync::{
|
|
Arc,
|
|
atomic::{AtomicPtr, AtomicUsize, Ordering},
|
|
},
|
|
};
|
|
|
|
use async_task::Runnable;
|
|
use werkzeug::util;
|
|
|
|
use crate::{
|
|
channel::Sender,
|
|
context::{Context, Message},
|
|
job::{
|
|
HeapJob, Job2 as Job, SharedJob,
|
|
traits::{InlineJob, IntoJob},
|
|
},
|
|
latch::{CountLatch, Probe},
|
|
queue::ReceiverToken,
|
|
util::{DropGuard, SendPtr},
|
|
workerthread::WorkerThread,
|
|
};
|
|
|
|
// thinking:
|
|
|
|
// the scope needs to keep track of any spawn() and spawn_async() calls, across all worker threads.
|
|
// that means, that for any spawn() or spawn_async() calls, we have to share a counter across all worker threads.
|
|
// we want to minimise the number of atomic operations in general.
|
|
// atomic operations occur in the following cases:
|
|
// - when we spawn() or spawn_async() a job, we increment the counter
|
|
// - when the same job finishes, we decrement the counter
|
|
// - when a join() job finishes, it's latch is set
|
|
// - when we wait for a join() job, we loop over the latch until it is set
|
|
|
|
// a Scope must keep track of:
|
|
// - The number of async jobs spawned, which is used to determine when the scope
|
|
// is complete.
|
|
// - A panic box, which is set when a job panics and is used to resume the panic
|
|
// when the scope is completed.
|
|
// - The Parker of the worker on which the scope was created, which is signaled
|
|
// when the last outstanding async job finishes.
|
|
// - The current worker thread in order to avoid having to query the
|
|
// thread-local storage.
|
|
|
|
struct ScopeInner {
|
|
outstanding_jobs: AtomicUsize,
|
|
parker: ReceiverToken,
|
|
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
|
|
}
|
|
|
|
unsafe impl Send for ScopeInner {}
|
|
unsafe impl Sync for ScopeInner {}
|
|
|
|
#[derive(Clone, Copy)]
|
|
pub struct Scope<'scope, 'env: 'scope> {
|
|
inner: SendPtr<ScopeInner>,
|
|
worker: SendPtr<WorkerThread>,
|
|
_scope: PhantomData<&'scope mut &'scope ()>,
|
|
_env: PhantomData<&'env mut &'env ()>,
|
|
}
|
|
|
|
impl ScopeInner {
|
|
fn from_worker(worker: &WorkerThread) -> Self {
|
|
Self {
|
|
outstanding_jobs: AtomicUsize::new(0),
|
|
parker: worker.receiver.get_token(),
|
|
panic: AtomicPtr::new(ptr::null_mut()),
|
|
}
|
|
}
|
|
|
|
fn increment(&self) {
|
|
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
|
|
fn decrement(&self, worker: &WorkerThread) {
|
|
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
|
|
worker
|
|
.context
|
|
.queue
|
|
.as_sender()
|
|
.unicast(Message::ScopeFinished, self.parker);
|
|
}
|
|
}
|
|
|
|
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
|
unsafe {
|
|
let err = Box::into_raw(Box::new(err));
|
|
if !self
|
|
.panic
|
|
.compare_exchange(ptr::null_mut(), err, Ordering::AcqRel, Ordering::Acquire)
|
|
.is_ok()
|
|
{
|
|
// someone else already set the panic, so we drop the error
|
|
_ = Box::from_raw(err);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn maybe_propagate_panic(&self) {
|
|
let err = self.panic.swap(ptr::null_mut(), Ordering::AcqRel);
|
|
|
|
if err.is_null() {
|
|
return;
|
|
} else {
|
|
// SAFETY: we have exclusive access to the panic error, so we can safely resume it.
|
|
unsafe {
|
|
let err = *Box::from_raw(err);
|
|
std::panic::resume_unwind(err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// find below a sketch of an unbalanced tree:
|
|
// []
|
|
// / \
|
|
// [] []
|
|
// / \ / \
|
|
// [] [] [] []
|
|
// / \ / \
|
|
// [] [][] []
|
|
// / \ / \
|
|
// [] [] [] []
|
|
// / \ / \
|
|
// [] [] [] []
|
|
// / \
|
|
// [] []
|
|
|
|
// in this tree of join() calls, it is possible to wait for a long time, so it is necessary to keep waking up when a job is shared.
|
|
|
|
// the worker waits on it's latch, which may be woken by:
|
|
// - a job finishing
|
|
// - another thread sharing a job
|
|
// - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share.
|
|
|
|
pub struct Scope2<'scope, 'env: 'scope> {
|
|
// latch to wait on before the scope finishes
|
|
job_counter: CountLatch,
|
|
// local threadpool
|
|
context: Arc<Context>,
|
|
// panic error
|
|
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
|
|
// variant lifetime
|
|
_scope: PhantomData<&'scope mut &'scope ()>,
|
|
_env: PhantomData<&'env mut &'env ()>,
|
|
}
|
|
|
|
pub fn scope<'env, F, R>(f: F) -> R
|
|
where
|
|
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
|
R: Send,
|
|
{
|
|
scope_with_context(Context::global_context(), f)
|
|
}
|
|
|
|
pub fn scope_with_context<'env, F, R>(context: &Arc<Context>, f: F) -> R
|
|
where
|
|
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
|
R: Send,
|
|
{
|
|
context.run_in_worker(|worker| {
|
|
// SAFETY: we call complete() after creating this scope, which
|
|
// ensures that any jobs spawned from the scope exit before the
|
|
// scope closes.
|
|
let inner = pin::pin!(ScopeInner::from_worker(worker));
|
|
let this = Scope::<'_, 'env>::new(worker, inner.as_ref());
|
|
this.complete(|| f(this))
|
|
})
|
|
}
|
|
|
|
impl<'scope, 'env> Scope<'scope, 'env> {
|
|
/// should be called from within a worker thread.
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
fn complete<F, R>(&self, f: F) -> R
|
|
where
|
|
F: FnOnce() -> R,
|
|
{
|
|
use std::panic::{AssertUnwindSafe, catch_unwind};
|
|
|
|
let result = match catch_unwind(AssertUnwindSafe(|| f())) {
|
|
Ok(val) => Some(val),
|
|
Err(payload) => {
|
|
self.panicked(payload);
|
|
None
|
|
}
|
|
};
|
|
|
|
self.wait_for_jobs();
|
|
let inner = self.inner();
|
|
inner.maybe_propagate_panic();
|
|
|
|
// SAFETY: if result panicked, we would have propagated the panic above.
|
|
result.unwrap()
|
|
}
|
|
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
fn wait_for_jobs(&self) {
|
|
loop {
|
|
let count = self.inner().outstanding_jobs.load(Ordering::Relaxed);
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("waiting for {} jobs to finish.", count);
|
|
if count == 0 {
|
|
break;
|
|
}
|
|
|
|
match self.worker().receiver.recv() {
|
|
Message::Shared(shared_job) => unsafe {
|
|
SharedJob::execute(shared_job, self.worker());
|
|
},
|
|
Message::Finished(util::Send(result)) => {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::error!(
|
|
"received result when waiting for jobs to finish: {:p}.",
|
|
result
|
|
);
|
|
}
|
|
Message::Exit => {}
|
|
Message::ScopeFinished => {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("scope finished, decrementing outstanding jobs.");
|
|
assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn decrement(&self) {
|
|
self.inner().decrement(self.worker());
|
|
}
|
|
|
|
fn inner(&self) -> &ScopeInner {
|
|
unsafe { self.inner.as_ref() }
|
|
}
|
|
|
|
/// stores the first panic that happened in this scope.
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
|
self.inner().panicked(err);
|
|
}
|
|
|
|
pub fn spawn<F>(&self, f: F)
|
|
where
|
|
F: FnOnce(Self) + Send,
|
|
{
|
|
struct SpawnedJob<F> {
|
|
f: F,
|
|
inner: SendPtr<ScopeInner>,
|
|
}
|
|
|
|
impl<F> SpawnedJob<F> {
|
|
fn new<'scope, 'env, T>(f: F, inner: SendPtr<ScopeInner>) -> Job
|
|
where
|
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
|
'env: 'scope,
|
|
T: Send,
|
|
{
|
|
Job::from_harness(
|
|
Self::harness,
|
|
Box::into_non_null(Box::new(Self { f, inner })).cast(),
|
|
)
|
|
}
|
|
|
|
unsafe fn harness<'scope, 'env, T>(
|
|
worker: &WorkerThread,
|
|
this: NonNull<()>,
|
|
_: Option<ReceiverToken>,
|
|
) where
|
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
|
'env: 'scope,
|
|
T: Send,
|
|
{
|
|
let Self { f, inner } =
|
|
unsafe { *Box::<SpawnedJob<F>>::from_non_null(this.cast()) };
|
|
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, inner) };
|
|
|
|
// SAFETY: we are in a worker thread, so the inner is valid.
|
|
(f)(scope);
|
|
}
|
|
}
|
|
|
|
self.inner().increment();
|
|
let job = SpawnedJob::new(
|
|
move |scope| {
|
|
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(scope))) {
|
|
scope.inner().panicked(payload);
|
|
}
|
|
|
|
scope.decrement();
|
|
},
|
|
self.inner,
|
|
);
|
|
|
|
self.context().inject_job(job.share(None));
|
|
}
|
|
|
|
pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
|
|
where
|
|
F: Future<Output = T> + Send + 'scope,
|
|
T: Send + 'scope,
|
|
{
|
|
self.spawn_async_internal(move |_| future)
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub fn spawn_async<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
|
|
where
|
|
Fn: FnOnce(Self) -> Fut + Send + 'scope,
|
|
Fut: Future<Output = T> + Send + 'scope,
|
|
T: Send + 'scope,
|
|
{
|
|
self.spawn_async_internal(f)
|
|
}
|
|
|
|
#[inline]
|
|
fn spawn_async_internal<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
|
|
where
|
|
Fn: FnOnce(Self) -> Fut + Send + 'scope,
|
|
Fut: Future<Output = T> + Send + 'scope,
|
|
T: Send + 'scope,
|
|
{
|
|
self.inner().increment();
|
|
|
|
// TODO: make sure this worker lasts long enough for the
|
|
// reference to remain valid for the duration of the future.
|
|
let scope = unsafe { Self::new_unchecked(self.worker.as_ref(), self.inner) };
|
|
|
|
let future = async move {
|
|
let _guard = DropGuard::new(move || {
|
|
scope.decrement();
|
|
});
|
|
|
|
// TODO: handle panics here
|
|
f(scope).await
|
|
};
|
|
|
|
let schedule = move |runnable: Runnable| {
|
|
#[align(8)]
|
|
unsafe fn harness(_: &WorkerThread, this: NonNull<()>, _: Option<ReceiverToken>) {
|
|
unsafe {
|
|
let runnable = Runnable::<()>::from_raw(this.cast());
|
|
runnable.run();
|
|
}
|
|
}
|
|
|
|
let job = Job::<()>::from_harness(harness, runnable.into_raw());
|
|
|
|
// casting into Job<()> here
|
|
self.context().inject_job(job.share(None));
|
|
// WorkerThread::current_ref()
|
|
// .expect("spawn_async_internal is run in workerthread.")
|
|
// .push_front(job);
|
|
};
|
|
|
|
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
|
|
|
runnable.schedule();
|
|
|
|
task
|
|
}
|
|
|
|
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
|
where
|
|
RA: Send,
|
|
A: FnOnce(Self) -> RA + Send,
|
|
B: FnOnce(Self) -> RB,
|
|
{
|
|
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
|
use std::{
|
|
cell::UnsafeCell,
|
|
mem::{self, ManuallyDrop},
|
|
};
|
|
|
|
let worker = self.worker();
|
|
|
|
struct ScopeJob<F> {
|
|
f: UnsafeCell<ManuallyDrop<F>>,
|
|
inner: SendPtr<ScopeInner>,
|
|
_pin: PhantomPinned,
|
|
}
|
|
|
|
impl<F> ScopeJob<F> {
|
|
fn new(f: F, inner: SendPtr<ScopeInner>) -> Self {
|
|
Self {
|
|
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
|
inner,
|
|
_pin: PhantomPinned,
|
|
}
|
|
}
|
|
|
|
fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job<T>
|
|
where
|
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
|
'env: 'scope,
|
|
T: Send,
|
|
{
|
|
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
|
|
}
|
|
|
|
unsafe fn unwrap(&self) -> F {
|
|
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
|
}
|
|
|
|
unsafe fn harness<'scope, 'env, T>(
|
|
worker: &WorkerThread,
|
|
this: NonNull<()>,
|
|
sender: Option<ReceiverToken>,
|
|
) where
|
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
|
'env: 'scope,
|
|
T: Send,
|
|
{
|
|
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
|
|
let f = unsafe { this.unwrap() };
|
|
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
|
|
|
|
_ = worker.context.queue.as_sender().unicast(
|
|
Message::Finished(werkzeug::util::Send(
|
|
Box::into_non_null(Box::new(catch_unwind(AssertUnwindSafe(|| f(scope)))))
|
|
.cast(),
|
|
)),
|
|
sender.unwrap(),
|
|
);
|
|
}
|
|
}
|
|
|
|
impl<'scope, 'env, F, T> IntoJob<T> for Pin<&ScopeJob<F>>
|
|
where
|
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
|
'env: 'scope,
|
|
T: Send,
|
|
{
|
|
fn into_job(self) -> Job<T> {
|
|
self.into_job()
|
|
}
|
|
}
|
|
|
|
impl<'scope, 'env, F, T> InlineJob<T> for Pin<&ScopeJob<F>>
|
|
where
|
|
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
|
'env: 'scope,
|
|
T: Send,
|
|
{
|
|
fn run_inline(self, worker: &WorkerThread) -> T {
|
|
unsafe { self.unwrap()(Scope::<'scope, 'env>::new_unchecked(worker, self.inner)) }
|
|
}
|
|
}
|
|
|
|
let mut _pinned = ScopeJob::new(a, self.inner);
|
|
let job = unsafe { Pin::new_unchecked(&_pinned) };
|
|
|
|
let (a, b) = worker.join_heartbeat2_every::<_, _, _, _, 64>(job, |_| b(*self));
|
|
|
|
// touch job here to ensure it is not dropped before we run the join.
|
|
drop(_pinned);
|
|
(a, b)
|
|
|
|
// let stack = ScopeJob::new(a, self.inner);
|
|
// let job = ScopeJob::into_job(&stack);
|
|
|
|
// worker.push_back(&job);
|
|
|
|
// worker.tick();
|
|
|
|
// let rb = match catch_unwind(AssertUnwindSafe(|| b(*self))) {
|
|
// Ok(val) => val,
|
|
// Err(payload) => {
|
|
// #[cfg(feature = "tracing")]
|
|
// tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
|
// std::hint::cold_path();
|
|
|
|
// // if b panicked, we need to wait for a to finish
|
|
// let mut receiver = job.take_receiver();
|
|
// worker.wait_until_pred(|| match &receiver {
|
|
// Some(recv) => recv.poll().is_some(),
|
|
// None => {
|
|
// receiver = job.take_receiver();
|
|
// false
|
|
// }
|
|
// });
|
|
|
|
// resume_unwind(payload);
|
|
// }
|
|
// };
|
|
|
|
// let ra = if let Some(recv) = job.take_receiver() {
|
|
// match worker.wait_until_recv(recv) {
|
|
// Some(t) => crate::util::unwrap_or_panic(t),
|
|
// None => {
|
|
// #[cfg(feature = "tracing")]
|
|
// tracing::trace!(
|
|
// "join_heartbeat: job was shared, but reclaimed, running a() inline"
|
|
// );
|
|
// // the job was shared, but not yet stolen, so we get to run the
|
|
// // job inline
|
|
// unsafe { stack.unwrap()(*self) }
|
|
// }
|
|
// }
|
|
// } else {
|
|
// worker.pop_back();
|
|
|
|
// unsafe {
|
|
// // SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
|
// #[cfg(feature = "tracing")]
|
|
// tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
|
// stack.unwrap()(*self)
|
|
// }
|
|
// };
|
|
|
|
// (ra, rb)
|
|
}
|
|
|
|
fn new(worker: &WorkerThread, inner: Pin<&'scope ScopeInner>) -> Self {
|
|
// SAFETY: we are creating a new scope, so the inner is valid.
|
|
unsafe { Self::new_unchecked(worker, SendPtr::new_const(&*inner).unwrap()) }
|
|
}
|
|
|
|
unsafe fn new_unchecked(worker: &WorkerThread, inner: SendPtr<ScopeInner>) -> Self {
|
|
Self {
|
|
inner,
|
|
worker: SendPtr::new_const(worker).unwrap(),
|
|
_scope: PhantomData,
|
|
_env: PhantomData,
|
|
}
|
|
}
|
|
|
|
pub fn context(&self) -> &Arc<Context> {
|
|
unsafe { &self.worker.as_ref().context }
|
|
}
|
|
pub fn worker(&self) -> &WorkerThread {
|
|
unsafe { self.worker.as_ref() }
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::atomic::AtomicU8;
|
|
|
|
use super::*;
|
|
use crate::ThreadPool;
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn scope_spawn_sync() {
|
|
let pool = ThreadPool::new_with_threads(1);
|
|
let count = Arc::new(AtomicU8::new(0));
|
|
|
|
scope_with_context(&pool.context, |scope| {
|
|
scope.spawn(|_| {
|
|
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
});
|
|
});
|
|
|
|
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 1);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn scope_join_one() {
|
|
let pool = ThreadPool::new_with_threads(1);
|
|
let count = AtomicU8::new(0);
|
|
|
|
let a = pool.scope(|scope| {
|
|
let (a, b) = scope.join(
|
|
|_| count.fetch_add(1, Ordering::Relaxed) + 4,
|
|
|_| count.fetch_add(2, Ordering::Relaxed) + 6,
|
|
);
|
|
a + b
|
|
});
|
|
|
|
assert_eq!(count.load(Ordering::Relaxed), 3);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn scope_join_many() {
|
|
let pool = ThreadPool::new_with_threads(1);
|
|
|
|
fn sum<'scope, 'env>(scope: Scope<'scope, 'env>, n: usize) -> usize {
|
|
if n == 0 {
|
|
return 0;
|
|
}
|
|
|
|
let (l, r) = scope.join(|s| sum(s, n - 1), |s| sum(s, n - 1));
|
|
|
|
l + r + 1
|
|
}
|
|
|
|
pool.scope(|scope| {
|
|
let total = sum(scope, 5);
|
|
// assert_eq!(total, 1023);
|
|
eprintln!("Total sum: {}", total);
|
|
});
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn scope_spawn_future() {
|
|
let pool = ThreadPool::new_with_threads(1);
|
|
let mut x = 0;
|
|
pool.scope(|scope| {
|
|
let task = scope.spawn_async(|_| async {
|
|
x += 1;
|
|
});
|
|
|
|
task.detach();
|
|
});
|
|
|
|
assert_eq!(x, 1);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn scope_spawn_many() {
|
|
let pool = ThreadPool::new_with_threads(1);
|
|
let count = Arc::new(AtomicU8::new(0));
|
|
|
|
pool.scope(|scope| {
|
|
for _ in 0..10 {
|
|
let count = count.clone();
|
|
scope.spawn(move |_| {
|
|
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
});
|
|
}
|
|
});
|
|
|
|
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 10);
|
|
}
|
|
}
|