diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index 4e437ca..4fe77a8 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -888,20 +888,52 @@ mod job { use std::{ cell::{Cell, UnsafeCell}, collections::BTreeMap, - mem, + future::Future, + mem::{self, MaybeUninit}, pin::{pin, Pin}, ptr::NonNull, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, OnceLock, Weak, }, time::Duration, }; +use async_task::Runnable; use crossbeam::utils::CachePadded; use job::*; use parking_lot::{Condvar, Mutex}; -use util::DropGuard; +use util::{DropGuard, SendPtr}; + +#[derive(Debug, Default)] +pub struct JobCounter { + jobs_pending: AtomicUsize, + waker: Mutex>, +} + +impl JobCounter { + pub fn increment(&self) { + self.jobs_pending.fetch_add(1, Ordering::Relaxed); + } + + pub fn decrement(&self) { + if self.jobs_pending.fetch_sub(1, Ordering::SeqCst) == 1 { + if let Some(thread) = self.waker.lock().take() { + thread.unpark(); + } + } + } + + /// must only be called once + pub unsafe fn wait(&self) { + _ = self.waker.lock().insert(std::thread::current()); + + let count = self.jobs_pending.load(Ordering::SeqCst); + if count > 0 { + std::thread::park(); + } + } +} pub struct Scope { join_count: Cell, @@ -909,12 +941,23 @@ pub struct Scope { index: usize, heartbeat: Arc>, queue: UnsafeCell, + + job_counter: JobCounter, } thread_local! { static SCOPE: UnsafeCell>> = const { UnsafeCell::new(None) }; } +impl Drop for Scope { + fn drop(&mut self) { + self.complete_jobs(); + unsafe { + self.job_counter.wait(); + } + } +} + impl Scope { /// locks shared context #[allow(dead_code)] @@ -933,20 +976,32 @@ impl Scope { heartbeat, join_count: Cell::new(0), queue: UnsafeCell::new(JobList::new()), + job_counter: JobCounter::default(), + } + } + + unsafe fn drop_in_place_and_dealloc(this: NonNull) { + unsafe { + let ptr = this.as_ptr(); + ptr.drop_in_place(); + + _ = Box::>::from_raw(ptr.cast()); } } fn with_in T>(ctx: &Arc, f: F) -> T { - let mut guard = Option::>>::None; + let mut _guard = Option::>>::None; let scope = match Self::current_ref() { Some(scope) if Arc::ptr_eq(&scope.context, ctx) => scope, Some(_) => { let old = unsafe { Self::unset_current().unwrap().as_ptr() }; - guard = Some(DropGuard::new(Box::new(move || unsafe { - _ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); + _guard = Some(DropGuard::new(Box::new(move || unsafe { + Self::drop_in_place_and_dealloc(Self::unset_current().unwrap()); + Self::set_current(old.cast_const()); }))); + let current = Box::into_raw(Box::new(Self::new_in(ctx.clone()))); unsafe { Self::set_current(current.cast_const()); @@ -956,8 +1011,8 @@ impl Scope { None => { let current = Box::into_raw(Box::new(Self::new_in(ctx.clone()))); - guard = Some(DropGuard::new(Box::new(|| unsafe { - _ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); + _guard = Some(DropGuard::new(Box::new(|| unsafe { + Self::drop_in_place_and_dealloc(Self::unset_current().unwrap()); }))); unsafe { @@ -969,7 +1024,6 @@ impl Scope { }; let t = f(scope); - drop(guard); t } @@ -1015,6 +1069,14 @@ impl Scope { unsafe { self.queue.as_mut_unchecked().pop_front() } } + fn complete_jobs(&self) { + while let Some(job) = self.pop_front() { + unsafe { + job.as_ref().set_pending(); + } + self.execute(job); + } + } #[inline] pub fn join(&self, a: A, b: B) -> (RA, RB) where @@ -1273,8 +1335,10 @@ fn worker(ctx: Arc, barrier: Arc) { unsafe { Scope::set_current(Box::into_raw(Box::new(Scope::new_in(ctx.clone()))).cast_const()); } - let _guard = - DropGuard::new(|| unsafe { drop(Box::from_raw(Scope::unset_current().unwrap().as_ptr())) }); + + let _guard = DropGuard::new(|| unsafe { + Scope::drop_in_place_and_dealloc(Scope::unset_current().unwrap()); + }); let scope = Scope::current_ref().unwrap();