From 3458a900ee8fbd536f74a7201963dee06933b76f Mon Sep 17 00:00:00 2001
From: Janis <janis@nirgendwo.xyz>
Date: Sat, 8 Mar 2025 12:24:02 +0100
Subject: [PATCH] todo: separate workerthread and scope logic, add scope type
 with lifetime

---
 .cargo/config.toml |   0
 src/praetor/mod.rs | 111 +++++++++++++++++++++++++--------------------
 2 files changed, 63 insertions(+), 48 deletions(-)
 create mode 100644 .cargo/config.toml

diff --git a/.cargo/config.toml b/.cargo/config.toml
new file mode 100644
index 0000000..e69de29
diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs
index 5e74683..03b5b90 100644
--- a/src/praetor/mod.rs
+++ b/src/praetor/mod.rs
@@ -3,7 +3,6 @@ mod util {
         cell::UnsafeCell,
         marker::PhantomData,
         mem::ManuallyDrop,
-        num::NonZero,
         ops::{Deref, DerefMut},
         ptr::NonNull,
         sync::atomic::{AtomicPtr, Ordering},
@@ -58,12 +57,23 @@ mod util {
     }
 
     impl<T> SendPtr<T> {
-        pub fn new(ptr: *mut T) -> Option<Self> {
-            NonNull::new(ptr).map(Self)
+        pub const fn new(ptr: *mut T) -> Option<Self> {
+            match NonNull::new(ptr) {
+                Some(ptr) => Some(Self(ptr)),
+                None => None,
+            }
         }
-        pub unsafe fn new_unchecked(ptr: *mut T) -> Self {
+        pub const unsafe fn new_unchecked(ptr: *mut T) -> Self {
             unsafe { Self(NonNull::new_unchecked(ptr)) }
         }
+
+        pub const fn new_const(ptr: *const T) -> Option<Self> {
+            Self::new(ptr.cast_mut())
+        }
+
+        pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self {
+            Self::new_unchecked(ptr.cast_mut())
+        }
     }
 
     // Miri doesn't like tagging pointers that it doesn't know the alignment of.
@@ -174,7 +184,7 @@ mod util {
             let ptr = ptr.cast::<()>();
             loop {
                 let old = self.0.load(failure);
-                let new = ptr.with_addr((ptr.addr() & !mask) | (old.addr() & mask));
+                let new = ptr.map_addr(|addr| (addr & !mask) | (old.addr() & mask));
                 if self
                     .0
                     .compare_exchange_weak(old, new, success, failure)
@@ -603,10 +613,7 @@ mod job {
     unsafe impl<T> Send for Job<T> {}
 
     impl<T> Job<T> {
-        pub fn new(
-            harness: unsafe fn(*const (), *const Job<T>, &super::Scope),
-            this: NonNull<()>,
-        ) -> Job<T> {
+        pub fn new(harness: unsafe fn(*const (), *const Job<T>), this: NonNull<()>) -> Job<T> {
             Self {
                 harness_and_state: TaggedAtomicPtr::new(
                     unsafe { mem::transmute(harness) },
@@ -748,19 +755,18 @@ mod job {
             }
         }
 
-        pub fn execute(job: NonNull<Self>, scope: &super::Scope) {
+        pub fn execute(job: NonNull<Self>) {
             // SAFETY: self is non-null
             unsafe {
                 let this = job.as_ref();
                 let (ptr, state) = this.harness_and_state.ptr_and_tag(Ordering::Relaxed);
 
                 debug_assert_eq!(state, JobState::Pending as usize);
-                let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) =
-                    mem::transmute(ptr.as_ptr());
+                let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr());
 
                 let this = (*this.val_or_this.get()).this;
 
-                harness(this.as_ptr().cast(), job.as_ptr(), scope);
+                harness(this.as_ptr().cast(), job.as_ptr());
             }
         }
 
@@ -829,19 +835,19 @@ mod job {
         #[allow(dead_code)]
         pub fn into_boxed_job<T>(self: Box<Self>) -> Pin<Box<Job<()>>>
         where
-            F: FnOnce(&super::Scope) -> T + Send,
+            F: FnOnce() -> T + Send,
             T: Send,
         {
             #[repr(align(8))]
-            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
+            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
             where
-                F: FnOnce(&super::Scope) -> T + Send,
+                F: FnOnce() -> T + Send,
                 T: Sized + Send,
             {
                 let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
                 let f = this.f;
 
-                _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
+                _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
 
                 _ = unsafe { Box::from_raw(job.cast_mut()) };
             }
@@ -871,19 +877,19 @@ mod job {
 
         pub fn as_job<T>(self: Pin<&Self>) -> Job<()>
         where
-            F: FnOnce(&super::Scope) -> T + Send,
+            F: FnOnce() -> T + Send,
             T: Send,
         {
             #[repr(align(8))]
-            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
+            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
             where
-                F: FnOnce(&super::Scope) -> T + Send,
+                F: FnOnce() -> T + Send,
                 T: Sized + Send,
             {
                 let this = unsafe { &*this.cast::<StackJob<F>>() };
                 let f = unsafe { this.unwrap() };
 
-                let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
+                let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
 
                 let job_ref = unsafe { &*job.cast::<Job<T>>() };
                 job_ref.complete(result);
@@ -900,6 +906,7 @@ use std::{
     cell::{Cell, UnsafeCell},
     collections::BTreeMap,
     future::Future,
+    marker::PhantomData,
     mem::{self, MaybeUninit},
     pin::{pin, Pin},
     ptr::NonNull,
@@ -946,7 +953,14 @@ impl JobCounter {
     }
 }
 
-pub struct Scope {
+struct WorkerThread {
+    context: Arc<Context>,
+    index: usize,
+    heartbeat: Arc<CachePadded<AtomicBool>>,
+    queue: UnsafeCell<JobList>,
+}
+
+pub struct Scope<'scope> {
     join_count: Cell<usize>,
     context: Arc<Context>,
     index: usize,
@@ -954,13 +968,14 @@ pub struct Scope {
     queue: UnsafeCell<JobList>,
 
     job_counter: JobCounter,
+    _pd: PhantomData<&'scope ()>,
 }
 
 thread_local! {
-    static SCOPE: UnsafeCell<Option<NonNull<Scope>>> = const { UnsafeCell::new(None) };
+    static SCOPE: UnsafeCell<Option<NonNull<Scope<'static>>>> = const { UnsafeCell::new(None) };
 }
 
-impl Scope {
+impl<'scope> Scope<'scope> {
     /// locks shared context
     #[allow(dead_code)]
     fn new() -> Self {
@@ -979,6 +994,7 @@ impl Scope {
             join_count: Cell::new(0),
             queue: UnsafeCell::new(JobList::new()),
             job_counter: JobCounter::default(),
+            _pd: PhantomData,
         }
     }
 
@@ -1033,22 +1049,22 @@ impl Scope {
         Self::with_in(Context::global(), f)
     }
 
-    unsafe fn set_current(scope: *const Scope) {
+    unsafe fn set_current(scope: *const Scope<'static>) {
         SCOPE.with(|ptr| unsafe {
             _ = (&mut *ptr.get()).insert(NonNull::new_unchecked(scope.cast_mut()));
         })
     }
 
-    unsafe fn unset_current() -> Option<NonNull<Scope>> {
+    unsafe fn unset_current() -> Option<NonNull<Scope<'static>>> {
         SCOPE.with(|ptr| unsafe { (&mut *ptr.get()).take() })
     }
 
     #[allow(dead_code)]
-    fn current() -> Option<NonNull<Scope>> {
+    fn current() -> Option<NonNull<Scope<'scope>>> {
         SCOPE.with(|ptr| unsafe { *ptr.get() })
     }
 
-    fn current_ref<'a>() -> Option<&'a Scope> {
+    fn current_ref<'a>() -> Option<&'a Scope<'scope>> {
         SCOPE.with(|ptr| unsafe { (&*ptr.get()).map(|ptr| ptr.as_ref()) })
     }
 
@@ -1084,19 +1100,17 @@ impl Scope {
         }
     }
 
-    pub fn spawn<'a, F>(&self, f: F)
+    pub fn spawn<F>(&self, f: F)
     where
-        F: FnOnce(&Scope) + Send + 'a,
+        F: FnOnce(&Scope<'scope>) + Send + 'scope,
     {
         self.job_counter.increment();
 
-        let this = unsafe {
-            SendPtr::new_unchecked(&self.job_counter as *const JobCounter as *mut JobCounter)
-        };
+        let this = SendPtr::new_const(self).unwrap();
 
-        let job = HeapJob::new(move |scope: &Scope| unsafe {
-            f(scope);
-            this.as_ref().decrement();
+        let job = HeapJob::new(move || unsafe {
+            f(this.as_ref());
+            this.as_ref().job_counter.decrement();
         })
         .into_boxed_job();
 
@@ -1104,15 +1118,14 @@ impl Scope {
         mem::forget(job);
     }
 
-    pub fn spawn_future<'a, T, F>(&'a self, future: F) -> async_task::Task<T>
+    pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
     where
-        F: Future<Output = T> + Send + 'a,
-        T: Send + 'a,
+        F: Future<Output = T> + Send + 'scope,
+        T: Send + 'scope,
     {
         self.job_counter.increment();
 
-        let this =
-            unsafe { SendPtr::new_unchecked(&raw const self.job_counter as *mut JobCounter) };
+        let this = SendPtr::new_const(&self.job_counter).unwrap();
 
         let future = async move {
             let _guard = DropGuard::new(move || unsafe {
@@ -1121,10 +1134,10 @@ impl Scope {
             future.await
         };
 
-        let this = SendPtr::new(&raw const *self as *mut Self).unwrap();
+        let this = SendPtr::new_const(self).unwrap();
         let schedule = move |runnable: Runnable| {
             #[repr(align(8))]
-            unsafe fn harness<T>(this: *const (), job: *const Job<T>, _: &Scope) {
+            unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
                 let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
                 runnable.run();
 
@@ -1153,7 +1166,7 @@ impl Scope {
         Fut: Future<Output = T> + Send + 'static,
         T: Send + 'static,
     {
-        let this = SendPtr::new(self as *const Self as *mut Self).unwrap();
+        let this = SendPtr::new_const(self).unwrap();
         let future = async move { f(unsafe { this.as_ref() }).await };
 
         self.spawn_future(future)
@@ -1209,7 +1222,9 @@ impl Scope {
         A: FnOnce(&Self) -> RA + Send,
         B: FnOnce(&Self) -> RB + Send,
     {
-        let a = pin!(StackJob::new(move |scope: &Scope| {
+        let this = SendPtr::new_const(self).unwrap();
+        let a = pin!(StackJob::new(move || unsafe {
+            let scope = this.as_ref();
             scope.tick();
 
             a(scope)
@@ -1225,14 +1240,14 @@ impl Scope {
                 job.unlink();
             }
 
-            unsafe { a.unwrap()(self) }
+            unsafe { a.unwrap()() }
         } else {
             match self.wait_until::<RA>(unsafe {
                 mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref())
             }) {
                 Some(Ok(t)) => t,
                 Some(Err(payload)) => std::panic::resume_unwind(payload),
-                None => unsafe { a.unwrap()(self) },
+                None => unsafe { a.unwrap()() },
             }
         };
 
@@ -1250,7 +1265,7 @@ impl Scope {
     #[inline]
     fn execute(&self, job: NonNull<Job>) {
         self.tick();
-        Job::execute(job, self);
+        Job::execute(job);
     }
 
     #[cold]