From 8f753108ec8de3cceb443fe57ecc550ef3536a6f Mon Sep 17 00:00:00 2001
From: Janis <janis@nirgendwo.xyz>
Date: Fri, 21 Feb 2025 20:06:29 +0100
Subject: [PATCH] drop scope before deallocating

---
 src/praetor/mod.rs | 86 ++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 75 insertions(+), 11 deletions(-)

diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs
index 4e437ca..4fe77a8 100644
--- a/src/praetor/mod.rs
+++ b/src/praetor/mod.rs
@@ -888,20 +888,52 @@ mod job {
 use std::{
     cell::{Cell, UnsafeCell},
     collections::BTreeMap,
-    mem,
+    future::Future,
+    mem::{self, MaybeUninit},
     pin::{pin, Pin},
     ptr::NonNull,
     sync::{
-        atomic::{AtomicBool, Ordering},
+        atomic::{AtomicBool, AtomicUsize, Ordering},
         Arc, OnceLock, Weak,
     },
     time::Duration,
 };
 
+use async_task::Runnable;
 use crossbeam::utils::CachePadded;
 use job::*;
 use parking_lot::{Condvar, Mutex};
-use util::DropGuard;
+use util::{DropGuard, SendPtr};
+
+#[derive(Debug, Default)]
+pub struct JobCounter {
+    jobs_pending: AtomicUsize,
+    waker: Mutex<Option<std::thread::Thread>>,
+}
+
+impl JobCounter {
+    pub fn increment(&self) {
+        self.jobs_pending.fetch_add(1, Ordering::Relaxed);
+    }
+
+    pub fn decrement(&self) {
+        if self.jobs_pending.fetch_sub(1, Ordering::SeqCst) == 1 {
+            if let Some(thread) = self.waker.lock().take() {
+                thread.unpark();
+            }
+        }
+    }
+
+    /// must only be called once
+    pub unsafe fn wait(&self) {
+        _ = self.waker.lock().insert(std::thread::current());
+
+        let count = self.jobs_pending.load(Ordering::SeqCst);
+        if count > 0 {
+            std::thread::park();
+        }
+    }
+}
 
 pub struct Scope {
     join_count: Cell<usize>,
@@ -909,12 +941,23 @@ pub struct Scope {
     index: usize,
     heartbeat: Arc<CachePadded<AtomicBool>>,
     queue: UnsafeCell<JobList>,
+
+    job_counter: JobCounter,
 }
 
 thread_local! {
     static SCOPE: UnsafeCell<Option<NonNull<Scope>>> = const { UnsafeCell::new(None) };
 }
 
+impl Drop for Scope {
+    fn drop(&mut self) {
+        self.complete_jobs();
+        unsafe {
+            self.job_counter.wait();
+        }
+    }
+}
+
 impl Scope {
     /// locks shared context
     #[allow(dead_code)]
@@ -933,20 +976,32 @@ impl Scope {
             heartbeat,
             join_count: Cell::new(0),
             queue: UnsafeCell::new(JobList::new()),
+            job_counter: JobCounter::default(),
+        }
+    }
+
+    unsafe fn drop_in_place_and_dealloc(this: NonNull<Scope>) {
+        unsafe {
+            let ptr = this.as_ptr();
+            ptr.drop_in_place();
+
+            _ = Box::<MaybeUninit<Self>>::from_raw(ptr.cast());
         }
     }
 
     fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: &Arc<Context>, f: F) -> T {
-        let mut guard = Option::<DropGuard<Box<dyn FnOnce()>>>::None;
+        let mut _guard = Option::<DropGuard<Box<dyn FnOnce()>>>::None;
 
         let scope = match Self::current_ref() {
             Some(scope) if Arc::ptr_eq(&scope.context, ctx) => scope,
             Some(_) => {
                 let old = unsafe { Self::unset_current().unwrap().as_ptr() };
-                guard = Some(DropGuard::new(Box::new(move || unsafe {
-                    _ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
+                _guard = Some(DropGuard::new(Box::new(move || unsafe {
+                    Self::drop_in_place_and_dealloc(Self::unset_current().unwrap());
+
                     Self::set_current(old.cast_const());
                 })));
+
                 let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
                 unsafe {
                     Self::set_current(current.cast_const());
@@ -956,8 +1011,8 @@ impl Scope {
             None => {
                 let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
 
-                guard = Some(DropGuard::new(Box::new(|| unsafe {
-                    _ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
+                _guard = Some(DropGuard::new(Box::new(|| unsafe {
+                    Self::drop_in_place_and_dealloc(Self::unset_current().unwrap());
                 })));
 
                 unsafe {
@@ -969,7 +1024,6 @@ impl Scope {
         };
 
         let t = f(scope);
-        drop(guard);
         t
     }
 
@@ -1015,6 +1069,14 @@ impl Scope {
         unsafe { self.queue.as_mut_unchecked().pop_front() }
     }
 
+    fn complete_jobs(&self) {
+        while let Some(job) = self.pop_front() {
+            unsafe {
+                job.as_ref().set_pending();
+            }
+            self.execute(job);
+        }
+    }
     #[inline]
     pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
     where
@@ -1273,8 +1335,10 @@ fn worker(ctx: Arc<Context>, barrier: Arc<std::sync::Barrier>) {
     unsafe {
         Scope::set_current(Box::into_raw(Box::new(Scope::new_in(ctx.clone()))).cast_const());
     }
-    let _guard =
-        DropGuard::new(|| unsafe { drop(Box::from_raw(Scope::unset_current().unwrap().as_ptr())) });
+
+    let _guard = DropGuard::new(|| unsafe {
+        Scope::drop_in_place_and_dealloc(Scope::unset_current().unwrap());
+    });
 
     let scope = Scope::current_ref().unwrap();