From eb43c293895ea613f292d32be8373d6d5742db34 Mon Sep 17 00:00:00 2001
From: Janis <janis@nirgendwo.xyz>
Date: Thu, 20 Feb 2025 21:50:42 +0100
Subject: [PATCH] reduce arc clones, pass references to scope to join functions

---
 src/praetor/mod.rs | 94 ++++++++++++++++++++++++++--------------------
 1 file changed, 54 insertions(+), 40 deletions(-)

diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs
index 76cb588..2871990 100644
--- a/src/praetor/mod.rs
+++ b/src/praetor/mod.rs
@@ -457,7 +457,10 @@ mod job {
     unsafe impl<T> Send for Job<T> {}
 
     impl<T> Job<T> {
-        pub fn new(harness: unsafe fn(*const (), *const Job<T>), this: NonNull<()>) -> Job<T> {
+        pub fn new(
+            harness: unsafe fn(*const (), *const Job<T>, &super::Scope),
+            this: NonNull<()>,
+        ) -> Job<T> {
             Self {
                 harness_and_state: TaggedAtomicPtr::new(
                     unsafe { mem::transmute(harness) },
@@ -473,6 +476,7 @@ mod job {
                 _phantom: PhantomPinned,
             }
         }
+
         pub fn empty() -> Job<T> {
             Self {
                 harness_and_state: TaggedAtomicPtr::new(
@@ -492,6 +496,7 @@ mod job {
             }
         }
 
+        #[inline]
         unsafe fn link_mut(&self) -> &mut Link<Job> {
             unsafe { &mut (&mut *self.err_or_link.get()).link }
         }
@@ -594,16 +599,17 @@ mod job {
             }
         }
 
-        pub fn execute(&self) {
+        pub fn execute(&self, scope: &super::Scope) {
             // SAFETY: self is non-null
             unsafe {
                 let (ptr, state) = self.harness_and_state.ptr_and_tag(Ordering::Relaxed);
                 debug_assert_eq!(state, JobState::Pending as usize);
 
-                let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr());
+                let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) =
+                    mem::transmute(ptr.as_ptr());
                 let this = (*self.val_or_this.get()).this;
 
-                harness(this.as_ptr().cast(), (self as *const Self).cast());
+                harness(this.as_ptr().cast(), (self as *const Self).cast(), scope);
             }
         }
 
@@ -667,21 +673,20 @@ mod job {
         #[allow(dead_code)]
         pub fn into_boxed_job<T>(self: Box<Self>) -> Box<Job<()>>
         where
-            F: FnOnce() -> T + Send,
+            F: FnOnce(&super::Scope) -> T + Send,
             T: Send,
         {
-            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
+            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
             where
-                F: FnOnce() -> T + Send,
+                F: FnOnce(&super::Scope) -> T + Send,
                 T: Sized + Send,
             {
-                let job = unsafe { &*job.cast::<Job<T>>() };
-
                 let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
                 let f = this.f;
 
-                let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
+                let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
 
+                let job = unsafe { &*job.cast::<Job<T>>() };
                 job.complete(result);
             }
 
@@ -708,18 +713,18 @@ mod job {
 
         pub fn as_job<T>(&self) -> Job<()>
         where
-            F: FnOnce() -> T + Send,
+            F: FnOnce(&super::Scope) -> T + Send,
             T: Send,
         {
-            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
+            unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
             where
-                F: FnOnce() -> T + Send,
+                F: FnOnce(&super::Scope) -> T + Send,
                 T: Sized + Send,
             {
                 let this = unsafe { &*this.cast::<StackJob<F>>() };
                 let f = unsafe { this.unwrap() };
 
-                let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
+                let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
 
                 let job_ref = unsafe { &*job.cast::<Job<T>>() };
                 job_ref.complete(result);
@@ -782,25 +787,25 @@ impl Scope {
         }
     }
 
-    fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: Arc<Context>, f: F) -> T {
+    fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: &Arc<Context>, f: F) -> T {
         let mut guard = Option::<DropGuard<Box<dyn FnOnce()>>>::None;
 
         let scope = match Self::current_ref() {
-            Some(scope) if Arc::ptr_eq(&scope.context, &ctx) => scope,
+            Some(scope) if Arc::ptr_eq(&scope.context, ctx) => scope,
             Some(_) => {
                 let old = unsafe { Self::unset_current().unwrap().as_ptr() };
                 guard = Some(DropGuard::new(Box::new(move || unsafe {
                     _ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
                     Self::set_current(old.cast_const());
                 })));
-                let current = Box::into_raw(Box::new(Self::new_in(ctx)));
+                let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
                 unsafe {
                     Self::set_current(current.cast_const());
                     &*current
                 }
             }
             None => {
-                let current = Box::into_raw(Box::new(Self::new_in(ctx)));
+                let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
 
                 guard = Some(DropGuard::new(Box::new(|| unsafe {
                     _ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
@@ -820,7 +825,7 @@ impl Scope {
     }
 
     pub fn with<T, F: FnOnce(&Scope) -> T>(f: F) -> T {
-        Self::with_in(Context::global().clone(), f)
+        Self::with_in(Context::global(), f)
     }
 
     unsafe fn set_current(scope: *const Scope) {
@@ -861,37 +866,43 @@ impl Scope {
         unsafe { self.queue.as_mut_unchecked().pop_front() }
     }
 
+    #[inline]
     pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
     where
         RA: Send,
         RB: Send,
-        A: FnOnce() -> RA + Send,
-        B: FnOnce() -> RB + Send,
+        A: FnOnce(&Self) -> RA + Send,
+        B: FnOnce(&Self) -> RB + Send,
     {
         self.join_heartbeat_every::<_, _, _, _, 64>(a, b)
+        // self.join_heartbeat(a, b)
     }
 
     pub fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
     where
         RA: Send,
         RB: Send,
-        A: FnOnce() -> RA + Send,
-        B: FnOnce() -> RB + Send,
+        A: FnOnce(&Self) -> RA + Send,
+        B: FnOnce(&Self) -> RB + Send,
     {
-        (a(), b())
+        let rb = b(&self);
+        let ra = a(&self);
+
+        (ra, rb)
     }
 
     pub fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
     where
         RA: Send,
         RB: Send,
-        A: FnOnce() -> RA + Send,
-        B: FnOnce() -> RB + Send,
+        A: FnOnce(&Self) -> RA + Send,
+        B: FnOnce(&Self) -> RB + Send,
     {
-        let count = self.join_count.get();
-        self.join_count.set(count.wrapping_add(1) % TIMES);
+        // let count = self.join_count.get();
+        // self.join_count.set(count.wrapping_add(1) % TIMES);
+        let count = self.join_count.update(|n| n.wrapping_add(1) % TIMES);
 
-        if count == 0 {
+        if count == 1 {
             self.join_heartbeat(a, b)
         } else {
             self.join_seq(a, b)
@@ -902,30 +913,33 @@ impl Scope {
     where
         RA: Send,
         RB: Send,
-        A: FnOnce() -> RA + Send,
-        B: FnOnce() -> RB + Send,
+        A: FnOnce(&Self) -> RA + Send,
+        B: FnOnce(&Self) -> RB + Send,
     {
-        let a = StackJob::new(a);
+        let a = StackJob::new(move |scope: &Scope| {
+            scope.tick();
+
+            a(scope)
+        });
 
         let job = pin!(a.as_job());
         self.push_front(job.as_ref());
 
-        let rb = b();
+        let rb = b(self);
 
         let ra = if job.state() == JobState::Empty as u8 {
             unsafe {
                 job.unlink();
             }
 
-            self.tick();
-            unsafe { a.unwrap()() }
+            unsafe { a.unwrap()(self) }
         } else {
             match self.wait_until::<RA>(unsafe {
                 mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref())
             }) {
                 Some(Ok(t)) => t,
                 Some(Err(payload)) => std::panic::resume_unwind(payload),
-                None => unsafe { a.unwrap()() },
+                None => unsafe { a.unwrap()(self) },
             }
         };
 
@@ -933,7 +947,7 @@ impl Scope {
         (ra, rb)
     }
 
-    #[inline]
+    #[inline(always)]
     fn tick(&self) {
         if self.heartbeat.load(Ordering::Relaxed) {
             self.heartbeat_cold();
@@ -943,7 +957,7 @@ impl Scope {
     #[inline]
     fn execute(&self, job: &Job) {
         self.tick();
-        job.execute();
+        job.execute(self);
     }
 
     #[cold]
@@ -1010,7 +1024,7 @@ where
     A: FnOnce() -> RA + Send,
     B: FnOnce() -> RB + Send,
 {
-    Scope::with(|scope| scope.join(a, b))
+    Scope::with(|scope| scope.join(|_| a(), |_| b()))
 }
 
 pub struct ThreadPool {
@@ -1031,7 +1045,7 @@ impl ThreadPool {
     }
 
     pub fn scope<T, F: FnOnce(&Scope) -> T>(&self, f: F) -> T {
-        Scope::with_in(self.context.clone(), f)
+        Scope::with_in(&self.context, f)
     }
 }