use core::{ marker::PhantomData, sync::atomic::{AtomicUsize, Ordering}, }; use std::{ cell::UnsafeCell, mem, ops::DerefMut, sync::{ Arc, atomic::{AtomicPtr, AtomicU8}, }, }; use parking_lot::{Condvar, Mutex}; use crate::{WorkerThread, context::Context}; pub trait Latch { unsafe fn set_raw(this: *const Self); } pub trait Probe { fn probe(&self) -> bool; } pub type CoreLatch = AtomicLatch; pub trait AsCoreLatch { fn as_core_latch(&self) -> &CoreLatch; } #[derive(Debug)] pub struct AtomicLatch { inner: AtomicU8, } impl AtomicLatch { pub const UNSET: u8 = 0; pub const SET: u8 = 1; pub const SLEEPING: u8 = 2; pub const WAKEUP: u8 = 4; pub const HEARTBEAT: u8 = 8; #[inline] pub const fn new() -> Self { Self { inner: AtomicU8::new(Self::UNSET), } } pub const fn new_set() -> Self { Self { inner: AtomicU8::new(Self::SET), } } #[inline] pub fn unset(&self) { self.inner.fetch_and(!Self::SET, Ordering::Release); } pub fn reset(&self) -> u8 { self.inner.swap(Self::UNSET, Ordering::Release) } pub fn get(&self) -> u8 { self.inner.load(Ordering::Acquire) } pub fn poll_heartbeat(&self) -> bool { self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT } /// returns true if the latch was already set. pub fn set_sleeping(&self) -> bool { self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET } pub fn is_sleeping(&self) -> bool { self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING } pub fn is_heartbeat(&self) -> bool { self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT } pub fn is_wakeup(&self) -> bool { self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP } pub fn is_set(&self) -> bool { self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET } pub fn set_wakeup(&self) { self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed); } pub fn set_heartbeat(&self) { self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed); } /// returns true if the latch was previously sleeping. #[inline] pub unsafe fn set(this: *const Self) -> bool { unsafe { let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed); old & Self::SLEEPING == Self::SLEEPING } } } impl Latch for AtomicLatch { #[inline] unsafe fn set_raw(this: *const Self) { unsafe { Self::set(this); } } } impl Probe for AtomicLatch { #[inline] fn probe(&self) -> bool { self.inner.load(Ordering::Relaxed) & Self::SET != 0 } } impl AsCoreLatch for AtomicLatch { #[inline] fn as_core_latch(&self) -> &CoreLatch { self } } pub struct LatchRef<'a, L: Latch> { inner: *const L, _marker: PhantomData<&'a L>, } impl<'a, L: Latch> LatchRef<'a, L> { #[inline] pub const fn new(latch: &'a L) -> Self { Self { inner: latch, _marker: PhantomData, } } } impl<'a, L: Latch> Latch for LatchRef<'a, L> { #[inline] unsafe fn set_raw(this: *const Self) { unsafe { let this = &*this; Latch::set_raw(this.inner); } } } impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> { #[inline] fn probe(&self) -> bool { unsafe { let this = &*self.inner; Probe::probe(this) } } } impl<'a, L> AsCoreLatch for LatchRef<'a, L> where L: Latch + AsCoreLatch, { #[inline] fn as_core_latch(&self) -> &CoreLatch { unsafe { let this = &*self.inner; this.as_core_latch() } } } pub struct NopLatch; impl Latch for NopLatch { #[inline] unsafe fn set_raw(_this: *const Self) { // do nothing } } impl Probe for NopLatch { #[inline] fn probe(&self) -> bool { false // always returns false } } pub struct CountLatch { count: AtomicUsize, inner: AtomicPtr, } impl CountLatch { #[inline] pub const fn new(inner: *const WorkerLatch) -> Self { Self { count: AtomicUsize::new(0), inner: AtomicPtr::new(inner as *mut WorkerLatch), } } pub fn set_inner(&self, inner: *const WorkerLatch) { self.inner .store(inner as *mut WorkerLatch, Ordering::Relaxed); } pub fn count(&self) -> usize { self.count.load(Ordering::Relaxed) } #[inline] pub fn increment(&self) { self.count.fetch_add(1, Ordering::Release); } #[inline] pub fn decrement(&self) { unsafe { Latch::set_raw(self); } } } impl Latch for CountLatch { #[inline] unsafe fn set_raw(this: *const Self) { unsafe { if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 { tracing::trace!("CountLatch set_raw: count was 1, setting inner latch"); // If the count was 1, we need to set the inner latch. let inner = (*this).inner.load(Ordering::Relaxed); if !inner.is_null() { (&*inner).wake(); } } } } } impl Probe for CountLatch { #[inline] fn probe(&self) -> bool { self.count.load(Ordering::Relaxed) == 0 } } pub struct MutexLatch { inner: AtomicLatch, lock: Mutex<()>, condvar: Condvar, } unsafe impl Send for MutexLatch {} unsafe impl Sync for MutexLatch {} #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum WakeResult { Wake, Heartbeat, Set, } impl MutexLatch { #[inline] pub const fn new() -> Self { Self { inner: AtomicLatch::new(), lock: Mutex::new(()), condvar: Condvar::new(), } } #[inline] pub fn reset(&self) { let _guard = self.lock.lock(); // SAFETY: inner is atomic, so we can safely access it. self.inner.reset(); } pub fn wait_and_reset(&self) { // SAFETY: inner is locked by the mutex, so we can safely access it. let mut guard = self.lock.lock(); while !self.inner.probe() { self.condvar.wait(&mut guard); } self.inner.reset(); } pub fn set(&self) { unsafe { Latch::set_raw(self); } } } impl Latch for MutexLatch { #[inline] unsafe fn set_raw(this: *const Self) { // SAFETY: `this` is valid until the guard is dropped. unsafe { let this = &*this; let _guard = this.lock.lock(); Latch::set_raw(&this.inner); this.condvar.notify_all(); } } } impl Probe for MutexLatch { #[inline] fn probe(&self) -> bool { let _guard = self.lock.lock(); // SAFETY: inner is atomic, so we can safely access it. self.inner.probe() } } impl AsCoreLatch for MutexLatch { #[inline] fn as_core_latch(&self) -> &CoreLatch { // SAFETY: inner is atomic, so we can safely access it. self.inner.as_core_latch() } } // The worker waits on this latch whenever it has nothing to do. pub struct WorkerLatch { // this boolean is set when the worker is waiting. mutex: Mutex, condvar: AtomicUsize, } impl WorkerLatch { pub fn new() -> Self { Self { mutex: Mutex::new(false), condvar: AtomicUsize::new(0), } } pub fn lock(&self) { mem::forget(self.mutex.lock()); } pub fn unlock(&self) { unsafe { self.mutex.force_unlock(); } } pub fn wait(&self) { let condvar = &self.condvar; let mut guard = self.mutex.lock(); Self::wait_internal(condvar, &mut guard); } fn wait_internal(condvar: &AtomicUsize, guard: &mut parking_lot::MutexGuard<'_, bool>) { let mutex = parking_lot::MutexGuard::mutex(guard); let key = condvar as *const _ as usize; let lock_addr = mutex as *const _ as usize; let mut requeued = false; let state = unsafe { AtomicUsize::from_ptr(condvar as *const _ as *mut usize) }; **guard = true; // set the mutex to true to indicate that the worker is waiting unsafe { parking_lot_core::park( key, || { let old = state.load(Ordering::Relaxed); if old == 0 { state.store(lock_addr, Ordering::Relaxed); } else if old != lock_addr { return false; } true }, || { mutex.force_unlock(); }, |k, was_last_thread| { requeued = k != key; if !requeued && was_last_thread { state.store(0, Ordering::Relaxed); } }, parking_lot_core::DEFAULT_PARK_TOKEN, None, ); } // relock let mut new = mutex.lock(); mem::swap(&mut new, guard); mem::forget(new); // forget the new guard to avoid dropping it **guard = false; // reset the mutex to false after waking up } fn wait_with_lock_internal(&self, other: &mut parking_lot::MutexGuard<'_, T>) { let key = &self.condvar as *const _ as usize; let lock_addr = &self.mutex as *const _ as usize; let mut requeued = false; let mut guard = self.mutex.lock(); let state = unsafe { AtomicUsize::from_ptr(&self.condvar as *const _ as *mut usize) }; *guard = true; // set the mutex to true to indicate that the worker is waiting unsafe { let token = parking_lot_core::park( key, || { let old = state.load(Ordering::Relaxed); if old == 0 { state.store(lock_addr, Ordering::Relaxed); } else if old != lock_addr { return false; } true }, || { drop(guard); // drop the guard to release the lock parking_lot::MutexGuard::mutex(&other).force_unlock(); }, |k, was_last_thread| { requeued = k != key; if !requeued && was_last_thread { state.store(0, Ordering::Relaxed); } }, parking_lot_core::DEFAULT_PARK_TOKEN, None, ); tracing::trace!( "WorkerLatch wait_with_lock_internal: unparked with token {:?}", token ); } // relock let mut other2 = parking_lot::MutexGuard::mutex(&other).lock(); tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other"); // because `other` is logically unlocked, we swap it with `other2` and then forget `other2` core::mem::swap(&mut *other2, &mut *other); core::mem::forget(other2); let mut guard = self.mutex.lock(); tracing::trace!("WorkerLatch wait_with_lock_internal: relocked self"); *guard = false; // reset the mutex to false after waking up } pub fn wait_with_lock(&self, other: &mut parking_lot::MutexGuard<'_, T>) { self.wait_with_lock_internal(other); } pub fn wait_with_lock_while(&self, other: &mut parking_lot::MutexGuard<'_, T>, mut f: F) where F: FnMut(&mut T) -> bool, { while f(other.deref_mut()) { self.wait_with_lock_internal(other); } } pub fn wait_until(&self, mut f: F) -> T where F: FnMut() -> Option, { let mut guard = self.mutex.lock(); loop { if let Some(result) = f() { return result; } Self::wait_internal(&self.condvar, &mut guard); } } pub fn is_waiting(&self) -> bool { *self.mutex.lock() } fn notify(&self) { let key = &self.condvar as *const _ as usize; unsafe { let n = parking_lot_core::unpark_all(key, parking_lot_core::DEFAULT_UNPARK_TOKEN); tracing::trace!("WorkerLatch notify_one: unparked {} threads", n); } } pub fn wake(&self) { self.notify(); } } #[cfg(test)] mod tests { use std::{ptr, sync::Barrier}; use tracing_test::traced_test; use super::*; #[test] #[cfg_attr(not(miri), traced_test)] fn worker_latch() { let latch = Arc::new(WorkerLatch::new()); let barrier = Arc::new(Barrier::new(2)); let mutex = Arc::new(parking_lot::Mutex::new(false)); let count = Arc::new(AtomicUsize::new(0)); let thread = std::thread::spawn({ let latch = latch.clone(); let mutex = mutex.clone(); let barrier = barrier.clone(); let count = count.clone(); move || { tracing::info!("Thread waiting on barrier"); let mut guard = mutex.lock(); barrier.wait(); tracing::info!("Thread waiting on latch"); latch.wait_with_lock(&mut guard); count.fetch_add(1, Ordering::Relaxed); tracing::info!("Thread woke up from latch"); barrier.wait(); tracing::info!("Thread finished waiting on barrier"); count.fetch_add(1, Ordering::Relaxed); } }); assert!(!latch.is_waiting(), "Latch should not be waiting yet"); barrier.wait(); tracing::info!("Main thread finished waiting on barrier"); // lock mutex and notify the thread that isn't yet waiting. { let guard = mutex.lock(); tracing::info!("Main thread acquired mutex, waking up thread"); assert!(latch.is_waiting(), "Latch should be waiting now"); latch.wake(); tracing::info!("Main thread woke up thread"); } assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0"); barrier.wait(); assert_eq!( count.load(Ordering::Relaxed), 1, "Count should be 1 after waking up" ); thread.join().expect("Thread should join successfully"); assert_eq!( count.load(Ordering::Relaxed), 2, "Count should be 2 after thread has finished" ); } #[test] fn test_atomic_latch() { let latch = AtomicLatch::new(); assert_eq!(latch.get(), AtomicLatch::UNSET); unsafe { assert!(!latch.probe()); AtomicLatch::set_raw(&latch); } assert_eq!(latch.get(), AtomicLatch::SET); assert!(latch.probe()); latch.unset(); assert_eq!(latch.get(), AtomicLatch::UNSET); } #[test] fn core_latch_sleep() { let latch = AtomicLatch::new(); assert_eq!(latch.get(), AtomicLatch::UNSET); latch.set_sleeping(); assert_eq!(latch.get(), AtomicLatch::SLEEPING); unsafe { assert!(!latch.probe()); assert!(AtomicLatch::set(&latch)); } assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING); assert!(latch.probe()); latch.reset(); assert_eq!(latch.get(), AtomicLatch::UNSET); } #[test] fn nop_latch() { assert!( core::mem::size_of::() == 0, "NopLatch should be zero-sized" ); } #[test] fn count_latch() { let latch = CountLatch::new(ptr::null()); assert_eq!(latch.count(), 0); latch.increment(); assert_eq!(latch.count(), 1); assert!(!latch.probe()); latch.increment(); assert_eq!(latch.count(), 2); assert!(!latch.probe()); unsafe { Latch::set_raw(&latch); } assert!(!latch.probe()); assert_eq!(latch.count(), 1); unsafe { Latch::set_raw(&latch); } assert!(latch.probe()); assert_eq!(latch.count(), 0); } #[test] #[traced_test] fn mutex_latch() { let latch = Arc::new(MutexLatch::new()); assert!(!latch.probe()); latch.set(); assert!(latch.probe()); latch.reset(); assert!(!latch.probe()); // Test wait functionality let latch_clone = latch.clone(); let handle = std::thread::spawn(move || { tracing::info!("Thread waiting on latch"); latch_clone.wait_and_reset(); tracing::info!("Thread woke up from latch"); }); // Give the thread time to block std::thread::sleep(std::time::Duration::from_millis(100)); assert!(!latch.probe()); tracing::info!("Setting latch from main thread"); latch.set(); tracing::info!("Latch set, joining waiting thread"); handle.join().expect("Thread should join successfully"); } }