executor/distaff/src/latch.rs

529 lines
12 KiB
Rust

use core::{
marker::PhantomData,
sync::atomic::{AtomicUsize, Ordering},
};
use std::{
cell::UnsafeCell,
mem,
ops::DerefMut,
sync::{
Arc,
atomic::{AtomicPtr, AtomicU8},
},
};
use parking_lot::{Condvar, Mutex};
use crate::{WorkerThread, channel::Parker, context::Context};
pub trait Latch {
unsafe fn set_raw(this: *const Self);
}
pub trait Probe {
fn probe(&self) -> bool;
}
pub type CoreLatch = AtomicLatch;
pub trait AsCoreLatch {
fn as_core_latch(&self) -> &CoreLatch;
}
#[derive(Debug)]
pub struct AtomicLatch {
inner: AtomicU8,
}
impl AtomicLatch {
pub const UNSET: u8 = 0;
pub const SET: u8 = 1;
pub const SLEEPING: u8 = 2;
pub const WAKEUP: u8 = 4;
pub const HEARTBEAT: u8 = 8;
#[inline]
pub const fn new() -> Self {
Self {
inner: AtomicU8::new(Self::UNSET),
}
}
pub const fn new_set() -> Self {
Self {
inner: AtomicU8::new(Self::SET),
}
}
#[inline]
pub fn unset(&self) {
self.inner.fetch_and(!Self::SET, Ordering::Release);
}
pub fn reset(&self) -> u8 {
self.inner.swap(Self::UNSET, Ordering::Release)
}
pub fn get(&self) -> u8 {
self.inner.load(Ordering::Acquire)
}
pub fn poll_heartbeat(&self) -> bool {
self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT
== Self::HEARTBEAT
}
/// returns true if the latch was already set.
pub fn set_sleeping(&self) -> bool {
self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET
}
pub fn is_sleeping(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING
}
pub fn is_heartbeat(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT
}
pub fn is_wakeup(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP
}
pub fn is_set(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
}
pub fn set_wakeup(&self) {
self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed);
}
pub fn set_heartbeat(&self) {
self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed);
}
/// returns true if the latch was previously sleeping.
#[inline]
pub unsafe fn set(this: *const Self) -> bool {
unsafe {
let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed);
old & Self::SLEEPING == Self::SLEEPING
}
}
}
impl Latch for AtomicLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe {
Self::set(this);
}
}
}
impl Probe for AtomicLatch {
#[inline]
fn probe(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SET != 0
}
}
impl AsCoreLatch for AtomicLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
self
}
}
pub struct LatchRef<'a, L: Latch> {
inner: *const L,
_marker: PhantomData<&'a L>,
}
impl<'a, L: Latch> LatchRef<'a, L> {
#[inline]
pub const fn new(latch: &'a L) -> Self {
Self {
inner: latch,
_marker: PhantomData,
}
}
}
impl<'a, L: Latch> Latch for LatchRef<'a, L> {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe {
let this = &*this;
Latch::set_raw(this.inner);
}
}
}
impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> {
#[inline]
fn probe(&self) -> bool {
unsafe {
let this = &*self.inner;
Probe::probe(this)
}
}
}
impl<'a, L> AsCoreLatch for LatchRef<'a, L>
where
L: Latch + AsCoreLatch,
{
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
unsafe {
let this = &*self.inner;
this.as_core_latch()
}
}
}
pub struct NopLatch;
impl Latch for NopLatch {
#[inline]
unsafe fn set_raw(_this: *const Self) {
// do nothing
}
}
impl Probe for NopLatch {
#[inline]
fn probe(&self) -> bool {
false // always returns false
}
}
pub struct CountLatch {
count: AtomicUsize,
inner: AtomicPtr<Parker>,
}
impl CountLatch {
#[inline]
pub const fn new(inner: *const Parker) -> Self {
Self {
count: AtomicUsize::new(0),
inner: AtomicPtr::new(inner as *mut Parker),
}
}
pub fn set_inner(&self, inner: *const Parker) {
self.inner.store(inner as *mut Parker, Ordering::Relaxed);
}
pub fn count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
#[inline]
pub fn increment(&self) {
self.count.fetch_add(1, Ordering::Release);
}
#[inline]
pub fn decrement(&self) {
unsafe {
Latch::set_raw(self);
}
}
}
impl Latch for CountLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe {
if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
// If the count was 1, we need to set the inner latch.
let inner = (*this).inner.load(Ordering::Relaxed);
if !inner.is_null() {
(&*inner).unpark();
}
}
}
}
}
impl Probe for CountLatch {
#[inline]
fn probe(&self) -> bool {
self.count.load(Ordering::Relaxed) == 0
}
}
pub struct MutexLatch {
inner: AtomicLatch,
lock: Mutex<()>,
condvar: Condvar,
}
unsafe impl Send for MutexLatch {}
unsafe impl Sync for MutexLatch {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WakeResult {
Wake,
Heartbeat,
Set,
}
impl MutexLatch {
#[inline]
pub const fn new() -> Self {
Self {
inner: AtomicLatch::new(),
lock: Mutex::new(()),
condvar: Condvar::new(),
}
}
#[inline]
pub fn reset(&self) {
let _guard = self.lock.lock();
// SAFETY: inner is atomic, so we can safely access it.
self.inner.reset();
}
pub fn wait_and_reset(&self) {
// SAFETY: inner is locked by the mutex, so we can safely access it.
let mut guard = self.lock.lock();
while !self.inner.probe() {
self.condvar.wait(&mut guard);
}
self.inner.reset();
}
pub fn set(&self) {
unsafe {
Latch::set_raw(self);
}
}
}
impl Latch for MutexLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
// SAFETY: `this` is valid until the guard is dropped.
unsafe {
let this = &*this;
let _guard = this.lock.lock();
Latch::set_raw(&this.inner);
this.condvar.notify_all();
}
}
}
impl Probe for MutexLatch {
#[inline]
fn probe(&self) -> bool {
let _guard = self.lock.lock();
// SAFETY: inner is atomic, so we can safely access it.
self.inner.probe()
}
}
impl AsCoreLatch for MutexLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
// SAFETY: inner is atomic, so we can safely access it.
self.inner.as_core_latch()
}
}
// The worker waits on this latch whenever it has nothing to do.
#[derive(Debug)]
pub struct WorkerLatch {
// this boolean is set when the worker is waiting.
mutex: Mutex<bool>,
condvar: Condvar,
}
impl WorkerLatch {
pub fn new() -> Self {
Self {
mutex: Mutex::new(false),
condvar: Condvar::new(),
}
}
#[tracing::instrument(level = "trace", skip_all, fields(
this = self as *const Self as usize,
))]
pub fn lock(&self) -> parking_lot::MutexGuard<'_, bool> {
tracing::trace!("aquiring mutex..");
let guard = self.mutex.lock();
tracing::trace!("mutex acquired.");
guard
}
pub unsafe fn force_unlock(&self) {
unsafe {
self.mutex.force_unlock();
}
}
pub fn wait(&self) {
let condvar = &self.condvar;
let mut guard = self.lock();
Self::wait_internal(condvar, &mut guard);
}
fn wait_internal(condvar: &Condvar, guard: &mut parking_lot::MutexGuard<'_, bool>) {
**guard = true; // set the mutex to true to indicate that the worker is waiting
//condvar.wait_for(guard, std::time::Duration::from_micros(100));
condvar.wait(guard);
**guard = false;
}
#[tracing::instrument(level = "trace", skip_all, fields(
this = self as *const Self as usize,
))]
pub fn wait_unless<F>(&self, mut f: F)
where
F: FnMut() -> bool,
{
let mut guard = self.lock();
if !f() {
Self::wait_internal(&self.condvar, &mut guard);
}
}
#[tracing::instrument(level = "trace", skip_all, fields(
this = self as *const Self as usize,
))]
pub fn wait_until<F, T>(&self, mut f: F) -> T
where
F: FnMut() -> Option<T>,
{
let mut guard = self.lock();
loop {
if let Some(result) = f() {
return result;
}
Self::wait_internal(&self.condvar, &mut guard);
}
}
pub fn is_waiting(&self) -> bool {
*self.mutex.lock()
}
#[tracing::instrument(level = "trace", skip_all, fields(
this = self as *const Self as usize,
))]
fn notify(&self) {
let n = self.condvar.notify_all();
tracing::trace!("WorkerLatch notify: notified {} threads", n);
}
pub fn wake(&self) {
self.notify();
}
}
#[cfg(test)]
mod tests {
use std::{ptr, sync::Barrier};
use tracing_test::traced_test;
use super::*;
#[test]
fn test_atomic_latch() {
let latch = AtomicLatch::new();
assert_eq!(latch.get(), AtomicLatch::UNSET);
unsafe {
assert!(!latch.probe());
AtomicLatch::set_raw(&latch);
}
assert_eq!(latch.get(), AtomicLatch::SET);
assert!(latch.probe());
latch.unset();
assert_eq!(latch.get(), AtomicLatch::UNSET);
}
#[test]
fn core_latch_sleep() {
let latch = AtomicLatch::new();
assert_eq!(latch.get(), AtomicLatch::UNSET);
latch.set_sleeping();
assert_eq!(latch.get(), AtomicLatch::SLEEPING);
unsafe {
assert!(!latch.probe());
assert!(AtomicLatch::set(&latch));
}
assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING);
assert!(latch.probe());
latch.reset();
assert_eq!(latch.get(), AtomicLatch::UNSET);
}
#[test]
fn nop_latch() {
assert!(
core::mem::size_of::<NopLatch>() == 0,
"NopLatch should be zero-sized"
);
}
#[test]
fn count_latch() {
let latch = CountLatch::new(ptr::null());
assert_eq!(latch.count(), 0);
latch.increment();
assert_eq!(latch.count(), 1);
assert!(!latch.probe());
latch.increment();
assert_eq!(latch.count(), 2);
assert!(!latch.probe());
unsafe {
Latch::set_raw(&latch);
}
assert!(!latch.probe());
assert_eq!(latch.count(), 1);
unsafe {
Latch::set_raw(&latch);
}
assert!(latch.probe());
assert_eq!(latch.count(), 0);
}
#[test]
#[cfg_attr(not(miri), traced_test)]
fn mutex_latch() {
let latch = Arc::new(MutexLatch::new());
assert!(!latch.probe());
latch.set();
assert!(latch.probe());
latch.reset();
assert!(!latch.probe());
// Test wait functionality
let latch_clone = latch.clone();
let handle = std::thread::spawn(move || {
tracing::info!("Thread waiting on latch");
latch_clone.wait_and_reset();
tracing::info!("Thread woke up from latch");
});
// Give the thread time to block
std::thread::sleep(std::time::Duration::from_millis(100));
assert!(!latch.probe());
tracing::info!("Setting latch from main thread");
latch.set();
tracing::info!("Latch set, joining waiting thread");
handle.join().expect("Thread should join successfully");
}
}