executor/distaff/src/latch.rs
2025-06-27 23:08:27 +02:00

671 lines
17 KiB
Rust

use core::{
marker::PhantomData,
sync::atomic::{AtomicUsize, Ordering},
};
use std::{
cell::UnsafeCell,
mem,
ops::DerefMut,
sync::{
Arc,
atomic::{AtomicPtr, AtomicU8},
},
};
use parking_lot::{Condvar, Mutex};
use crate::{WorkerThread, context::Context};
pub trait Latch {
unsafe fn set_raw(this: *const Self);
}
pub trait Probe {
fn probe(&self) -> bool;
}
pub type CoreLatch = AtomicLatch;
pub trait AsCoreLatch {
fn as_core_latch(&self) -> &CoreLatch;
}
#[derive(Debug)]
pub struct AtomicLatch {
inner: AtomicU8,
}
impl AtomicLatch {
pub const UNSET: u8 = 0;
pub const SET: u8 = 1;
pub const SLEEPING: u8 = 2;
pub const WAKEUP: u8 = 4;
pub const HEARTBEAT: u8 = 8;
#[inline]
pub const fn new() -> Self {
Self {
inner: AtomicU8::new(Self::UNSET),
}
}
pub const fn new_set() -> Self {
Self {
inner: AtomicU8::new(Self::SET),
}
}
#[inline]
pub fn unset(&self) {
self.inner.fetch_and(!Self::SET, Ordering::Release);
}
pub fn reset(&self) -> u8 {
self.inner.swap(Self::UNSET, Ordering::Release)
}
pub fn get(&self) -> u8 {
self.inner.load(Ordering::Acquire)
}
pub fn poll_heartbeat(&self) -> bool {
self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT
== Self::HEARTBEAT
}
/// returns true if the latch was already set.
pub fn set_sleeping(&self) -> bool {
self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET
}
pub fn is_sleeping(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING
}
pub fn is_heartbeat(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT
}
pub fn is_wakeup(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP
}
pub fn is_set(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
}
pub fn set_wakeup(&self) {
self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed);
}
pub fn set_heartbeat(&self) {
self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed);
}
/// returns true if the latch was previously sleeping.
#[inline]
pub unsafe fn set(this: *const Self) -> bool {
unsafe {
let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed);
old & Self::SLEEPING == Self::SLEEPING
}
}
}
impl Latch for AtomicLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe {
Self::set(this);
}
}
}
impl Probe for AtomicLatch {
#[inline]
fn probe(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SET != 0
}
}
impl AsCoreLatch for AtomicLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
self
}
}
pub struct LatchRef<'a, L: Latch> {
inner: *const L,
_marker: PhantomData<&'a L>,
}
impl<'a, L: Latch> LatchRef<'a, L> {
#[inline]
pub const fn new(latch: &'a L) -> Self {
Self {
inner: latch,
_marker: PhantomData,
}
}
}
impl<'a, L: Latch> Latch for LatchRef<'a, L> {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe {
let this = &*this;
Latch::set_raw(this.inner);
}
}
}
impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> {
#[inline]
fn probe(&self) -> bool {
unsafe {
let this = &*self.inner;
Probe::probe(this)
}
}
}
impl<'a, L> AsCoreLatch for LatchRef<'a, L>
where
L: Latch + AsCoreLatch,
{
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
unsafe {
let this = &*self.inner;
this.as_core_latch()
}
}
}
pub struct NopLatch;
impl Latch for NopLatch {
#[inline]
unsafe fn set_raw(_this: *const Self) {
// do nothing
}
}
impl Probe for NopLatch {
#[inline]
fn probe(&self) -> bool {
false // always returns false
}
}
pub struct CountLatch {
count: AtomicUsize,
inner: AtomicPtr<WorkerLatch>,
}
impl CountLatch {
#[inline]
pub const fn new(inner: *const WorkerLatch) -> Self {
Self {
count: AtomicUsize::new(0),
inner: AtomicPtr::new(inner as *mut WorkerLatch),
}
}
pub fn set_inner(&self, inner: *const WorkerLatch) {
self.inner
.store(inner as *mut WorkerLatch, Ordering::Relaxed);
}
pub fn count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
#[inline]
pub fn increment(&self) {
self.count.fetch_add(1, Ordering::Release);
}
#[inline]
pub fn decrement(&self) {
unsafe {
Latch::set_raw(self);
}
}
}
impl Latch for CountLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe {
if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
// If the count was 1, we need to set the inner latch.
let inner = (*this).inner.load(Ordering::Relaxed);
if !inner.is_null() {
(&*inner).wake();
}
}
}
}
}
impl Probe for CountLatch {
#[inline]
fn probe(&self) -> bool {
self.count.load(Ordering::Relaxed) == 0
}
}
pub struct MutexLatch {
inner: AtomicLatch,
lock: Mutex<()>,
condvar: Condvar,
}
unsafe impl Send for MutexLatch {}
unsafe impl Sync for MutexLatch {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WakeResult {
Wake,
Heartbeat,
Set,
}
impl MutexLatch {
#[inline]
pub const fn new() -> Self {
Self {
inner: AtomicLatch::new(),
lock: Mutex::new(()),
condvar: Condvar::new(),
}
}
#[inline]
pub fn reset(&self) {
let _guard = self.lock.lock();
// SAFETY: inner is atomic, so we can safely access it.
self.inner.reset();
}
pub fn wait_and_reset(&self) {
// SAFETY: inner is locked by the mutex, so we can safely access it.
let mut guard = self.lock.lock();
while !self.inner.probe() {
self.condvar.wait(&mut guard);
}
self.inner.reset();
}
pub fn set(&self) {
unsafe {
Latch::set_raw(self);
}
}
}
impl Latch for MutexLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
// SAFETY: `this` is valid until the guard is dropped.
unsafe {
let this = &*this;
let _guard = this.lock.lock();
Latch::set_raw(&this.inner);
this.condvar.notify_all();
}
}
}
impl Probe for MutexLatch {
#[inline]
fn probe(&self) -> bool {
let _guard = self.lock.lock();
// SAFETY: inner is atomic, so we can safely access it.
self.inner.probe()
}
}
impl AsCoreLatch for MutexLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
// SAFETY: inner is atomic, so we can safely access it.
self.inner.as_core_latch()
}
}
// The worker waits on this latch whenever it has nothing to do.
pub struct WorkerLatch {
// this boolean is set when the worker is waiting.
mutex: Mutex<bool>,
condvar: AtomicUsize,
}
impl WorkerLatch {
pub fn new() -> Self {
Self {
mutex: Mutex::new(false),
condvar: AtomicUsize::new(0),
}
}
pub fn lock(&self) {
mem::forget(self.mutex.lock());
}
pub fn unlock(&self) {
unsafe {
self.mutex.force_unlock();
}
}
pub fn wait(&self) {
let condvar = &self.condvar;
let mut guard = self.mutex.lock();
Self::wait_internal(condvar, &mut guard);
}
fn wait_internal(condvar: &AtomicUsize, guard: &mut parking_lot::MutexGuard<'_, bool>) {
let mutex = parking_lot::MutexGuard::mutex(guard);
let key = condvar as *const _ as usize;
let lock_addr = mutex as *const _ as usize;
let mut requeued = false;
let state = unsafe { AtomicUsize::from_ptr(condvar as *const _ as *mut usize) };
**guard = true; // set the mutex to true to indicate that the worker is waiting
unsafe {
parking_lot_core::park(
key,
|| {
let old = state.load(Ordering::Relaxed);
if old == 0 {
state.store(lock_addr, Ordering::Relaxed);
} else if old != lock_addr {
return false;
}
true
},
|| {
mutex.force_unlock();
},
|k, was_last_thread| {
requeued = k != key;
if !requeued && was_last_thread {
state.store(0, Ordering::Relaxed);
}
},
parking_lot_core::DEFAULT_PARK_TOKEN,
None,
);
}
// relock
let mut new = mutex.lock();
mem::swap(&mut new, guard);
mem::forget(new); // forget the new guard to avoid dropping it
**guard = false; // reset the mutex to false after waking up
}
fn wait_with_lock_internal<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
let key = &self.condvar as *const _ as usize;
let lock_addr = &self.mutex as *const _ as usize;
let mut requeued = false;
let mut guard = self.mutex.lock();
let state = unsafe { AtomicUsize::from_ptr(&self.condvar as *const _ as *mut usize) };
*guard = true; // set the mutex to true to indicate that the worker is waiting
unsafe {
let token = parking_lot_core::park(
key,
|| {
let old = state.load(Ordering::Relaxed);
if old == 0 {
state.store(lock_addr, Ordering::Relaxed);
} else if old != lock_addr {
return false;
}
true
},
|| {
drop(guard); // drop the guard to release the lock
parking_lot::MutexGuard::mutex(&other).force_unlock();
},
|k, was_last_thread| {
requeued = k != key;
if !requeued && was_last_thread {
state.store(0, Ordering::Relaxed);
}
},
parking_lot_core::DEFAULT_PARK_TOKEN,
None,
);
tracing::trace!(
"WorkerLatch wait_with_lock_internal: unparked with token {:?}",
token
);
}
// relock
let mut other2 = parking_lot::MutexGuard::mutex(&other).lock();
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other");
// because `other` is logically unlocked, we swap it with `other2` and then forget `other2`
core::mem::swap(&mut *other2, &mut *other);
core::mem::forget(other2);
let mut guard = self.mutex.lock();
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked self");
*guard = false; // reset the mutex to false after waking up
}
pub fn wait_with_lock<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
self.wait_with_lock_internal(other);
}
pub fn wait_with_lock_while<T, F>(&self, other: &mut parking_lot::MutexGuard<'_, T>, mut f: F)
where
F: FnMut(&mut T) -> bool,
{
while f(other.deref_mut()) {
self.wait_with_lock_internal(other);
}
}
pub fn wait_until<F, T>(&self, mut f: F) -> T
where
F: FnMut() -> Option<T>,
{
let mut guard = self.mutex.lock();
loop {
if let Some(result) = f() {
return result;
}
Self::wait_internal(&self.condvar, &mut guard);
}
}
pub fn is_waiting(&self) -> bool {
*self.mutex.lock()
}
fn notify(&self) {
let key = &self.condvar as *const _ as usize;
unsafe {
let n = parking_lot_core::unpark_all(key, parking_lot_core::DEFAULT_UNPARK_TOKEN);
tracing::trace!("WorkerLatch notify_one: unparked {} threads", n);
}
}
pub fn wake(&self) {
self.notify();
}
}
#[cfg(test)]
mod tests {
use std::{ptr, sync::Barrier};
use tracing_test::traced_test;
use super::*;
#[test]
#[cfg_attr(not(miri), traced_test)]
fn worker_latch() {
let latch = Arc::new(WorkerLatch::new());
let barrier = Arc::new(Barrier::new(2));
let mutex = Arc::new(parking_lot::Mutex::new(false));
let count = Arc::new(AtomicUsize::new(0));
let thread = std::thread::spawn({
let latch = latch.clone();
let mutex = mutex.clone();
let barrier = barrier.clone();
let count = count.clone();
move || {
tracing::info!("Thread waiting on barrier");
let mut guard = mutex.lock();
barrier.wait();
tracing::info!("Thread waiting on latch");
latch.wait_with_lock(&mut guard);
count.fetch_add(1, Ordering::Relaxed);
tracing::info!("Thread woke up from latch");
barrier.wait();
tracing::info!("Thread finished waiting on barrier");
count.fetch_add(1, Ordering::Relaxed);
}
});
assert!(!latch.is_waiting(), "Latch should not be waiting yet");
barrier.wait();
tracing::info!("Main thread finished waiting on barrier");
// lock mutex and notify the thread that isn't yet waiting.
{
let guard = mutex.lock();
tracing::info!("Main thread acquired mutex, waking up thread");
assert!(latch.is_waiting(), "Latch should be waiting now");
latch.wake();
tracing::info!("Main thread woke up thread");
}
assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0");
barrier.wait();
assert_eq!(
count.load(Ordering::Relaxed),
1,
"Count should be 1 after waking up"
);
thread.join().expect("Thread should join successfully");
assert_eq!(
count.load(Ordering::Relaxed),
2,
"Count should be 2 after thread has finished"
);
}
#[test]
fn test_atomic_latch() {
let latch = AtomicLatch::new();
assert_eq!(latch.get(), AtomicLatch::UNSET);
unsafe {
assert!(!latch.probe());
AtomicLatch::set_raw(&latch);
}
assert_eq!(latch.get(), AtomicLatch::SET);
assert!(latch.probe());
latch.unset();
assert_eq!(latch.get(), AtomicLatch::UNSET);
}
#[test]
fn core_latch_sleep() {
let latch = AtomicLatch::new();
assert_eq!(latch.get(), AtomicLatch::UNSET);
latch.set_sleeping();
assert_eq!(latch.get(), AtomicLatch::SLEEPING);
unsafe {
assert!(!latch.probe());
assert!(AtomicLatch::set(&latch));
}
assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING);
assert!(latch.probe());
latch.reset();
assert_eq!(latch.get(), AtomicLatch::UNSET);
}
#[test]
fn nop_latch() {
assert!(
core::mem::size_of::<NopLatch>() == 0,
"NopLatch should be zero-sized"
);
}
#[test]
fn count_latch() {
let latch = CountLatch::new(ptr::null());
assert_eq!(latch.count(), 0);
latch.increment();
assert_eq!(latch.count(), 1);
assert!(!latch.probe());
latch.increment();
assert_eq!(latch.count(), 2);
assert!(!latch.probe());
unsafe {
Latch::set_raw(&latch);
}
assert!(!latch.probe());
assert_eq!(latch.count(), 1);
unsafe {
Latch::set_raw(&latch);
}
assert!(latch.probe());
assert_eq!(latch.count(), 0);
}
#[test]
#[traced_test]
fn mutex_latch() {
let latch = Arc::new(MutexLatch::new());
assert!(!latch.probe());
latch.set();
assert!(latch.probe());
latch.reset();
assert!(!latch.probe());
// Test wait functionality
let latch_clone = latch.clone();
let handle = std::thread::spawn(move || {
tracing::info!("Thread waiting on latch");
latch_clone.wait_and_reset();
tracing::info!("Thread woke up from latch");
});
// Give the thread time to block
std::thread::sleep(std::time::Duration::from_millis(100));
assert!(!latch.probe());
tracing::info!("Setting latch from main thread");
latch.set();
tracing::info!("Latch set, joining waiting thread");
handle.join().expect("Thread should join successfully");
}
}