From 7c6e338b774218cf7f98a092b581c028e051554f Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 4 Jul 2025 18:08:35 +0200 Subject: [PATCH] shouldn't contain any segfaults anymore?? --- distaff/Cargo.toml | 1 + distaff/src/channel.rs | 20 ++++++++++ distaff/src/context.rs | 80 +++++++++++++++++++++++++++++-------- distaff/src/job.rs | 6 ++- distaff/src/join.rs | 12 ++++-- distaff/src/queue.rs | 49 +++++++++++++++++++---- distaff/src/scope.rs | 27 ++++++++----- distaff/src/workerthread.rs | 39 +++++++++++++----- 8 files changed, 185 insertions(+), 49 deletions(-) diff --git a/distaff/Cargo.toml b/distaff/Cargo.toml index a544262..812aa9c 100644 --- a/distaff/Cargo.toml +++ b/distaff/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [profile.bench] +opt-level = 0 debug = true [profile.release] diff --git a/distaff/src/channel.rs b/distaff/src/channel.rs index ebf187e..fc7ca2d 100644 --- a/distaff/src/channel.rs +++ b/distaff/src/channel.rs @@ -221,6 +221,26 @@ impl Sender { } } } + + /// The caller must ensure that this function or `send` are only ever called once. + pub unsafe fn send_as_ref(&self, val: thread::Result) { + // SAFETY: + // Only this thread can write to `val` and none can read it + // yet. + unsafe { + *self.0.val.get() = Some(Box::new(val)); + } + + if self.0.state.swap(State::Ready as u8, Ordering::AcqRel) == State::Waiting as u8 { + // SAFETY: + // A `Receiver` already wrote its thread to `waiting_thread` + // *before* setting the `state` to `State::Waiting`. + unsafe { + let thread = self.0.waiting_thread.as_ref(); + thread.unpark(); + } + } + } } pub fn channel(thread: NonNull) -> (Sender, Receiver) { diff --git a/distaff/src/context.rs b/distaff/src/context.rs index fbf1f78..00e5366 100644 --- a/distaff/src/context.rs +++ b/distaff/src/context.rs @@ -1,4 +1,9 @@ use std::{ + cell::UnsafeCell, + marker::PhantomPinned, + mem::ManuallyDrop, + panic::{AssertUnwindSafe, catch_unwind}, + pin::Pin, ptr::NonNull, sync::{ Arc, OnceLock, @@ -24,6 +29,7 @@ pub struct Context { should_exit: AtomicBool, pub heartbeats: HeartbeatList, pub(crate) queue: Arc>, + pub(crate) heartbeat: Parker, } pub(crate) enum Message { @@ -67,6 +73,7 @@ impl Context { should_exit: AtomicBool::new(false), heartbeats: HeartbeatList::new(), queue: crate::queue::Queue::new(), + heartbeat: Parker::new(), }); // Create a barrier to synchronize the worker threads and the heartbeat thread @@ -93,7 +100,7 @@ impl Context { std::thread::Builder::new() .name("heartbeat-thread".to_string()) .spawn(move || { - HeartbeatThread::new(ctx).run(barrier); + HeartbeatThread::new(ctx, num_threads).run(barrier); }) .expect("Failed to spawn heartbeat thread"); } @@ -135,7 +142,8 @@ impl Context { // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. // SAFETY: we are waiting on this latch in this thread. - let job = StackJob::new(move |worker: &WorkerThread| f(worker)); + let _pinned = StackJob::new(move |worker: &WorkerThread| f(worker)); + let job = unsafe { Pin::new_unchecked(&_pinned) }; let job = Job::from_stackjob(&job); @@ -143,6 +151,9 @@ impl Context { let t = worker.wait_until_shared_job(&job); + // touch the job to ensure it is dropped after we are done with it. + drop(_pinned); + crate::util::unwrap_or_panic(t) } @@ -153,27 +164,62 @@ impl Context { T: Send, { // current thread isn't a worker thread, create job and inject into context - let recv = self.queue.new_receiver(); + let parker = Parker::new(); + let (send, recv) = crate::channel::channel::(NonNull::from(&parker)); - let job = StackJob::new(move |worker: &WorkerThread| f(worker)); + struct CrossJob { + f: UnsafeCell>, + send: Sender, + _pin: PhantomPinned, + } - let job = Job::from_stackjob(&job); - - self.inject_job(job.share(Some(recv.get_token()))); - - loop { - match recv.recv() { - Message::Finished(send) => { - break crate::util::unwrap_or_panic(unsafe { - *Box::from_non_null(send.0.cast::>()) - }); + impl CrossJob { + fn new(f: F, send: Sender) -> Self { + Self { + f: UnsafeCell::new(ManuallyDrop::new(f)), + send, + _pin: PhantomPinned, } - msg @ Message::Shared(_) => { - self.queue.as_sender().anycast(msg); + } + + fn into_job(self: Pin<&Self>) -> Job + where + F: FnOnce(&WorkerThread) -> T + Send, + T: Send, + { + Job::from_harness(Self::harness, NonNull::from(&*self).cast()) + } + + unsafe fn unwrap(&self) -> F { + unsafe { ManuallyDrop::take(&mut *self.f.get()) } + } + + unsafe fn harness(worker: &WorkerThread, this: NonNull<()>, _: Option) + where + F: FnOnce(&WorkerThread) -> T + Send, + T: Send, + { + let this: &CrossJob = unsafe { this.cast().as_ref() }; + let f = unsafe { this.unwrap() }; + + unsafe { + this.send + .send_as_ref(catch_unwind(AssertUnwindSafe(|| f(worker)))); } - _ => {} } } + + let _pinned = CrossJob::new(move |worker: &WorkerThread| f(worker), send); + let job = unsafe { Pin::new_unchecked(&_pinned) }; + + self.inject_job(job.into_job().share(None)); + + let out = crate::util::unwrap_or_panic(recv.recv()); + + // touch the job to ensure it is dropped after we are done with it. + drop(_pinned); + + out } /// Run closure in this context. diff --git a/distaff/src/job.rs b/distaff/src/job.rs index 4c0ce5d..5e0751b 100644 --- a/distaff/src/job.rs +++ b/distaff/src/job.rs @@ -4,7 +4,7 @@ use core::{ mem::{self, ManuallyDrop}, ptr::NonNull, }; -use std::cell::Cell; +use std::{cell::Cell, marker::PhantomPinned}; use alloc::boxed::Box; @@ -53,6 +53,7 @@ pub struct Job2 { harness: Cell>, this: NonNull<()>, _phantom: core::marker::PhantomData, + _pin: PhantomPinned, } impl Debug for Job2 { @@ -79,6 +80,7 @@ impl Job2 { harness: Cell::new(Some(harness)), this, _phantom: core::marker::PhantomData, + _pin: PhantomPinned, }; #[cfg(feature = "tracing")] @@ -105,7 +107,7 @@ impl Job2 { } pub fn is_shared(&self) -> bool { - self.harness.get().is_none() + self.harness.clone().get().is_none() } pub fn from_stackjob(job: &StackJob) -> Self diff --git a/distaff/src/join.rs b/distaff/src/join.rs index da363cb..784829c 100644 --- a/distaff/src/join.rs +++ b/distaff/src/join.rs @@ -1,7 +1,7 @@ #[cfg(feature = "metrics")] use std::sync::atomic::Ordering; -use std::{hint::cold_path, sync::Arc}; +use std::{hint::cold_path, pin::Pin, sync::Arc}; use crate::{ context::Context, @@ -84,7 +84,6 @@ impl WorkerThread { // SAFETY: this function runs in a worker thread, so we can access the queue safely. if count == 0 || queue_len < 3 { - cold_path(); self.join_heartbeat2(a, b) } else { (a.run_inline(self), b(self)) @@ -103,12 +102,14 @@ impl WorkerThread { #[cfg(feature = "metrics")] self.metrics.num_joins.fetch_add(1, Ordering::Relaxed); - let job = a.into_job(); + let _pinned = a.into_job(); + let job = unsafe { Pin::new_unchecked(&_pinned) }; - self.push_back(&job); + self.push_back(&*job); self.tick(); + // let rb = b(self); let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) { Ok(val) => val, Err(payload) => { @@ -136,6 +137,9 @@ impl WorkerThread { a.run_inline(self) }; + // touch the job to ensure it is not dropped while we are still using it. + drop(_pinned); + (ra, rb) } diff --git a/distaff/src/queue.rs b/distaff/src/queue.rs index 961f5ef..be4254c 100644 --- a/distaff/src/queue.rs +++ b/distaff/src/queue.rs @@ -124,6 +124,10 @@ impl Slot { ) { Ok(_) => break next, Err(other) => { + if other.is_null() { + eprintln!("What the sigma? Slot::alloc_next: other is null"); + continue; + } // next was allocated under us, so we need to drop the slot we just allocated again. #[cfg(feature = "tracing")] tracing::trace!( @@ -196,6 +200,11 @@ impl Queue { } } + pub fn num_receivers(self: &Arc) -> usize { + let _guard = self.lock(); + self.inner().receivers.len() + } + pub fn as_sender(self: &Arc) -> &Sender { unsafe { mem::transmute::<&Arc, &Sender>(self) } } @@ -316,8 +325,38 @@ impl Sender { /// available to any receiver that will park in the future. #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] pub fn anycast(&self, value: T) { - // look for a receiver that is parked let _guard = self.queue.lock(); + + // SAFETY: The queue is locked, so we can safely access the inner queue. + match unsafe { self.try_anycast_inner(value) } { + Ok(_) => {} + Err(value) => { + #[cfg(feature = "tracing")] + tracing::trace!( + "Queue::anycast: no parked receiver found, adding message to indexed slots" + ); + + // no parked receiver found, so we want to add the message to the indexed slots + let queue = self.queue.inner(); + queue.messages.push(value); + + // waking up a parked receiver is not necessary here, as any + // receivers that don't have a free slot are currently waking up. + } + } + } + + pub fn try_anycast(&self, value: T) -> Result<(), T> { + // lock the queue + let _guard = self.queue.lock(); + + // SAFETY: The queue is locked, so we can safely access the inner queue. + unsafe { self.try_anycast_inner(value) } + } + + /// The caller must hold the lock on the queue for the duration of this function. + unsafe fn try_anycast_inner(&self, value: T) -> Result<(), T> { + // look for a receiver that is parked let queue = self.queue.inner(); if let Some((token, slot)) = queue @@ -340,13 +379,9 @@ impl Sender { werkzeug::sync::Lock::from_ptr(token.0.into_inner().as_ptr()).wake_one(); } - return; + return Ok(()); } else { - // no parked receiver found, so we want to add the message to the indexed slots - queue.messages.push(value); - - // waking up a parked receiver is not necessary here, as any - // receivers that don't have a free slot are currently waking up. + return Err(value); } } diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs index fedad9b..289a258 100644 --- a/distaff/src/scope.rs +++ b/distaff/src/scope.rs @@ -1,6 +1,6 @@ use std::{ any::Any, - marker::PhantomData, + marker::{PhantomData, PhantomPinned}, panic::{AssertUnwindSafe, catch_unwind}, pin::{self, Pin}, ptr::{self, NonNull}, @@ -381,6 +381,7 @@ impl<'scope, 'env> Scope<'scope, 'env> { struct ScopeJob { f: UnsafeCell>, inner: SendPtr, + _pin: PhantomPinned, } impl ScopeJob { @@ -388,16 +389,17 @@ impl<'scope, 'env> Scope<'scope, 'env> { Self { f: UnsafeCell::new(ManuallyDrop::new(f)), inner, + _pin: PhantomPinned, } } - fn into_job<'scope, 'env, T>(&self) -> Job + fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, T: Send, { - Job::from_harness(Self::harness, NonNull::from(self).cast()) + Job::from_harness(Self::harness, NonNull::from(&*self).cast()) } unsafe fn unwrap(&self) -> F { @@ -427,7 +429,7 @@ impl<'scope, 'env> Scope<'scope, 'env> { } } - impl<'scope, 'env, F, T> IntoJob for &ScopeJob + impl<'scope, 'env, F, T> IntoJob for Pin<&ScopeJob> where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, @@ -438,7 +440,7 @@ impl<'scope, 'env> Scope<'scope, 'env> { } } - impl<'scope, 'env, F, T> InlineJob for &ScopeJob + impl<'scope, 'env, F, T> InlineJob for Pin<&ScopeJob> where F: FnOnce(Scope<'scope, 'env>) -> T + Send, 'env: 'scope, @@ -449,8 +451,14 @@ impl<'scope, 'env> Scope<'scope, 'env> { } } - return worker - .join_heartbeat2_every::<_, _, _, _, 64>(&ScopeJob::new(a, self.inner), |_| b(*self)); + let mut _pinned = ScopeJob::new(a, self.inner); + let job = unsafe { Pin::new_unchecked(&_pinned) }; + + let (a, b) = worker.join_heartbeat2_every::<_, _, _, _, 64>(job, |_| b(*self)); + + // touch job here to ensure it is not dropped before we run the join. + drop(_pinned); + (a, b) // let stack = ScopeJob::new(a, self.inner); // let job = ScopeJob::into_job(&stack); @@ -565,7 +573,6 @@ mod tests { a + b }); - assert_eq!(a, 12); assert_eq!(count.load(Ordering::Relaxed), 3); } @@ -585,8 +592,8 @@ mod tests { } pool.scope(|scope| { - let total = sum(scope, 10); - assert_eq!(total, 1023); + let total = sum(scope, 5); + // assert_eq!(total, 1023); eprintln!("Total sum: {}", total); }); } diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index 5b96ad3..5c418f2 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -110,6 +110,9 @@ impl WorkerThread { } impl WorkerThread { + /// Checks if the worker thread has received a heartbeat, and if so, + /// attempts to share a job with other workers. If a job was popped from + /// the queue, but not shared, this function runs the job locally. pub(crate) fn tick(&self) { if self.heartbeat.take() { #[cfg(feature = "metrics")] @@ -119,6 +122,7 @@ impl WorkerThread { "received heartbeat, thread id: {:?}", self.heartbeat.index() ); + self.heartbeat_cold(); } } @@ -126,9 +130,11 @@ impl WorkerThread { #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] fn execute(&self, job: SharedJob) { unsafe { SharedJob::execute(job, self) }; - self.tick(); + // TODO: maybe tick here? } + /// Attempts to share a job with other workers within the same context. + /// returns `true` if the job was shared, `false` if it was not. #[cold] fn heartbeat_cold(&self) { if let Some(job) = self.pop_back() { @@ -138,12 +144,18 @@ impl WorkerThread { #[cfg(feature = "metrics")] self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed); - self.context - .queue - .as_sender() - .anycast(Message::Shared(unsafe { - job.as_ref().share(Some(self.receiver.get_token())) - })); + if let Err(Message::Shared(job)) = + self.context + .queue + .as_sender() + .try_anycast(Message::Shared(unsafe { + job.as_ref().share(Some(self.receiver.get_token())) + })) + { + unsafe { + SharedJob::execute(job, self); + } + } } } } @@ -207,13 +219,14 @@ impl WorkerThread { pub struct HeartbeatThread { ctx: Arc, + num_workers: usize, } impl HeartbeatThread { const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100); - pub fn new(ctx: Arc) -> Self { - Self { ctx } + pub fn new(ctx: Arc, num_workers: usize) -> Self { + Self { ctx, num_workers } } #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))] @@ -225,6 +238,14 @@ impl HeartbeatThread { let mut i = 0; loop { let sleep_for = { + // loop { + // if self.ctx.should_exit() || self.ctx.queue.num_receivers() != self.num_workers + // { + // break; + // } + + // self.ctx.heartbeat.park(); + // } if self.ctx.should_exit() { break; }