reduce arc clones, pass references to scope to join functions

This commit is contained in:
Janis 2025-02-20 21:50:42 +01:00
parent e2d5208025
commit eb43c29389

View file

@ -457,7 +457,10 @@ mod job {
unsafe impl<T> Send for Job<T> {}
impl<T> Job<T> {
pub fn new(harness: unsafe fn(*const (), *const Job<T>), this: NonNull<()>) -> Job<T> {
pub fn new(
harness: unsafe fn(*const (), *const Job<T>, &super::Scope),
this: NonNull<()>,
) -> Job<T> {
Self {
harness_and_state: TaggedAtomicPtr::new(
unsafe { mem::transmute(harness) },
@ -473,6 +476,7 @@ mod job {
_phantom: PhantomPinned,
}
}
pub fn empty() -> Job<T> {
Self {
harness_and_state: TaggedAtomicPtr::new(
@ -492,6 +496,7 @@ mod job {
}
}
#[inline]
unsafe fn link_mut(&self) -> &mut Link<Job> {
unsafe { &mut (&mut *self.err_or_link.get()).link }
}
@ -594,16 +599,17 @@ mod job {
}
}
pub fn execute(&self) {
pub fn execute(&self, scope: &super::Scope) {
// SAFETY: self is non-null
unsafe {
let (ptr, state) = self.harness_and_state.ptr_and_tag(Ordering::Relaxed);
debug_assert_eq!(state, JobState::Pending as usize);
let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr());
let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) =
mem::transmute(ptr.as_ptr());
let this = (*self.val_or_this.get()).this;
harness(this.as_ptr().cast(), (self as *const Self).cast());
harness(this.as_ptr().cast(), (self as *const Self).cast(), scope);
}
}
@ -667,21 +673,20 @@ mod job {
#[allow(dead_code)]
pub fn into_boxed_job<T>(self: Box<Self>) -> Box<Job<()>>
where
F: FnOnce() -> T + Send,
F: FnOnce(&super::Scope) -> T + Send,
T: Send,
{
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
where
F: FnOnce() -> T + Send,
F: FnOnce(&super::Scope) -> T + Send,
T: Sized + Send,
{
let job = unsafe { &*job.cast::<Job<T>>() };
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
let f = this.f;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
let job = unsafe { &*job.cast::<Job<T>>() };
job.complete(result);
}
@ -708,18 +713,18 @@ mod job {
pub fn as_job<T>(&self) -> Job<()>
where
F: FnOnce() -> T + Send,
F: FnOnce(&super::Scope) -> T + Send,
T: Send,
{
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
where
F: FnOnce() -> T + Send,
F: FnOnce(&super::Scope) -> T + Send,
T: Sized + Send,
{
let this = unsafe { &*this.cast::<StackJob<F>>() };
let f = unsafe { this.unwrap() };
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
let job_ref = unsafe { &*job.cast::<Job<T>>() };
job_ref.complete(result);
@ -782,25 +787,25 @@ impl Scope {
}
}
fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: Arc<Context>, f: F) -> T {
fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: &Arc<Context>, f: F) -> T {
let mut guard = Option::<DropGuard<Box<dyn FnOnce()>>>::None;
let scope = match Self::current_ref() {
Some(scope) if Arc::ptr_eq(&scope.context, &ctx) => scope,
Some(scope) if Arc::ptr_eq(&scope.context, ctx) => scope,
Some(_) => {
let old = unsafe { Self::unset_current().unwrap().as_ptr() };
guard = Some(DropGuard::new(Box::new(move || unsafe {
_ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
Self::set_current(old.cast_const());
})));
let current = Box::into_raw(Box::new(Self::new_in(ctx)));
let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
unsafe {
Self::set_current(current.cast_const());
&*current
}
}
None => {
let current = Box::into_raw(Box::new(Self::new_in(ctx)));
let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
guard = Some(DropGuard::new(Box::new(|| unsafe {
_ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
@ -820,7 +825,7 @@ impl Scope {
}
pub fn with<T, F: FnOnce(&Scope) -> T>(f: F) -> T {
Self::with_in(Context::global().clone(), f)
Self::with_in(Context::global(), f)
}
unsafe fn set_current(scope: *const Scope) {
@ -861,37 +866,43 @@ impl Scope {
unsafe { self.queue.as_mut_unchecked().pop_front() }
}
#[inline]
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
self.join_heartbeat_every::<_, _, _, _, 64>(a, b)
// self.join_heartbeat(a, b)
}
pub fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
(a(), b())
let rb = b(&self);
let ra = a(&self);
(ra, rb)
}
pub fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
let count = self.join_count.get();
self.join_count.set(count.wrapping_add(1) % TIMES);
// let count = self.join_count.get();
// self.join_count.set(count.wrapping_add(1) % TIMES);
let count = self.join_count.update(|n| n.wrapping_add(1) % TIMES);
if count == 0 {
if count == 1 {
self.join_heartbeat(a, b)
} else {
self.join_seq(a, b)
@ -902,30 +913,33 @@ impl Scope {
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
let a = StackJob::new(a);
let a = StackJob::new(move |scope: &Scope| {
scope.tick();
a(scope)
});
let job = pin!(a.as_job());
self.push_front(job.as_ref());
let rb = b();
let rb = b(self);
let ra = if job.state() == JobState::Empty as u8 {
unsafe {
job.unlink();
}
self.tick();
unsafe { a.unwrap()() }
unsafe { a.unwrap()(self) }
} else {
match self.wait_until::<RA>(unsafe {
mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref())
}) {
Some(Ok(t)) => t,
Some(Err(payload)) => std::panic::resume_unwind(payload),
None => unsafe { a.unwrap()() },
None => unsafe { a.unwrap()(self) },
}
};
@ -933,7 +947,7 @@ impl Scope {
(ra, rb)
}
#[inline]
#[inline(always)]
fn tick(&self) {
if self.heartbeat.load(Ordering::Relaxed) {
self.heartbeat_cold();
@ -943,7 +957,7 @@ impl Scope {
#[inline]
fn execute(&self, job: &Job) {
self.tick();
job.execute();
job.execute(self);
}
#[cold]
@ -1010,7 +1024,7 @@ where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
Scope::with(|scope| scope.join(a, b))
Scope::with(|scope| scope.join(|_| a(), |_| b()))
}
pub struct ThreadPool {
@ -1031,7 +1045,7 @@ impl ThreadPool {
}
pub fn scope<T, F: FnOnce(&Scope) -> T>(&self, f: F) -> T {
Scope::with_in(self.context.clone(), f)
Scope::with_in(&self.context, f)
}
}