reduce arc clones, pass references to scope to join functions

This commit is contained in:
Janis 2025-02-20 21:50:42 +01:00
parent e2d5208025
commit eb43c29389

View file

@ -457,7 +457,10 @@ mod job {
unsafe impl<T> Send for Job<T> {} unsafe impl<T> Send for Job<T> {}
impl<T> Job<T> { impl<T> Job<T> {
pub fn new(harness: unsafe fn(*const (), *const Job<T>), this: NonNull<()>) -> Job<T> { pub fn new(
harness: unsafe fn(*const (), *const Job<T>, &super::Scope),
this: NonNull<()>,
) -> Job<T> {
Self { Self {
harness_and_state: TaggedAtomicPtr::new( harness_and_state: TaggedAtomicPtr::new(
unsafe { mem::transmute(harness) }, unsafe { mem::transmute(harness) },
@ -473,6 +476,7 @@ mod job {
_phantom: PhantomPinned, _phantom: PhantomPinned,
} }
} }
pub fn empty() -> Job<T> { pub fn empty() -> Job<T> {
Self { Self {
harness_and_state: TaggedAtomicPtr::new( harness_and_state: TaggedAtomicPtr::new(
@ -492,6 +496,7 @@ mod job {
} }
} }
#[inline]
unsafe fn link_mut(&self) -> &mut Link<Job> { unsafe fn link_mut(&self) -> &mut Link<Job> {
unsafe { &mut (&mut *self.err_or_link.get()).link } unsafe { &mut (&mut *self.err_or_link.get()).link }
} }
@ -594,16 +599,17 @@ mod job {
} }
} }
pub fn execute(&self) { pub fn execute(&self, scope: &super::Scope) {
// SAFETY: self is non-null // SAFETY: self is non-null
unsafe { unsafe {
let (ptr, state) = self.harness_and_state.ptr_and_tag(Ordering::Relaxed); let (ptr, state) = self.harness_and_state.ptr_and_tag(Ordering::Relaxed);
debug_assert_eq!(state, JobState::Pending as usize); debug_assert_eq!(state, JobState::Pending as usize);
let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr()); let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) =
mem::transmute(ptr.as_ptr());
let this = (*self.val_or_this.get()).this; let this = (*self.val_or_this.get()).this;
harness(this.as_ptr().cast(), (self as *const Self).cast()); harness(this.as_ptr().cast(), (self as *const Self).cast(), scope);
} }
} }
@ -667,21 +673,20 @@ mod job {
#[allow(dead_code)] #[allow(dead_code)]
pub fn into_boxed_job<T>(self: Box<Self>) -> Box<Job<()>> pub fn into_boxed_job<T>(self: Box<Self>) -> Box<Job<()>>
where where
F: FnOnce() -> T + Send, F: FnOnce(&super::Scope) -> T + Send,
T: Send, T: Send,
{ {
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>) unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
where where
F: FnOnce() -> T + Send, F: FnOnce(&super::Scope) -> T + Send,
T: Sized + Send, T: Sized + Send,
{ {
let job = unsafe { &*job.cast::<Job<T>>() };
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) }; let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
let f = this.f; let f = this.f;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
let job = unsafe { &*job.cast::<Job<T>>() };
job.complete(result); job.complete(result);
} }
@ -708,18 +713,18 @@ mod job {
pub fn as_job<T>(&self) -> Job<()> pub fn as_job<T>(&self) -> Job<()>
where where
F: FnOnce() -> T + Send, F: FnOnce(&super::Scope) -> T + Send,
T: Send, T: Send,
{ {
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>) unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
where where
F: FnOnce() -> T + Send, F: FnOnce(&super::Scope) -> T + Send,
T: Sized + Send, T: Sized + Send,
{ {
let this = unsafe { &*this.cast::<StackJob<F>>() }; let this = unsafe { &*this.cast::<StackJob<F>>() };
let f = unsafe { this.unwrap() }; let f = unsafe { this.unwrap() };
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
let job_ref = unsafe { &*job.cast::<Job<T>>() }; let job_ref = unsafe { &*job.cast::<Job<T>>() };
job_ref.complete(result); job_ref.complete(result);
@ -782,25 +787,25 @@ impl Scope {
} }
} }
fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: Arc<Context>, f: F) -> T { fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: &Arc<Context>, f: F) -> T {
let mut guard = Option::<DropGuard<Box<dyn FnOnce()>>>::None; let mut guard = Option::<DropGuard<Box<dyn FnOnce()>>>::None;
let scope = match Self::current_ref() { let scope = match Self::current_ref() {
Some(scope) if Arc::ptr_eq(&scope.context, &ctx) => scope, Some(scope) if Arc::ptr_eq(&scope.context, ctx) => scope,
Some(_) => { Some(_) => {
let old = unsafe { Self::unset_current().unwrap().as_ptr() }; let old = unsafe { Self::unset_current().unwrap().as_ptr() };
guard = Some(DropGuard::new(Box::new(move || unsafe { guard = Some(DropGuard::new(Box::new(move || unsafe {
_ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); _ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
Self::set_current(old.cast_const()); Self::set_current(old.cast_const());
}))); })));
let current = Box::into_raw(Box::new(Self::new_in(ctx))); let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
unsafe { unsafe {
Self::set_current(current.cast_const()); Self::set_current(current.cast_const());
&*current &*current
} }
} }
None => { None => {
let current = Box::into_raw(Box::new(Self::new_in(ctx))); let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
guard = Some(DropGuard::new(Box::new(|| unsafe { guard = Some(DropGuard::new(Box::new(|| unsafe {
_ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); _ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
@ -820,7 +825,7 @@ impl Scope {
} }
pub fn with<T, F: FnOnce(&Scope) -> T>(f: F) -> T { pub fn with<T, F: FnOnce(&Scope) -> T>(f: F) -> T {
Self::with_in(Context::global().clone(), f) Self::with_in(Context::global(), f)
} }
unsafe fn set_current(scope: *const Scope) { unsafe fn set_current(scope: *const Scope) {
@ -861,37 +866,43 @@ impl Scope {
unsafe { self.queue.as_mut_unchecked().pop_front() } unsafe { self.queue.as_mut_unchecked().pop_front() }
} }
#[inline]
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB) pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where where
RA: Send, RA: Send,
RB: Send, RB: Send,
A: FnOnce() -> RA + Send, A: FnOnce(&Self) -> RA + Send,
B: FnOnce() -> RB + Send, B: FnOnce(&Self) -> RB + Send,
{ {
self.join_heartbeat_every::<_, _, _, _, 64>(a, b) self.join_heartbeat_every::<_, _, _, _, 64>(a, b)
// self.join_heartbeat(a, b)
} }
pub fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB) pub fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where where
RA: Send, RA: Send,
RB: Send, RB: Send,
A: FnOnce() -> RA + Send, A: FnOnce(&Self) -> RA + Send,
B: FnOnce() -> RB + Send, B: FnOnce(&Self) -> RB + Send,
{ {
(a(), b()) let rb = b(&self);
let ra = a(&self);
(ra, rb)
} }
pub fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB) pub fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
where where
RA: Send, RA: Send,
RB: Send, RB: Send,
A: FnOnce() -> RA + Send, A: FnOnce(&Self) -> RA + Send,
B: FnOnce() -> RB + Send, B: FnOnce(&Self) -> RB + Send,
{ {
let count = self.join_count.get(); // let count = self.join_count.get();
self.join_count.set(count.wrapping_add(1) % TIMES); // self.join_count.set(count.wrapping_add(1) % TIMES);
let count = self.join_count.update(|n| n.wrapping_add(1) % TIMES);
if count == 0 { if count == 1 {
self.join_heartbeat(a, b) self.join_heartbeat(a, b)
} else { } else {
self.join_seq(a, b) self.join_seq(a, b)
@ -902,30 +913,33 @@ impl Scope {
where where
RA: Send, RA: Send,
RB: Send, RB: Send,
A: FnOnce() -> RA + Send, A: FnOnce(&Self) -> RA + Send,
B: FnOnce() -> RB + Send, B: FnOnce(&Self) -> RB + Send,
{ {
let a = StackJob::new(a); let a = StackJob::new(move |scope: &Scope| {
scope.tick();
a(scope)
});
let job = pin!(a.as_job()); let job = pin!(a.as_job());
self.push_front(job.as_ref()); self.push_front(job.as_ref());
let rb = b(); let rb = b(self);
let ra = if job.state() == JobState::Empty as u8 { let ra = if job.state() == JobState::Empty as u8 {
unsafe { unsafe {
job.unlink(); job.unlink();
} }
self.tick(); unsafe { a.unwrap()(self) }
unsafe { a.unwrap()() }
} else { } else {
match self.wait_until::<RA>(unsafe { match self.wait_until::<RA>(unsafe {
mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref()) mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref())
}) { }) {
Some(Ok(t)) => t, Some(Ok(t)) => t,
Some(Err(payload)) => std::panic::resume_unwind(payload), Some(Err(payload)) => std::panic::resume_unwind(payload),
None => unsafe { a.unwrap()() }, None => unsafe { a.unwrap()(self) },
} }
}; };
@ -933,7 +947,7 @@ impl Scope {
(ra, rb) (ra, rb)
} }
#[inline] #[inline(always)]
fn tick(&self) { fn tick(&self) {
if self.heartbeat.load(Ordering::Relaxed) { if self.heartbeat.load(Ordering::Relaxed) {
self.heartbeat_cold(); self.heartbeat_cold();
@ -943,7 +957,7 @@ impl Scope {
#[inline] #[inline]
fn execute(&self, job: &Job) { fn execute(&self, job: &Job) {
self.tick(); self.tick();
job.execute(); job.execute(self);
} }
#[cold] #[cold]
@ -1010,7 +1024,7 @@ where
A: FnOnce() -> RA + Send, A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send, B: FnOnce() -> RB + Send,
{ {
Scope::with(|scope| scope.join(a, b)) Scope::with(|scope| scope.join(|_| a(), |_| b()))
} }
pub struct ThreadPool { pub struct ThreadPool {
@ -1031,7 +1045,7 @@ impl ThreadPool {
} }
pub fn scope<T, F: FnOnce(&Scope) -> T>(&self, f: F) -> T { pub fn scope<T, F: FnOnce(&Scope) -> T>(&self, f: F) -> T {
Scope::with_in(self.context.clone(), f) Scope::with_in(&self.context, f)
} }
} }