use std::{ any::Any, marker::{PhantomData, PhantomPinned}, panic::{AssertUnwindSafe, catch_unwind}, pin::{self, Pin}, ptr::{self, NonNull}, sync::{ Arc, atomic::{AtomicPtr, AtomicUsize, Ordering}, }, }; use async_task::Runnable; use werkzeug::util; use crate::{ channel::Sender, context::{Context, Message}, job::{ HeapJob, Job2 as Job, SharedJob, traits::{InlineJob, IntoJob}, }, latch::{CountLatch, Probe}, queue::ReceiverToken, util::{DropGuard, SendPtr}, workerthread::WorkerThread, }; // thinking: // the scope needs to keep track of any spawn() and spawn_async() calls, across all worker threads. // that means, that for any spawn() or spawn_async() calls, we have to share a counter across all worker threads. // we want to minimise the number of atomic operations in general. // atomic operations occur in the following cases: // - when we spawn() or spawn_async() a job, we increment the counter // - when the same job finishes, we decrement the counter // - when a join() job finishes, it's latch is set // - when we wait for a join() job, we loop over the latch until it is set // a Scope must keep track of: // - The number of async jobs spawned, which is used to determine when the scope // is complete. // - A panic box, which is set when a job panics and is used to resume the panic // when the scope is completed. // - The Parker of the worker on which the scope was created, which is signaled // when the last outstanding async job finishes. // - The current worker thread in order to avoid having to query the // thread-local storage. struct ScopeInner { outstanding_jobs: AtomicUsize, parker: ReceiverToken, panic: AtomicPtr>, } unsafe impl Send for ScopeInner {} unsafe impl Sync for ScopeInner {} #[derive(Clone, Copy)] pub struct Scope<'scope, 'env: 'scope> { inner: SendPtr, worker: SendPtr, _scope: PhantomData<&'scope mut &'scope ()>, _env: PhantomData<&'env mut &'env ()>, } impl ScopeInner { fn from_worker(worker: &WorkerThread) -> Self { Self { outstanding_jobs: AtomicUsize::new(0), parker: worker.receiver.get_token(), panic: AtomicPtr::new(ptr::null_mut()), } } fn increment(&self) { self.outstanding_jobs.fetch_add(1, Ordering::Relaxed); } fn decrement(&self, worker: &WorkerThread) { if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 { worker .context .queue .as_sender() .unicast(Message::ScopeFinished, self.parker); } } fn panicked(&self, err: Box) { unsafe { let err = Box::into_raw(Box::new(err)); if !self .panic .compare_exchange(ptr::null_mut(), err, Ordering::AcqRel, Ordering::Acquire) .is_ok() { // someone else already set the panic, so we drop the error _ = Box::from_raw(err); } } } fn maybe_propagate_panic(&self) { let err = self.panic.swap(ptr::null_mut(), Ordering::AcqRel); if err.is_null() { return; } else { // SAFETY: we have exclusive access to the panic error, so we can safely resume it. unsafe { let err = *Box::from_raw(err); std::panic::resume_unwind(err); } } } } // find below a sketch of an unbalanced tree: // [] // / \ // [] [] // / \ / \ // [] [] [] [] // / \ / \ // [] [][] [] // / \ / \ // [] [] [] [] // / \ / \ // [] [] [] [] // / \ // [] [] // in this tree of join() calls, it is possible to wait for a long time, so it is necessary to keep waking up when a job is shared. // the worker waits on it's latch, which may be woken by: // - a job finishing // - another thread sharing a job // - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share. pub struct Scope2<'scope, 'env: 'scope> { // latch to wait on before the scope finishes job_counter: CountLatch, // local threadpool context: Arc, // panic error panic: AtomicPtr>, // variant lifetime _scope: PhantomData<&'scope mut &'scope ()>, _env: PhantomData<&'env mut &'env ()>, } pub fn scope<'env, F, R>(f: F) -> R where F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send, R: Send, { scope_with_context(Context::global_context(), f) } pub fn scope_with_context<'env, F, R>(context: &Arc, f: F) -> R where F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send, R: Send, { context.run_in_worker(|worker| { // SAFETY: we call complete() after creating this scope, which // ensures that any jobs spawned from the scope exit before the // scope closes. let inner = pin::pin!(ScopeInner::from_worker(worker)); let this = Scope::<'_, 'env>::new(worker, inner.as_ref()); this.complete(|| f(this)) }) } impl<'scope, 'env> Scope<'scope, 'env> { /// should be called from within a worker thread. #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] fn complete(&self, f: F) -> R where F: FnOnce() -> R, { use std::panic::{AssertUnwindSafe, catch_unwind}; let result = match catch_unwind(AssertUnwindSafe(|| f())) { Ok(val) => Some(val), Err(payload) => { self.panicked(payload); None } }; self.wait_for_jobs(); let inner = self.inner(); inner.maybe_propagate_panic(); // SAFETY: if result panicked, we would have propagated the panic above. result.unwrap() } #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] fn wait_for_jobs(&self) { loop { let count = self.inner().outstanding_jobs.load(Ordering::Relaxed); #[cfg(feature = "tracing")] tracing::trace!("waiting for {} jobs to finish.", count); if count == 0 { break; } match self.worker().receiver.recv() { Message::Shared(shared_job) => unsafe { SharedJob::execute(shared_job, self.worker()); }, Message::ScopeFinished => { #[cfg(feature = "tracing")] tracing::trace!("scope finished, decrementing outstanding jobs."); assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0); break; } Message::WakeUp | Message::Exit => {} } } } fn decrement(&self) { self.inner().decrement(self.worker()); } fn inner(&self) -> &ScopeInner { unsafe { self.inner.as_ref() } } /// stores the first panic that happened in this scope. #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] fn panicked(&self, err: Box) { self.inner().panicked(err); } pub fn spawn(&self, f: F) where F: FnOnce(Self) + Send, { struct SpawnedJob { f: F, inner: SendPtr, } impl SpawnedJob { fn new<'scope, 'env, T>(f: F, inner: SendPtr) -> Job where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, T: Send, { Job::from_harness( Self::harness, Box::into_non_null(Box::new(Self { f, inner })).cast(), ) } #[align(8)] unsafe fn harness<'scope, 'env, T>( worker: &WorkerThread, this: NonNull<()>, _: Option, ) where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, T: Send, { let Self { f, inner } = unsafe { *Box::>::from_non_null(this.cast()) }; let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, inner) }; // SAFETY: we are in a worker thread, so the inner is valid. (f)(scope); } } self.inner().increment(); let job = SpawnedJob::new( move |scope| { if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(scope))) { scope.inner().panicked(payload); } scope.decrement(); }, self.inner, ); self.context().inject_job(job.share(None)); } pub fn spawn_future(&self, future: F) -> async_task::Task where F: Future + Send + 'scope, T: Send + 'scope, { self.spawn_async_internal(move |_| future) } #[allow(dead_code)] pub fn spawn_async(&self, f: Fn) -> async_task::Task where Fn: FnOnce(Self) -> Fut + Send + 'scope, Fut: Future + Send + 'scope, T: Send + 'scope, { self.spawn_async_internal(f) } #[inline] fn spawn_async_internal(&self, f: Fn) -> async_task::Task where Fn: FnOnce(Self) -> Fut + Send + 'scope, Fut: Future + Send + 'scope, T: Send + 'scope, { self.inner().increment(); // TODO: make sure this worker lasts long enough for the // reference to remain valid for the duration of the future. let scope = unsafe { Self::new_unchecked(self.worker.as_ref(), self.inner) }; let future = async move { let _guard = DropGuard::new(move || { scope.decrement(); }); // TODO: handle panics here f(scope).await }; let schedule = move |runnable: Runnable| { #[align(8)] unsafe fn harness(_: &WorkerThread, this: NonNull<()>, _: Option) { unsafe { let runnable = Runnable::<()>::from_raw(this.cast()); runnable.run(); } } let job = Job::<()>::from_harness(harness, runnable.into_raw()); // casting into Job<()> here self.context().inject_job(job.share(None)); // WorkerThread::current_ref() // .expect("spawn_async_internal is run in workerthread.") // .push_front(job); }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; runnable.schedule(); task } pub fn join(&self, a: A, b: B) -> (RA, RB) where RA: Send, A: FnOnce(Self) -> RA + Send, B: FnOnce(Self) -> RB, { use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; use std::{ cell::UnsafeCell, mem::{self, ManuallyDrop}, }; let worker = self.worker(); struct ScopeJob { f: UnsafeCell>, inner: SendPtr, _pin: PhantomPinned, } impl ScopeJob { fn new(f: F, inner: SendPtr) -> Self { Self { f: UnsafeCell::new(ManuallyDrop::new(f)), inner, _pin: PhantomPinned, } } fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, T: Send, { Job::from_harness(Self::harness, NonNull::from(&*self).cast()) } unsafe fn unwrap(&self) -> F { unsafe { ManuallyDrop::take(&mut *self.f.get()) } } #[align(8)] unsafe fn harness<'scope, 'env, T>( worker: &WorkerThread, this: NonNull<()>, sender: Option, ) where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, T: Send, { let this: &ScopeJob = unsafe { this.cast().as_ref() }; let sender: Option> = unsafe { mem::transmute(sender) }; let f = unsafe { this.unwrap() }; let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) }; let result = catch_unwind(AssertUnwindSafe(|| f(scope))); let sender = sender.unwrap(); unsafe { sender.send_as_ref(result); worker .context .queue .as_sender() .unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker())); } } } impl<'scope, 'env, F, T> IntoJob for Pin<&ScopeJob> where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, T: Send, { fn into_job(self) -> Job { self.into_job() } } impl<'scope, 'env, F, T> InlineJob for Pin<&ScopeJob> where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, T: Send, { fn run_inline(self, worker: &WorkerThread) -> T { unsafe { self.unwrap()(Scope::<'scope, 'env>::new_unchecked(worker, self.inner)) } } } let _pinned = ScopeJob::new(a, self.inner); let job = unsafe { Pin::new_unchecked(&_pinned) }; let (a, b) = worker.join_heartbeat2(job, |_| b(*self)); // touch job here to ensure it is not dropped before we run the join. drop(_pinned); (a, b) // let stack = ScopeJob::new(a, self.inner); // let job = ScopeJob::into_job(&stack); // worker.push_back(&job); // worker.tick(); // let rb = match catch_unwind(AssertUnwindSafe(|| b(*self))) { // Ok(val) => val, // Err(payload) => { // #[cfg(feature = "tracing")] // tracing::debug!("join_heartbeat: b panicked, waiting for a to finish"); // std::hint::cold_path(); // // if b panicked, we need to wait for a to finish // let mut receiver = job.take_receiver(); // worker.wait_until_pred(|| match &receiver { // Some(recv) => recv.poll().is_some(), // None => { // receiver = job.take_receiver(); // false // } // }); // resume_unwind(payload); // } // }; // let ra = if let Some(recv) = job.take_receiver() { // match worker.wait_until_recv(recv) { // Some(t) => crate::util::unwrap_or_panic(t), // None => { // #[cfg(feature = "tracing")] // tracing::trace!( // "join_heartbeat: job was shared, but reclaimed, running a() inline" // ); // // the job was shared, but not yet stolen, so we get to run the // // job inline // unsafe { stack.unwrap()(*self) } // } // } // } else { // worker.pop_back(); // unsafe { // // SAFETY: we just popped the job from the queue, so it is safe to unwrap. // #[cfg(feature = "tracing")] // tracing::trace!("join_heartbeat: job was not shared, running a() inline"); // stack.unwrap()(*self) // } // }; // (ra, rb) } fn new(worker: &WorkerThread, inner: Pin<&'scope ScopeInner>) -> Self { // SAFETY: we are creating a new scope, so the inner is valid. unsafe { Self::new_unchecked(worker, SendPtr::new_const(&*inner).unwrap()) } } unsafe fn new_unchecked(worker: &WorkerThread, inner: SendPtr) -> Self { Self { inner, worker: SendPtr::new_const(worker).unwrap(), _scope: PhantomData, _env: PhantomData, } } pub fn context(&self) -> &Arc { unsafe { &self.worker.as_ref().context } } pub fn worker(&self) -> &WorkerThread { unsafe { self.worker.as_ref() } } } #[cfg(test)] mod tests { use std::sync::atomic::AtomicU8; use super::*; use crate::ThreadPool; #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn scope_spawn_sync() { let pool = ThreadPool::new_with_threads(1); let count = Arc::new(AtomicU8::new(0)); scope_with_context(&pool.context, |scope| { scope.spawn(|_| { count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); }); }); assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 1); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn scope_join_one() { let pool = ThreadPool::new_with_threads(1); let count = AtomicU8::new(0); let a = pool.scope(|scope| { let (a, b) = scope.join( |_| count.fetch_add(1, Ordering::Relaxed) + 4, |_| count.fetch_add(2, Ordering::Relaxed) + 6, ); a + b }); assert_eq!(count.load(Ordering::Relaxed), 3); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn scope_join_many() { let pool = ThreadPool::new_with_threads(1); fn sum<'scope, 'env>(scope: Scope<'scope, 'env>, n: usize) -> usize { if n == 0 { return 0; } let (l, r) = scope.join(|s| sum(s, n - 1), |s| sum(s, n - 1)); l + r + 1 } pool.scope(|scope| { let total = sum(scope, 5); // assert_eq!(total, 1023); eprintln!("Total sum: {}", total); }); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn scope_spawn_future() { let pool = ThreadPool::new_with_threads(1); let mut x = 0; pool.scope(|scope| { let task = scope.spawn_async(|_| async { x += 1; }); task.detach(); }); assert_eq!(x, 1); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn scope_spawn_many() { let pool = ThreadPool::new_with_threads(1); let count = Arc::new(AtomicU8::new(0)); pool.scope(|scope| { for _ in 0..10 { let count = count.clone(); scope.spawn(move |_| { count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); }); } }); assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 10); } }