diff --git a/distaff/src/job.rs b/distaff/src/job.rs index db82b2d..84b0aa5 100644 --- a/distaff/src/job.rs +++ b/distaff/src/job.rs @@ -1086,15 +1086,16 @@ impl JobSender { // // This concludes my TED talk on why we need to lock here. - let _guard = (!mutex.is_null()).then(|| { - // SAFETY: mutex is a valid pointer to a WorkerLatch - unsafe { - (&*mutex).lock(); - DropGuard::new(|| { - (&*mutex).wake(); - (&*mutex).unlock() - }) - } + let _guard = unsafe { mutex.as_ref() }.map(|mutex| { + let guard = mutex.lock(); + DropGuard::new(move || { + // // SAFETY: we forget the guard here so we no longer borrow the mutex. + // mem::forget(guard); + _ = guard; + mutex.wake(); + // // SAFETY: we can safely unlock the mutex here, as we are the only ones holding it. + // mutex.force_unlock(); + }) }); assert!(self.channel.tag.tag(Ordering::Acquire) & FINISHED == 0); diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs index 0d56a6e..80a745c 100644 --- a/distaff/src/latch.rs +++ b/distaff/src/latch.rs @@ -341,20 +341,29 @@ impl AsCoreLatch for MutexLatch { pub struct WorkerLatch { // this boolean is set when the worker is waiting. mutex: Mutex, - condvar: AtomicUsize, + condvar: Condvar, } impl WorkerLatch { pub fn new() -> Self { Self { mutex: Mutex::new(false), - condvar: AtomicUsize::new(0), + condvar: Condvar::new(), } } - pub fn lock(&self) { - mem::forget(self.mutex.lock()); + + #[tracing::instrument(level = "trace", skip_all, fields( + this = self as *const Self as usize, + ))] + pub fn lock(&self) -> parking_lot::MutexGuard<'_, bool> { + tracing::trace!("aquiring mutex.."); + let guard = self.mutex.lock(); + tracing::trace!("mutex acquired."); + + guard } - pub fn unlock(&self) { + + pub unsafe fn force_unlock(&self) { unsafe { self.mutex.force_unlock(); } @@ -362,144 +371,15 @@ impl WorkerLatch { pub fn wait(&self) { let condvar = &self.condvar; - let mut guard = self.mutex.lock(); + let mut guard = self.lock(); Self::wait_internal(condvar, &mut guard); } - fn wait_internal(condvar: &AtomicUsize, guard: &mut parking_lot::MutexGuard<'_, bool>) { - let mutex = parking_lot::MutexGuard::mutex(guard); - let key = condvar as *const _ as usize; - let lock_addr = mutex as *const _ as usize; - let mut requeued = false; - - let state = unsafe { AtomicUsize::from_ptr(condvar as *const _ as *mut usize) }; - + fn wait_internal(condvar: &Condvar, guard: &mut parking_lot::MutexGuard<'_, bool>) { **guard = true; // set the mutex to true to indicate that the worker is waiting - - unsafe { - parking_lot_core::park( - key, - || { - let old = state.load(Ordering::Relaxed); - if old == 0 { - state.store(lock_addr, Ordering::Relaxed); - } else if old != lock_addr { - return false; - } - - true - }, - || { - mutex.force_unlock(); - }, - |k, was_last_thread| { - requeued = k != key; - if !requeued && was_last_thread { - state.store(0, Ordering::Relaxed); - } - }, - parking_lot_core::DEFAULT_PARK_TOKEN, - None, - ); - } - // relock - - let mut new = mutex.lock(); - mem::swap(&mut new, guard); - mem::forget(new); // forget the new guard to avoid dropping it - - **guard = false; // reset the mutex to false after waking up - } - - fn wait_with_lock_internal( - condvar: &AtomicUsize, - mutex: &mut parking_lot::MutexGuard<'_, bool>, - other: &mut parking_lot::MutexGuard<'_, T>, - ) { - **mutex = true; - let key = condvar as *const _ as usize; - let lock_addr = parking_lot::MutexGuard::mutex(mutex) as *const _ as usize; - let mut requeued = false; - - let state = condvar; - - unsafe { - let token = parking_lot_core::park( - key, - || { - let old = state.load(Ordering::Relaxed); - if old == 0 { - state.store(lock_addr, Ordering::Relaxed); - } else if old != lock_addr { - return false; - } - - true - }, - || { - parking_lot::MutexGuard::mutex(&mutex).force_unlock(); - parking_lot::MutexGuard::mutex(&other).force_unlock(); - }, - |k, was_last_thread| { - requeued = k != key; - if !requeued && was_last_thread { - state.store(0, Ordering::Relaxed); - } - }, - parking_lot_core::DEFAULT_PARK_TOKEN, - None, - ); - - tracing::trace!( - "WorkerLatch wait_with_lock_internal: unparked with token {:?}", - token - ); - } - // because `other` is logically unlocked, we swap it with `other2` and then forget `other2` - let mut other2 = parking_lot::MutexGuard::mutex(&other).lock(); - core::mem::swap(&mut other2, other); - core::mem::forget(other2); - - // because `other` is logically unlocked, we swap it with `other2` and then forget `other2` - let mut mutex2 = parking_lot::MutexGuard::mutex(&mutex).lock(); - core::mem::swap(&mut mutex2, mutex); - core::mem::forget(mutex2); - - **mutex = false; - } - - #[tracing::instrument(level = "trace", skip_all, fields( - this = self as *const Self as usize, - ))] - pub fn wait_with_lock(&self, other: &mut parking_lot::MutexGuard<'_, T>) { - Self::wait_with_lock_internal(&self.condvar, &mut self.mutex.lock(), other); - } - - #[tracing::instrument(level = "trace", skip_all, fields( - this = self as *const Self as usize, - ))] - pub fn wait_with_lock_unless( - &self, - other: &mut parking_lot::MutexGuard<'_, T>, - mut pred: F, - ) where - F: FnMut(&mut T) -> bool, - { - let mut guard = self.mutex.lock(); - if !pred(other.deref_mut()) { - Self::wait_with_lock_internal(&self.condvar, &mut guard, other); - } - } - - pub fn wait_with_lock_while(&self, other: &mut parking_lot::MutexGuard<'_, T>, mut f: F) - where - F: FnMut(&mut T) -> bool, - { - let mut guard = self.mutex.lock(); - while f(other.deref_mut()) { - Self::wait_with_lock_internal(&self.condvar, &mut guard, other); - } + condvar.wait(guard); + **guard = false; } #[tracing::instrument(level = "trace", skip_all, fields( @@ -509,7 +389,7 @@ impl WorkerLatch { where F: FnMut() -> bool, { - let mut guard = self.mutex.lock(); + let mut guard = self.lock(); if !f() { Self::wait_internal(&self.condvar, &mut guard); } @@ -522,7 +402,7 @@ impl WorkerLatch { where F: FnMut() -> Option, { - let mut guard = self.mutex.lock(); + let mut guard = self.lock(); loop { if let Some(result) = f() { return result; @@ -539,29 +419,8 @@ impl WorkerLatch { this = self as *const Self as usize, ))] fn notify(&self) { - let from = &self.condvar as *const _ as usize; - let to = &self.mutex as *const _ as usize; - - let validate = || { - if self.condvar.load(Ordering::Relaxed) != to { - return parking_lot_core::RequeueOp::Abort; - } - - self.condvar.store(0, Ordering::Relaxed); - - parking_lot_core::RequeueOp::UnparkOneRequeueRest - }; - - let callback = |_op: parking_lot_core::RequeueOp, - _result: parking_lot_core::UnparkResult| { - parking_lot_core::DEFAULT_UNPARK_TOKEN - }; - - unsafe { - //let n = parking_lot_core::unpark_requeue(from, to, validate, callback); - let n = parking_lot_core::unpark_all(from, parking_lot_core::DEFAULT_UNPARK_TOKEN); - tracing::trace!("WorkerLatch notify_one: unparked {:?}", n); - } + let n = self.condvar.notify_all(); + tracing::trace!("WorkerLatch notify: notified {} threads", n); } pub fn wake(&self) { @@ -598,7 +457,7 @@ mod tests { barrier.wait(); tracing::info!("Thread waiting on latch"); - latch.wait_with_lock(&mut guard); + latch.wait(); count.fetch_add(1, Ordering::SeqCst); tracing::info!("Thread woke up from latch"); barrier.wait(); diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index 758aeb4..1cef6f6 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -118,15 +118,15 @@ impl WorkerThread { if let Some(job) = self.find_work_inner() { return Some(job); } - // check the predicate while holding the lock // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - // this is very important, because the lock must be held when - // notifying us of the result of a job we scheduled. + // Check the predicate while holding the lock. This is very important, + // because the lock must be held when notifying us of the result of a + // job we scheduled. // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // no jobs found, wait for a heartbeat or a new job - // tracing::trace!(worker = self.heartbeat.index(), "waiting for new job"); + tracing::trace!(worker = self.heartbeat.index(), "waiting for new job"); self.heartbeat.latch().wait_unless(pred); - // tracing::trace!(worker = self.heartbeat.index(), "woken up from wait"); + tracing::trace!(worker = self.heartbeat.index(), "woken up from wait"); None } @@ -341,9 +341,7 @@ impl WorkerThread { } } - self.heartbeat.latch().lock(); out = recv.poll(); - self.heartbeat.latch().unlock(); } out