529 lines
12 KiB
Rust
529 lines
12 KiB
Rust
use core::{
|
|
marker::PhantomData,
|
|
sync::atomic::{AtomicUsize, Ordering},
|
|
};
|
|
use std::{
|
|
cell::UnsafeCell,
|
|
mem,
|
|
ops::DerefMut,
|
|
sync::{
|
|
Arc,
|
|
atomic::{AtomicPtr, AtomicU8},
|
|
},
|
|
};
|
|
|
|
use parking_lot::{Condvar, Mutex};
|
|
|
|
use crate::{WorkerThread, channel::Parker, context::Context};
|
|
|
|
pub trait Latch {
|
|
unsafe fn set_raw(this: *const Self);
|
|
}
|
|
|
|
pub trait Probe {
|
|
fn probe(&self) -> bool;
|
|
}
|
|
|
|
pub type CoreLatch = AtomicLatch;
|
|
pub trait AsCoreLatch {
|
|
fn as_core_latch(&self) -> &CoreLatch;
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct AtomicLatch {
|
|
inner: AtomicU8,
|
|
}
|
|
|
|
impl AtomicLatch {
|
|
pub const UNSET: u8 = 0;
|
|
pub const SET: u8 = 1;
|
|
pub const SLEEPING: u8 = 2;
|
|
pub const WAKEUP: u8 = 4;
|
|
pub const HEARTBEAT: u8 = 8;
|
|
|
|
#[inline]
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
inner: AtomicU8::new(Self::UNSET),
|
|
}
|
|
}
|
|
|
|
pub const fn new_set() -> Self {
|
|
Self {
|
|
inner: AtomicU8::new(Self::SET),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn unset(&self) {
|
|
self.inner.fetch_and(!Self::SET, Ordering::Release);
|
|
}
|
|
|
|
pub fn reset(&self) -> u8 {
|
|
self.inner.swap(Self::UNSET, Ordering::Release)
|
|
}
|
|
|
|
pub fn get(&self) -> u8 {
|
|
self.inner.load(Ordering::Acquire)
|
|
}
|
|
|
|
pub fn poll_heartbeat(&self) -> bool {
|
|
self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT
|
|
== Self::HEARTBEAT
|
|
}
|
|
|
|
/// returns true if the latch was already set.
|
|
pub fn set_sleeping(&self) -> bool {
|
|
self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET
|
|
}
|
|
|
|
pub fn is_sleeping(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING
|
|
}
|
|
|
|
pub fn is_heartbeat(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT
|
|
}
|
|
|
|
pub fn is_wakeup(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP
|
|
}
|
|
|
|
pub fn is_set(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
|
|
}
|
|
|
|
pub fn set_wakeup(&self) {
|
|
self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed);
|
|
}
|
|
|
|
pub fn set_heartbeat(&self) {
|
|
self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed);
|
|
}
|
|
|
|
/// returns true if the latch was previously sleeping.
|
|
#[inline]
|
|
pub unsafe fn set(this: *const Self) -> bool {
|
|
unsafe {
|
|
let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed);
|
|
old & Self::SLEEPING == Self::SLEEPING
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Latch for AtomicLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
unsafe {
|
|
Self::set(this);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Probe for AtomicLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::SET != 0
|
|
}
|
|
}
|
|
impl AsCoreLatch for AtomicLatch {
|
|
#[inline]
|
|
fn as_core_latch(&self) -> &CoreLatch {
|
|
self
|
|
}
|
|
}
|
|
|
|
pub struct LatchRef<'a, L: Latch> {
|
|
inner: *const L,
|
|
_marker: PhantomData<&'a L>,
|
|
}
|
|
|
|
impl<'a, L: Latch> LatchRef<'a, L> {
|
|
#[inline]
|
|
pub const fn new(latch: &'a L) -> Self {
|
|
Self {
|
|
inner: latch,
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, L: Latch> Latch for LatchRef<'a, L> {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
unsafe {
|
|
let this = &*this;
|
|
Latch::set_raw(this.inner);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
unsafe {
|
|
let this = &*self.inner;
|
|
Probe::probe(this)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, L> AsCoreLatch for LatchRef<'a, L>
|
|
where
|
|
L: Latch + AsCoreLatch,
|
|
{
|
|
#[inline]
|
|
fn as_core_latch(&self) -> &CoreLatch {
|
|
unsafe {
|
|
let this = &*self.inner;
|
|
this.as_core_latch()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct NopLatch;
|
|
|
|
impl Latch for NopLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(_this: *const Self) {
|
|
// do nothing
|
|
}
|
|
}
|
|
|
|
impl Probe for NopLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
false // always returns false
|
|
}
|
|
}
|
|
|
|
pub struct CountLatch {
|
|
count: AtomicUsize,
|
|
inner: AtomicPtr<Parker>,
|
|
}
|
|
|
|
impl CountLatch {
|
|
#[inline]
|
|
pub const fn new(inner: *const Parker) -> Self {
|
|
Self {
|
|
count: AtomicUsize::new(0),
|
|
inner: AtomicPtr::new(inner as *mut Parker),
|
|
}
|
|
}
|
|
|
|
pub fn set_inner(&self, inner: *const Parker) {
|
|
self.inner.store(inner as *mut Parker, Ordering::Relaxed);
|
|
}
|
|
|
|
pub fn count(&self) -> usize {
|
|
self.count.load(Ordering::Relaxed)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn increment(&self) {
|
|
self.count.fetch_add(1, Ordering::Release);
|
|
}
|
|
|
|
#[inline]
|
|
pub fn decrement(&self) {
|
|
unsafe {
|
|
Latch::set_raw(self);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Latch for CountLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
unsafe {
|
|
if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
|
|
tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
|
|
// If the count was 1, we need to set the inner latch.
|
|
let inner = (*this).inner.load(Ordering::Relaxed);
|
|
if !inner.is_null() {
|
|
(&*inner).unpark();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Probe for CountLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
self.count.load(Ordering::Relaxed) == 0
|
|
}
|
|
}
|
|
|
|
pub struct MutexLatch {
|
|
inner: AtomicLatch,
|
|
lock: Mutex<()>,
|
|
condvar: Condvar,
|
|
}
|
|
|
|
unsafe impl Send for MutexLatch {}
|
|
unsafe impl Sync for MutexLatch {}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub(crate) enum WakeResult {
|
|
Wake,
|
|
Heartbeat,
|
|
Set,
|
|
}
|
|
|
|
impl MutexLatch {
|
|
#[inline]
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
inner: AtomicLatch::new(),
|
|
lock: Mutex::new(()),
|
|
condvar: Condvar::new(),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn reset(&self) {
|
|
let _guard = self.lock.lock();
|
|
// SAFETY: inner is atomic, so we can safely access it.
|
|
self.inner.reset();
|
|
}
|
|
|
|
pub fn wait_and_reset(&self) {
|
|
// SAFETY: inner is locked by the mutex, so we can safely access it.
|
|
let mut guard = self.lock.lock();
|
|
while !self.inner.probe() {
|
|
self.condvar.wait(&mut guard);
|
|
}
|
|
|
|
self.inner.reset();
|
|
}
|
|
|
|
pub fn set(&self) {
|
|
unsafe {
|
|
Latch::set_raw(self);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Latch for MutexLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
// SAFETY: `this` is valid until the guard is dropped.
|
|
unsafe {
|
|
let this = &*this;
|
|
let _guard = this.lock.lock();
|
|
Latch::set_raw(&this.inner);
|
|
this.condvar.notify_all();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Probe for MutexLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
let _guard = self.lock.lock();
|
|
// SAFETY: inner is atomic, so we can safely access it.
|
|
self.inner.probe()
|
|
}
|
|
}
|
|
|
|
impl AsCoreLatch for MutexLatch {
|
|
#[inline]
|
|
fn as_core_latch(&self) -> &CoreLatch {
|
|
// SAFETY: inner is atomic, so we can safely access it.
|
|
self.inner.as_core_latch()
|
|
}
|
|
}
|
|
|
|
// The worker waits on this latch whenever it has nothing to do.
|
|
#[derive(Debug)]
|
|
pub struct WorkerLatch {
|
|
// this boolean is set when the worker is waiting.
|
|
mutex: Mutex<bool>,
|
|
condvar: Condvar,
|
|
}
|
|
|
|
impl WorkerLatch {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
mutex: Mutex::new(false),
|
|
condvar: Condvar::new(),
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(level = "trace", skip_all, fields(
|
|
this = self as *const Self as usize,
|
|
))]
|
|
pub fn lock(&self) -> parking_lot::MutexGuard<'_, bool> {
|
|
tracing::trace!("aquiring mutex..");
|
|
let guard = self.mutex.lock();
|
|
tracing::trace!("mutex acquired.");
|
|
|
|
guard
|
|
}
|
|
|
|
pub unsafe fn force_unlock(&self) {
|
|
unsafe {
|
|
self.mutex.force_unlock();
|
|
}
|
|
}
|
|
|
|
pub fn wait(&self) {
|
|
let condvar = &self.condvar;
|
|
let mut guard = self.lock();
|
|
|
|
Self::wait_internal(condvar, &mut guard);
|
|
}
|
|
|
|
fn wait_internal(condvar: &Condvar, guard: &mut parking_lot::MutexGuard<'_, bool>) {
|
|
**guard = true; // set the mutex to true to indicate that the worker is waiting
|
|
//condvar.wait_for(guard, std::time::Duration::from_micros(100));
|
|
condvar.wait(guard);
|
|
**guard = false;
|
|
}
|
|
|
|
#[tracing::instrument(level = "trace", skip_all, fields(
|
|
this = self as *const Self as usize,
|
|
))]
|
|
pub fn wait_unless<F>(&self, mut f: F)
|
|
where
|
|
F: FnMut() -> bool,
|
|
{
|
|
let mut guard = self.lock();
|
|
if !f() {
|
|
Self::wait_internal(&self.condvar, &mut guard);
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(level = "trace", skip_all, fields(
|
|
this = self as *const Self as usize,
|
|
))]
|
|
pub fn wait_until<F, T>(&self, mut f: F) -> T
|
|
where
|
|
F: FnMut() -> Option<T>,
|
|
{
|
|
let mut guard = self.lock();
|
|
loop {
|
|
if let Some(result) = f() {
|
|
return result;
|
|
}
|
|
Self::wait_internal(&self.condvar, &mut guard);
|
|
}
|
|
}
|
|
|
|
pub fn is_waiting(&self) -> bool {
|
|
*self.mutex.lock()
|
|
}
|
|
|
|
#[tracing::instrument(level = "trace", skip_all, fields(
|
|
this = self as *const Self as usize,
|
|
))]
|
|
fn notify(&self) {
|
|
let n = self.condvar.notify_all();
|
|
tracing::trace!("WorkerLatch notify: notified {} threads", n);
|
|
}
|
|
|
|
pub fn wake(&self) {
|
|
self.notify();
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::{ptr, sync::Barrier};
|
|
|
|
use tracing_test::traced_test;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_atomic_latch() {
|
|
let latch = AtomicLatch::new();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
unsafe {
|
|
assert!(!latch.probe());
|
|
AtomicLatch::set_raw(&latch);
|
|
}
|
|
assert_eq!(latch.get(), AtomicLatch::SET);
|
|
assert!(latch.probe());
|
|
latch.unset();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
}
|
|
|
|
#[test]
|
|
fn core_latch_sleep() {
|
|
let latch = AtomicLatch::new();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
latch.set_sleeping();
|
|
assert_eq!(latch.get(), AtomicLatch::SLEEPING);
|
|
unsafe {
|
|
assert!(!latch.probe());
|
|
assert!(AtomicLatch::set(&latch));
|
|
}
|
|
assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING);
|
|
assert!(latch.probe());
|
|
latch.reset();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
}
|
|
|
|
#[test]
|
|
fn nop_latch() {
|
|
assert!(
|
|
core::mem::size_of::<NopLatch>() == 0,
|
|
"NopLatch should be zero-sized"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn count_latch() {
|
|
let latch = CountLatch::new(ptr::null());
|
|
assert_eq!(latch.count(), 0);
|
|
latch.increment();
|
|
assert_eq!(latch.count(), 1);
|
|
assert!(!latch.probe());
|
|
latch.increment();
|
|
assert_eq!(latch.count(), 2);
|
|
assert!(!latch.probe());
|
|
|
|
unsafe {
|
|
Latch::set_raw(&latch);
|
|
}
|
|
assert!(!latch.probe());
|
|
assert_eq!(latch.count(), 1);
|
|
|
|
unsafe {
|
|
Latch::set_raw(&latch);
|
|
}
|
|
assert!(latch.probe());
|
|
assert_eq!(latch.count(), 0);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(not(miri), traced_test)]
|
|
fn mutex_latch() {
|
|
let latch = Arc::new(MutexLatch::new());
|
|
assert!(!latch.probe());
|
|
latch.set();
|
|
assert!(latch.probe());
|
|
latch.reset();
|
|
assert!(!latch.probe());
|
|
|
|
// Test wait functionality
|
|
let latch_clone = latch.clone();
|
|
let handle = std::thread::spawn(move || {
|
|
tracing::info!("Thread waiting on latch");
|
|
latch_clone.wait_and_reset();
|
|
tracing::info!("Thread woke up from latch");
|
|
});
|
|
|
|
// Give the thread time to block
|
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
|
assert!(!latch.probe());
|
|
|
|
tracing::info!("Setting latch from main thread");
|
|
latch.set();
|
|
tracing::info!("Latch set, joining waiting thread");
|
|
handle.join().expect("Thread should join successfully");
|
|
}
|
|
}
|