From 19ef21e2efcf8a08ec2ff6178068bc49bd26079c Mon Sep 17 00:00:00 2001 From: Janis Date: Tue, 1 Jul 2025 11:32:30 +0200 Subject: [PATCH] cleanup and fix race --- distaff/src/channel.rs | 87 +++++++++++++++++++++---------------- distaff/src/context.rs | 25 +---------- distaff/src/heartbeat.rs | 2 +- distaff/src/job.rs | 12 +---- distaff/src/join.rs | 19 ++++---- distaff/src/latch.rs | 14 ++---- distaff/src/scope.rs | 2 +- distaff/src/workerthread.rs | 30 ++++++++++++- 8 files changed, 96 insertions(+), 95 deletions(-) diff --git a/distaff/src/channel.rs b/distaff/src/channel.rs index a6e74c8..082610b 100644 --- a/distaff/src/channel.rs +++ b/distaff/src/channel.rs @@ -1,3 +1,5 @@ +// This file is taken from [`chili`] + use std::{ cell::UnsafeCell, ptr::NonNull, @@ -15,6 +17,7 @@ enum State { Taken, } +// taken from `std` #[derive(Debug)] #[repr(transparent)] pub struct Parker { @@ -100,27 +103,45 @@ impl Receiver { } pub fn wait(&self) { - match self.0.state.compare_exchange( - State::Pending as u8, - State::Waiting as u8, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => { - // SAFETY: - // The `waiting_thread` is set to the current thread's parker - // before we park it. - unsafe { - let thread = self.0.waiting_thread.as_ref(); - thread.park(); + loop { + match self.0.state.compare_exchange( + State::Pending as u8, + State::Waiting as u8, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // SAFETY: + // The `waiting_thread` is set to the current thread's parker + // before we park it. + unsafe { + let thread = self.0.waiting_thread.as_ref(); + thread.park(); + } + + // we might have been woken up because of a shared job. + // In that case, we need to check the state again. + if self + .0 + .state + .compare_exchange( + State::Waiting as u8, + State::Pending as u8, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + continue; + } + } + Err(state) if state == State::Ready as u8 => { + // The channel is ready, so we can return immediately. + return; + } + _ => { + panic!("Receiver is already waiting or consumed."); } - } - Err(state) if state == State::Ready as u8 => { - // The channel is ready, so we can return immediately. - return; - } - _ => { - panic!("Receiver is already waiting or consumed."); } } } @@ -144,22 +165,7 @@ impl Receiver { } pub fn recv(self) -> thread::Result { - if self - .0 - .state - .compare_exchange( - State::Pending as u8, - State::Waiting as u8, - Ordering::AcqRel, - Ordering::Acquire, - ) - .is_ok() - { - unsafe { - let thread = self.0.waiting_thread.as_ref(); - thread.park(); - } - } + self.wait(); // SAFETY: // To arrive here, either `state` is `State::Ready` or the above @@ -172,7 +178,14 @@ impl Receiver { } unsafe fn take(&self) -> thread::Result { - unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() } + let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() }; + + assert_eq!( + self.0.state.swap(State::Taken as u8, Ordering::Release), + State::Ready as u8 + ); + + result } } diff --git a/distaff/src/context.rs b/distaff/src/context.rs index 28fc791..40742d6 100644 --- a/distaff/src/context.rs +++ b/distaff/src/context.rs @@ -9,40 +9,17 @@ use std::{ use alloc::collections::BTreeMap; use async_task::Runnable; -use crossbeam_utils::CachePadded; use parking_lot::{Condvar, Mutex}; use crate::{ channel::{Parker, Sender}, heartbeat::HeartbeatList, job::{HeapJob, Job2 as Job, SharedJob, StackJob}, - latch::{AsCoreLatch, MutexLatch, NopLatch}, + latch::NopLatch, util::DropGuard, workerthread::{HeartbeatThread, WorkerThread}, }; -pub struct Heartbeat { - pub latch: MutexLatch, -} - -impl Heartbeat { - pub fn new() -> NonNull> { - let ptr = Box::new(CachePadded::new(Self { - latch: MutexLatch::new(), - })); - - Box::into_non_null(ptr) - } - - pub fn is_pending(&self) -> bool { - self.latch.as_core_latch().poll_heartbeat() - } - - pub fn is_sleeping(&self) -> bool { - self.latch.as_core_latch().is_sleeping() - } -} - pub struct Context { shared: Mutex, pub shared_job: Condvar, diff --git a/distaff/src/heartbeat.rs b/distaff/src/heartbeat.rs index 0c83202..41108e9 100644 --- a/distaff/src/heartbeat.rs +++ b/distaff/src/heartbeat.rs @@ -12,7 +12,7 @@ use std::{ use parking_lot::Mutex; -use crate::{channel::Parker, latch::WorkerLatch}; +use crate::channel::Parker; #[derive(Debug, Clone)] pub struct HeartbeatList { diff --git a/distaff/src/job.rs b/distaff/src/job.rs index e9dab83..47fe44e 100644 --- a/distaff/src/job.rs +++ b/distaff/src/job.rs @@ -7,23 +7,15 @@ use core::{ ptr::{self, NonNull}, sync::atomic::Ordering, }; -use std::{ - cell::Cell, - marker::PhantomData, - mem::MaybeUninit, - ops::DerefMut, - sync::atomic::{AtomicU8, AtomicU32, AtomicUsize}, -}; +use std::cell::Cell; use alloc::boxed::Box; -use parking_lot::{Condvar, Mutex}; use parking_lot_core::SpinWait; use crate::{ WorkerThread, channel::{Parker, Sender}, - latch::{Probe, WorkerLatch}, - util::{DropGuard, SmallBox, TaggedAtomicPtr}, + util::{SmallBox, TaggedAtomicPtr}, }; #[repr(u8)] diff --git a/distaff/src/join.rs b/distaff/src/join.rs index 49edaf5..5989ade 100644 --- a/distaff/src/join.rs +++ b/distaff/src/join.rs @@ -103,15 +103,8 @@ impl WorkerThread { } }; - let ra = if !job.is_shared() { - tracing::trace!("join_heartbeat: job is not shared, running a() inline"); - // we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed. - self.pop_back(); - - // a is allowed to panic here, because we already finished b. - unsafe { a.unwrap()() } - } else { - match self.wait_until_shared_job(&job) { + let ra = if let Some(recv) = job.take_receiver() { + match self.wait_until_recv(recv) { Some(t) => crate::util::unwrap_or_panic(t), None => { tracing::trace!( @@ -122,6 +115,14 @@ impl WorkerThread { unsafe { a.unwrap()() } } } + } else { + self.pop_back(); + + unsafe { + // SAFETY: we just popped the job from the queue, so it is safe to unwrap. + tracing::trace!("join_heartbeat: job was not shared, running a() inline"); + a.unwrap()() + } }; (ra, rb) diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs index 39b9b7c..dafea62 100644 --- a/distaff/src/latch.rs +++ b/distaff/src/latch.rs @@ -2,19 +2,11 @@ use core::{ marker::PhantomData, sync::atomic::{AtomicUsize, Ordering}, }; -use std::{ - cell::UnsafeCell, - mem, - ops::DerefMut, - sync::{ - Arc, - atomic::{AtomicPtr, AtomicU8}, - }, -}; +use std::sync::atomic::{AtomicPtr, AtomicU8}; use parking_lot::{Condvar, Mutex}; -use crate::{WorkerThread, channel::Parker, context::Context}; +use crate::channel::Parker; pub trait Latch { unsafe fn set_raw(this: *const Self); @@ -430,7 +422,7 @@ impl WorkerLatch { #[cfg(test)] mod tests { - use std::{ptr, sync::Barrier}; + use std::{ptr, sync::Arc}; use tracing_test::traced_test; diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs index f334e0e..beb5803 100644 --- a/distaff/src/scope.rs +++ b/distaff/src/scope.rs @@ -14,7 +14,7 @@ use crate::{ channel::Sender, context::Context, job::{HeapJob, Job2 as Job}, - latch::{CountLatch, Probe, WorkerLatch}, + latch::{CountLatch, Probe}, util::{DropGuard, SendPtr}, workerthread::WorkerThread, }; diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index c3b12c3..79918b2 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -10,10 +10,10 @@ use std::{ use crossbeam_utils::CachePadded; use crate::{ + channel::Receiver, context::Context, heartbeat::OwnedHeartbeatReceiver, job::{Job2 as Job, JobQueue as JobList, SharedJob}, - latch::Probe, util::DropGuard, }; @@ -304,7 +304,7 @@ impl HeartbeatThread { impl WorkerThread { #[tracing::instrument(level = "trace", skip(self))] pub fn wait_until_shared_job(&self, job: &Job) -> Option> { - let recv = (*job).take_receiver()?; + let recv = (*job).take_receiver().unwrap(); let mut out = recv.poll(); @@ -321,6 +321,32 @@ impl WorkerThread { out } + #[tracing::instrument(level = "trace", skip_all)] + pub fn wait_until_recv(&self, recv: Receiver) -> Option> { + if self + .context + .shared() + .jobs + .remove(&self.heartbeat.id()) + .is_some() + { + tracing::trace!("reclaiming shared job"); + return None; + } + + while recv.is_empty() { + if let Some(job) = self.find_work() { + unsafe { + SharedJob::execute(job, self); + } + } else { + break; + } + } + + Some(recv.recv()) + } + #[tracing::instrument(level = "trace", skip_all)] pub fn wait_until_pred(&self, mut pred: F) where