diff --git a/distaff/src/context.rs b/distaff/src/context.rs index d17ce6d..5902a7e 100644 --- a/distaff/src/context.rs +++ b/distaff/src/context.rs @@ -19,28 +19,25 @@ use crate::{ }; pub struct Heartbeat { - heartbeat: AtomicU8, pub latch: HeartbeatLatch, } impl Heartbeat { - pub const CLEAR: u8 = 0; - pub const PENDING: u8 = 1; - pub const SLEEPING: u8 = 2; - - pub fn new() -> (Arc>, Weak>) { - let strong = Arc::new(CachePadded::new(Self { - heartbeat: AtomicU8::new(Self::CLEAR), + pub fn new() -> NonNull> { + let ptr = Box::new(CachePadded::new(Self { latch: HeartbeatLatch::new(), })); - let weak = Arc::downgrade(&strong); - (strong, weak) + Box::into_non_null(ptr) } pub fn is_pending(&self) -> bool { self.latch.as_core_latch().poll_heartbeat() } + + pub fn is_sleeping(&self) -> bool { + self.latch.as_core_latch().is_sleeping() + } } pub struct Context { @@ -50,7 +47,7 @@ pub struct Context { pub(crate) struct Shared { pub jobs: BTreeMap>, - pub heartbeats: BTreeMap>>, + pub heartbeats: BTreeMap>>, injected_jobs: Vec>, heartbeat_count: usize, should_exit: bool, @@ -59,15 +56,15 @@ pub(crate) struct Shared { unsafe impl Send for Shared {} impl Shared { - pub fn new_heartbeat(&mut self) -> (Arc>, usize) { + pub fn new_heartbeat(&mut self) -> (NonNull>, usize) { let index = self.heartbeat_count; self.heartbeat_count = index.wrapping_add(1); - let (strong, weak) = Heartbeat::new(); + let heatbeat = Heartbeat::new(); - self.heartbeats.insert(index, weak); + self.heartbeats.insert(index, heatbeat); - (strong, index) + (heatbeat, index) } pub(crate) fn remove_heartbeat(&mut self, index: usize) { @@ -91,12 +88,12 @@ impl Shared { } pub fn notify_job_shared(&self) { - _ = self.heartbeats.iter().find(|(_, heartbeat)| { - if let Some(heartbeat) = heartbeat.upgrade() { - heartbeat.latch.signal_job_shared(); - true + _ = self.heartbeats.iter().find(|(_, heartbeat)| unsafe { + if heartbeat.as_ref().is_sleeping() { + heartbeat.as_ref().latch.signal_job_shared(); + return true; } else { - false + return false; } }); } @@ -166,11 +163,10 @@ impl Context { let mut shared = self.shared.lock(); shared.should_exit = true; for (_, heartbeat) in shared.heartbeats.iter() { - if let Some(heartbeat) = heartbeat.upgrade() { - heartbeat.latch.signal_job_shared(); + unsafe { + heartbeat.as_ref().latch.signal_job_shared(); } } - self.shared_job.notify_all(); } pub fn new() -> Arc { @@ -199,7 +195,7 @@ impl Context { // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. // SAFETY: we are waiting on this latch in this thread. - let latch = unsafe { UnsafeWakeLatch::new(&raw const worker.heartbeat.latch) }; + let latch = unsafe { UnsafeWakeLatch::new(&raw const worker.heartbeat().latch) }; let job = StackJob::new( move || { diff --git a/distaff/src/join.rs b/distaff/src/join.rs index ce2f2d0..11f4642 100644 --- a/distaff/src/join.rs +++ b/distaff/src/join.rs @@ -65,7 +65,7 @@ impl WorkerThread { // SAFETY: this thread's heartbeat latch is valid until the job sets it // because we will be waiting on it. - let latch = unsafe { UnsafeWakeLatch::new(&raw const (*self.heartbeat).latch) }; + let latch = unsafe { UnsafeWakeLatch::new(&raw const self.heartbeat().latch) }; let a = StackJob::new(a, LatchRef::new(&latch)); diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs index a3b6740..989fed0 100644 --- a/distaff/src/latch.rs +++ b/distaff/src/latch.rs @@ -414,12 +414,10 @@ impl Latch for WakeLatch { if CoreLatch::set(&(&*this).inner) { let ctx = WorkerThread::current_ref().unwrap().context.clone(); // If the latch was sleeping, wake the worker thread - ctx.shared().heartbeats.get(&worker_index).and_then(|weak| { - weak.upgrade().map(|heartbeat| { - // we set the latch to wake the worker so it knows to check the heartbeat - heartbeat.latch.signal_job_finished() - }) - }); + ctx.shared() + .heartbeats + .get(&worker_index) + .map(|ptr| ptr.as_ref().latch.signal_job_finished()); } } } @@ -591,7 +589,7 @@ mod tests { tracing::info!("worker thread started: {:?}", worker.index); let latch = Arc::new(WakeLatch::new(worker.index)); worker.context.spawn({ - let heartbeat = worker.heartbeat.clone(); + let heartbeat = unsafe { crate::util::Send::new(worker.heartbeat) }; let barrier = barrier.clone(); let count = count.clone(); let latch = latch.clone(); @@ -599,7 +597,9 @@ mod tests { tracing::info!("sleeping workerthread"); latch.as_core_latch().set_sleeping(); - heartbeat.latch.wait_and_reset(); + unsafe { + heartbeat.as_ref().latch.wait_and_reset(); + } tracing::info!("woken up workerthread"); count.fetch_add(1, Ordering::SeqCst); tracing::info!("waiting on barrier"); diff --git a/distaff/src/util.rs b/distaff/src/util.rs index 93ac342..ddc9ec1 100644 --- a/distaff/src/util.rs +++ b/distaff/src/util.rs @@ -54,7 +54,7 @@ impl core::fmt::Pointer for SendPtr { } } -unsafe impl Send for SendPtr {} +unsafe impl core::marker::Send for SendPtr {} impl Deref for SendPtr { type Target = NonNull; @@ -408,6 +408,30 @@ pub fn available_parallelism() -> usize { .unwrap_or(1) } +#[repr(transparent)] +pub struct Send(pub(self) T); + +unsafe impl core::marker::Send for Send {} + +impl Deref for Send { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for Send { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Send { + pub unsafe fn new(value: T) -> Self { + Self(value) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index b497ac3..ac66786 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -18,7 +18,7 @@ pub struct WorkerThread { pub(crate) context: Arc, pub(crate) index: usize, pub(crate) queue: UnsafeCell, - pub(crate) heartbeat: Arc>, + pub(crate) heartbeat: NonNull>, pub(crate) join_count: Cell, } @@ -26,6 +26,11 @@ impl Drop for WorkerThread { fn drop(&mut self) { // remove the current worker thread from the heartbeat list self.context.shared().remove_heartbeat(self.index); + + // SAFETY: we removed the heartbeat from the context, so we can safely drop it. + unsafe { + _ = Box::from_non_null(self.heartbeat); + } } } @@ -45,6 +50,11 @@ impl WorkerThread { join_count: Cell::new(0), } } + + pub(crate) fn heartbeat(&self) -> &CachePadded { + // SAFETY: the heartbeat is always set when the worker thread is created + unsafe { self.heartbeat.as_ref() } + } } impl WorkerThread { @@ -88,7 +98,7 @@ impl WorkerThread { } // no more jobs, wait to be notified of a new job or a heartbeat. - match self.heartbeat.latch.wait_and_reset() { + match self.heartbeat().latch.wait_and_reset() { crate::latch::WakeResult::Wake => { let mut guard = self.context.shared(); if guard.should_exit() { @@ -111,7 +121,7 @@ impl WorkerThread { impl WorkerThread { #[inline(always)] pub(crate) fn tick(&self) { - if self.heartbeat.is_pending() { + if self.heartbeat().is_pending() { tracing::trace!("received heartbeat, thread id: {:?}", self.index); self.heartbeat_cold(); } @@ -220,22 +230,16 @@ impl HeartbeatThread { let mut i = 0; loop { let sleep_for = { - let mut guard = self.ctx.shared(); + let guard = self.ctx.shared(); if guard.should_exit() { break; } - let mut n = 0; - guard.heartbeats.retain(|_, b| { - b.upgrade() - .inspect(|heartbeat| { - if n == i { - heartbeat.latch.signal_heartbeat(); - } - n += 1; - }) - .is_some() - }); + if let Some((_, heartbeat)) = guard.heartbeats.iter().nth(i) { + unsafe { + heartbeat.as_ref().latch.signal_heartbeat(); + } + } let num_heartbeats = guard.heartbeats.len(); drop(guard); @@ -313,7 +317,7 @@ impl WorkerThread { tracing::trace!("thread {:?} is sleeping", self.index); - match self.heartbeat.latch.wait_and_reset() { + match self.heartbeat().latch.wait_and_reset() { // why were we woken up? // 1. the heartbeat thread ticked and set the // latch, so we should see if we have any work