diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index 76cb588..2871990 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -457,7 +457,10 @@ mod job { unsafe impl Send for Job {} impl Job { - pub fn new(harness: unsafe fn(*const (), *const Job), this: NonNull<()>) -> Job { + pub fn new( + harness: unsafe fn(*const (), *const Job, &super::Scope), + this: NonNull<()>, + ) -> Job { Self { harness_and_state: TaggedAtomicPtr::new( unsafe { mem::transmute(harness) }, @@ -473,6 +476,7 @@ mod job { _phantom: PhantomPinned, } } + pub fn empty() -> Job { Self { harness_and_state: TaggedAtomicPtr::new( @@ -492,6 +496,7 @@ mod job { } } + #[inline] unsafe fn link_mut(&self) -> &mut Link { unsafe { &mut (&mut *self.err_or_link.get()).link } } @@ -594,16 +599,17 @@ mod job { } } - pub fn execute(&self) { + pub fn execute(&self, scope: &super::Scope) { // SAFETY: self is non-null unsafe { let (ptr, state) = self.harness_and_state.ptr_and_tag(Ordering::Relaxed); debug_assert_eq!(state, JobState::Pending as usize); - let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr()); + let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) = + mem::transmute(ptr.as_ptr()); let this = (*self.val_or_this.get()).this; - harness(this.as_ptr().cast(), (self as *const Self).cast()); + harness(this.as_ptr().cast(), (self as *const Self).cast(), scope); } } @@ -667,21 +673,20 @@ mod job { #[allow(dead_code)] pub fn into_boxed_job(self: Box) -> Box> where - F: FnOnce() -> T + Send, + F: FnOnce(&super::Scope) -> T + Send, T: Send, { - unsafe fn harness(this: *const (), job: *const Job<()>) + unsafe fn harness(this: *const (), job: *const Job<()>, scope: &super::Scope) where - F: FnOnce() -> T + Send, + F: FnOnce(&super::Scope) -> T + Send, T: Sized + Send, { - let job = unsafe { &*job.cast::>() }; - let this = unsafe { Box::from_raw(this.cast::>().cast_mut()) }; let f = this.f; - let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope))); + let job = unsafe { &*job.cast::>() }; job.complete(result); } @@ -708,18 +713,18 @@ mod job { pub fn as_job(&self) -> Job<()> where - F: FnOnce() -> T + Send, + F: FnOnce(&super::Scope) -> T + Send, T: Send, { - unsafe fn harness(this: *const (), job: *const Job<()>) + unsafe fn harness(this: *const (), job: *const Job<()>, scope: &super::Scope) where - F: FnOnce() -> T + Send, + F: FnOnce(&super::Scope) -> T + Send, T: Sized + Send, { let this = unsafe { &*this.cast::>() }; let f = unsafe { this.unwrap() }; - let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope))); let job_ref = unsafe { &*job.cast::>() }; job_ref.complete(result); @@ -782,25 +787,25 @@ impl Scope { } } - fn with_in T>(ctx: Arc, f: F) -> T { + fn with_in T>(ctx: &Arc, f: F) -> T { let mut guard = Option::>>::None; let scope = match Self::current_ref() { - Some(scope) if Arc::ptr_eq(&scope.context, &ctx) => scope, + Some(scope) if Arc::ptr_eq(&scope.context, ctx) => scope, Some(_) => { let old = unsafe { Self::unset_current().unwrap().as_ptr() }; guard = Some(DropGuard::new(Box::new(move || unsafe { _ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); Self::set_current(old.cast_const()); }))); - let current = Box::into_raw(Box::new(Self::new_in(ctx))); + let current = Box::into_raw(Box::new(Self::new_in(ctx.clone()))); unsafe { Self::set_current(current.cast_const()); &*current } } None => { - let current = Box::into_raw(Box::new(Self::new_in(ctx))); + let current = Box::into_raw(Box::new(Self::new_in(ctx.clone()))); guard = Some(DropGuard::new(Box::new(|| unsafe { _ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); @@ -820,7 +825,7 @@ impl Scope { } pub fn with T>(f: F) -> T { - Self::with_in(Context::global().clone(), f) + Self::with_in(Context::global(), f) } unsafe fn set_current(scope: *const Scope) { @@ -861,37 +866,43 @@ impl Scope { unsafe { self.queue.as_mut_unchecked().pop_front() } } + #[inline] pub fn join(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, + A: FnOnce(&Self) -> RA + Send, + B: FnOnce(&Self) -> RB + Send, { self.join_heartbeat_every::<_, _, _, _, 64>(a, b) + // self.join_heartbeat(a, b) } pub fn join_seq(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, + A: FnOnce(&Self) -> RA + Send, + B: FnOnce(&Self) -> RB + Send, { - (a(), b()) + let rb = b(&self); + let ra = a(&self); + + (ra, rb) } pub fn join_heartbeat_every(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, + A: FnOnce(&Self) -> RA + Send, + B: FnOnce(&Self) -> RB + Send, { - let count = self.join_count.get(); - self.join_count.set(count.wrapping_add(1) % TIMES); + // let count = self.join_count.get(); + // self.join_count.set(count.wrapping_add(1) % TIMES); + let count = self.join_count.update(|n| n.wrapping_add(1) % TIMES); - if count == 0 { + if count == 1 { self.join_heartbeat(a, b) } else { self.join_seq(a, b) @@ -902,30 +913,33 @@ impl Scope { where RA: Send, RB: Send, - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, + A: FnOnce(&Self) -> RA + Send, + B: FnOnce(&Self) -> RB + Send, { - let a = StackJob::new(a); + let a = StackJob::new(move |scope: &Scope| { + scope.tick(); + + a(scope) + }); let job = pin!(a.as_job()); self.push_front(job.as_ref()); - let rb = b(); + let rb = b(self); let ra = if job.state() == JobState::Empty as u8 { unsafe { job.unlink(); } - self.tick(); - unsafe { a.unwrap()() } + unsafe { a.unwrap()(self) } } else { match self.wait_until::(unsafe { mem::transmute::>, Pin<&Job>>(job.as_ref()) }) { Some(Ok(t)) => t, Some(Err(payload)) => std::panic::resume_unwind(payload), - None => unsafe { a.unwrap()() }, + None => unsafe { a.unwrap()(self) }, } }; @@ -933,7 +947,7 @@ impl Scope { (ra, rb) } - #[inline] + #[inline(always)] fn tick(&self) { if self.heartbeat.load(Ordering::Relaxed) { self.heartbeat_cold(); @@ -943,7 +957,7 @@ impl Scope { #[inline] fn execute(&self, job: &Job) { self.tick(); - job.execute(); + job.execute(self); } #[cold] @@ -1010,7 +1024,7 @@ where A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { - Scope::with(|scope| scope.join(a, b)) + Scope::with(|scope| scope.join(|_| a(), |_| b())) } pub struct ThreadPool { @@ -1031,7 +1045,7 @@ impl ThreadPool { } pub fn scope T>(&self, f: F) -> T { - Scope::with_in(self.context.clone(), f) + Scope::with_in(&self.context, f) } }