671 lines
17 KiB
Rust
671 lines
17 KiB
Rust
use core::{
|
|
marker::PhantomData,
|
|
sync::atomic::{AtomicUsize, Ordering},
|
|
};
|
|
use std::{
|
|
cell::UnsafeCell,
|
|
mem,
|
|
ops::DerefMut,
|
|
sync::{
|
|
Arc,
|
|
atomic::{AtomicPtr, AtomicU8},
|
|
},
|
|
};
|
|
|
|
use parking_lot::{Condvar, Mutex};
|
|
|
|
use crate::{WorkerThread, context::Context};
|
|
|
|
pub trait Latch {
|
|
unsafe fn set_raw(this: *const Self);
|
|
}
|
|
|
|
pub trait Probe {
|
|
fn probe(&self) -> bool;
|
|
}
|
|
|
|
pub type CoreLatch = AtomicLatch;
|
|
pub trait AsCoreLatch {
|
|
fn as_core_latch(&self) -> &CoreLatch;
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct AtomicLatch {
|
|
inner: AtomicU8,
|
|
}
|
|
|
|
impl AtomicLatch {
|
|
pub const UNSET: u8 = 0;
|
|
pub const SET: u8 = 1;
|
|
pub const SLEEPING: u8 = 2;
|
|
pub const WAKEUP: u8 = 4;
|
|
pub const HEARTBEAT: u8 = 8;
|
|
|
|
#[inline]
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
inner: AtomicU8::new(Self::UNSET),
|
|
}
|
|
}
|
|
|
|
pub const fn new_set() -> Self {
|
|
Self {
|
|
inner: AtomicU8::new(Self::SET),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn unset(&self) {
|
|
self.inner.fetch_and(!Self::SET, Ordering::Release);
|
|
}
|
|
|
|
pub fn reset(&self) -> u8 {
|
|
self.inner.swap(Self::UNSET, Ordering::Release)
|
|
}
|
|
|
|
pub fn get(&self) -> u8 {
|
|
self.inner.load(Ordering::Acquire)
|
|
}
|
|
|
|
pub fn poll_heartbeat(&self) -> bool {
|
|
self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT
|
|
== Self::HEARTBEAT
|
|
}
|
|
|
|
/// returns true if the latch was already set.
|
|
pub fn set_sleeping(&self) -> bool {
|
|
self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET
|
|
}
|
|
|
|
pub fn is_sleeping(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING
|
|
}
|
|
|
|
pub fn is_heartbeat(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT
|
|
}
|
|
|
|
pub fn is_wakeup(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP
|
|
}
|
|
|
|
pub fn is_set(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
|
|
}
|
|
|
|
pub fn set_wakeup(&self) {
|
|
self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed);
|
|
}
|
|
|
|
pub fn set_heartbeat(&self) {
|
|
self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed);
|
|
}
|
|
|
|
/// returns true if the latch was previously sleeping.
|
|
#[inline]
|
|
pub unsafe fn set(this: *const Self) -> bool {
|
|
unsafe {
|
|
let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed);
|
|
old & Self::SLEEPING == Self::SLEEPING
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Latch for AtomicLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
unsafe {
|
|
Self::set(this);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Probe for AtomicLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
self.inner.load(Ordering::Relaxed) & Self::SET != 0
|
|
}
|
|
}
|
|
impl AsCoreLatch for AtomicLatch {
|
|
#[inline]
|
|
fn as_core_latch(&self) -> &CoreLatch {
|
|
self
|
|
}
|
|
}
|
|
|
|
pub struct LatchRef<'a, L: Latch> {
|
|
inner: *const L,
|
|
_marker: PhantomData<&'a L>,
|
|
}
|
|
|
|
impl<'a, L: Latch> LatchRef<'a, L> {
|
|
#[inline]
|
|
pub const fn new(latch: &'a L) -> Self {
|
|
Self {
|
|
inner: latch,
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, L: Latch> Latch for LatchRef<'a, L> {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
unsafe {
|
|
let this = &*this;
|
|
Latch::set_raw(this.inner);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
unsafe {
|
|
let this = &*self.inner;
|
|
Probe::probe(this)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, L> AsCoreLatch for LatchRef<'a, L>
|
|
where
|
|
L: Latch + AsCoreLatch,
|
|
{
|
|
#[inline]
|
|
fn as_core_latch(&self) -> &CoreLatch {
|
|
unsafe {
|
|
let this = &*self.inner;
|
|
this.as_core_latch()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct NopLatch;
|
|
|
|
impl Latch for NopLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(_this: *const Self) {
|
|
// do nothing
|
|
}
|
|
}
|
|
|
|
impl Probe for NopLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
false // always returns false
|
|
}
|
|
}
|
|
|
|
pub struct CountLatch {
|
|
count: AtomicUsize,
|
|
inner: AtomicPtr<WorkerLatch>,
|
|
}
|
|
|
|
impl CountLatch {
|
|
#[inline]
|
|
pub const fn new(inner: *const WorkerLatch) -> Self {
|
|
Self {
|
|
count: AtomicUsize::new(0),
|
|
inner: AtomicPtr::new(inner as *mut WorkerLatch),
|
|
}
|
|
}
|
|
|
|
pub fn set_inner(&self, inner: *const WorkerLatch) {
|
|
self.inner
|
|
.store(inner as *mut WorkerLatch, Ordering::Relaxed);
|
|
}
|
|
|
|
pub fn count(&self) -> usize {
|
|
self.count.load(Ordering::Relaxed)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn increment(&self) {
|
|
self.count.fetch_add(1, Ordering::Release);
|
|
}
|
|
|
|
#[inline]
|
|
pub fn decrement(&self) {
|
|
unsafe {
|
|
Latch::set_raw(self);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Latch for CountLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
unsafe {
|
|
if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
|
|
tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
|
|
// If the count was 1, we need to set the inner latch.
|
|
let inner = (*this).inner.load(Ordering::Relaxed);
|
|
if !inner.is_null() {
|
|
(&*inner).wake();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Probe for CountLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
self.count.load(Ordering::Relaxed) == 0
|
|
}
|
|
}
|
|
|
|
pub struct MutexLatch {
|
|
inner: AtomicLatch,
|
|
lock: Mutex<()>,
|
|
condvar: Condvar,
|
|
}
|
|
|
|
unsafe impl Send for MutexLatch {}
|
|
unsafe impl Sync for MutexLatch {}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub(crate) enum WakeResult {
|
|
Wake,
|
|
Heartbeat,
|
|
Set,
|
|
}
|
|
|
|
impl MutexLatch {
|
|
#[inline]
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
inner: AtomicLatch::new(),
|
|
lock: Mutex::new(()),
|
|
condvar: Condvar::new(),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn reset(&self) {
|
|
let _guard = self.lock.lock();
|
|
// SAFETY: inner is atomic, so we can safely access it.
|
|
self.inner.reset();
|
|
}
|
|
|
|
pub fn wait_and_reset(&self) {
|
|
// SAFETY: inner is locked by the mutex, so we can safely access it.
|
|
let mut guard = self.lock.lock();
|
|
while !self.inner.probe() {
|
|
self.condvar.wait(&mut guard);
|
|
}
|
|
|
|
self.inner.reset();
|
|
}
|
|
|
|
pub fn set(&self) {
|
|
unsafe {
|
|
Latch::set_raw(self);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Latch for MutexLatch {
|
|
#[inline]
|
|
unsafe fn set_raw(this: *const Self) {
|
|
// SAFETY: `this` is valid until the guard is dropped.
|
|
unsafe {
|
|
let this = &*this;
|
|
let _guard = this.lock.lock();
|
|
Latch::set_raw(&this.inner);
|
|
this.condvar.notify_all();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Probe for MutexLatch {
|
|
#[inline]
|
|
fn probe(&self) -> bool {
|
|
let _guard = self.lock.lock();
|
|
// SAFETY: inner is atomic, so we can safely access it.
|
|
self.inner.probe()
|
|
}
|
|
}
|
|
|
|
impl AsCoreLatch for MutexLatch {
|
|
#[inline]
|
|
fn as_core_latch(&self) -> &CoreLatch {
|
|
// SAFETY: inner is atomic, so we can safely access it.
|
|
self.inner.as_core_latch()
|
|
}
|
|
}
|
|
|
|
// The worker waits on this latch whenever it has nothing to do.
|
|
pub struct WorkerLatch {
|
|
// this boolean is set when the worker is waiting.
|
|
mutex: Mutex<bool>,
|
|
condvar: AtomicUsize,
|
|
}
|
|
|
|
impl WorkerLatch {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
mutex: Mutex::new(false),
|
|
condvar: AtomicUsize::new(0),
|
|
}
|
|
}
|
|
pub fn lock(&self) {
|
|
mem::forget(self.mutex.lock());
|
|
}
|
|
pub fn unlock(&self) {
|
|
unsafe {
|
|
self.mutex.force_unlock();
|
|
}
|
|
}
|
|
|
|
pub fn wait(&self) {
|
|
let condvar = &self.condvar;
|
|
let mut guard = self.mutex.lock();
|
|
|
|
Self::wait_internal(condvar, &mut guard);
|
|
}
|
|
|
|
fn wait_internal(condvar: &AtomicUsize, guard: &mut parking_lot::MutexGuard<'_, bool>) {
|
|
let mutex = parking_lot::MutexGuard::mutex(guard);
|
|
let key = condvar as *const _ as usize;
|
|
let lock_addr = mutex as *const _ as usize;
|
|
let mut requeued = false;
|
|
|
|
let state = unsafe { AtomicUsize::from_ptr(condvar as *const _ as *mut usize) };
|
|
|
|
**guard = true; // set the mutex to true to indicate that the worker is waiting
|
|
|
|
unsafe {
|
|
parking_lot_core::park(
|
|
key,
|
|
|| {
|
|
let old = state.load(Ordering::Relaxed);
|
|
if old == 0 {
|
|
state.store(lock_addr, Ordering::Relaxed);
|
|
} else if old != lock_addr {
|
|
return false;
|
|
}
|
|
|
|
true
|
|
},
|
|
|| {
|
|
mutex.force_unlock();
|
|
},
|
|
|k, was_last_thread| {
|
|
requeued = k != key;
|
|
if !requeued && was_last_thread {
|
|
state.store(0, Ordering::Relaxed);
|
|
}
|
|
},
|
|
parking_lot_core::DEFAULT_PARK_TOKEN,
|
|
None,
|
|
);
|
|
}
|
|
// relock
|
|
|
|
let mut new = mutex.lock();
|
|
mem::swap(&mut new, guard);
|
|
mem::forget(new); // forget the new guard to avoid dropping it
|
|
|
|
**guard = false; // reset the mutex to false after waking up
|
|
}
|
|
|
|
fn wait_with_lock_internal<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
|
|
let key = &self.condvar as *const _ as usize;
|
|
let lock_addr = &self.mutex as *const _ as usize;
|
|
let mut requeued = false;
|
|
|
|
let mut guard = self.mutex.lock();
|
|
|
|
let state = unsafe { AtomicUsize::from_ptr(&self.condvar as *const _ as *mut usize) };
|
|
|
|
*guard = true; // set the mutex to true to indicate that the worker is waiting
|
|
|
|
unsafe {
|
|
let token = parking_lot_core::park(
|
|
key,
|
|
|| {
|
|
let old = state.load(Ordering::Relaxed);
|
|
if old == 0 {
|
|
state.store(lock_addr, Ordering::Relaxed);
|
|
} else if old != lock_addr {
|
|
return false;
|
|
}
|
|
|
|
true
|
|
},
|
|
|| {
|
|
drop(guard); // drop the guard to release the lock
|
|
parking_lot::MutexGuard::mutex(&other).force_unlock();
|
|
},
|
|
|k, was_last_thread| {
|
|
requeued = k != key;
|
|
if !requeued && was_last_thread {
|
|
state.store(0, Ordering::Relaxed);
|
|
}
|
|
},
|
|
parking_lot_core::DEFAULT_PARK_TOKEN,
|
|
None,
|
|
);
|
|
|
|
tracing::trace!(
|
|
"WorkerLatch wait_with_lock_internal: unparked with token {:?}",
|
|
token
|
|
);
|
|
}
|
|
// relock
|
|
let mut other2 = parking_lot::MutexGuard::mutex(&other).lock();
|
|
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other");
|
|
|
|
// because `other` is logically unlocked, we swap it with `other2` and then forget `other2`
|
|
core::mem::swap(&mut *other2, &mut *other);
|
|
core::mem::forget(other2);
|
|
|
|
let mut guard = self.mutex.lock();
|
|
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked self");
|
|
|
|
*guard = false; // reset the mutex to false after waking up
|
|
}
|
|
|
|
pub fn wait_with_lock<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
|
|
self.wait_with_lock_internal(other);
|
|
}
|
|
|
|
pub fn wait_with_lock_while<T, F>(&self, other: &mut parking_lot::MutexGuard<'_, T>, mut f: F)
|
|
where
|
|
F: FnMut(&mut T) -> bool,
|
|
{
|
|
while f(other.deref_mut()) {
|
|
self.wait_with_lock_internal(other);
|
|
}
|
|
}
|
|
|
|
pub fn wait_until<F, T>(&self, mut f: F) -> T
|
|
where
|
|
F: FnMut() -> Option<T>,
|
|
{
|
|
let mut guard = self.mutex.lock();
|
|
loop {
|
|
if let Some(result) = f() {
|
|
return result;
|
|
}
|
|
Self::wait_internal(&self.condvar, &mut guard);
|
|
}
|
|
}
|
|
|
|
pub fn is_waiting(&self) -> bool {
|
|
*self.mutex.lock()
|
|
}
|
|
|
|
fn notify(&self) {
|
|
let key = &self.condvar as *const _ as usize;
|
|
|
|
unsafe {
|
|
let n = parking_lot_core::unpark_all(key, parking_lot_core::DEFAULT_UNPARK_TOKEN);
|
|
tracing::trace!("WorkerLatch notify_one: unparked {} threads", n);
|
|
}
|
|
}
|
|
|
|
pub fn wake(&self) {
|
|
self.notify();
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::{ptr, sync::Barrier};
|
|
|
|
use tracing_test::traced_test;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
#[cfg_attr(not(miri), traced_test)]
|
|
fn worker_latch() {
|
|
let latch = Arc::new(WorkerLatch::new());
|
|
let barrier = Arc::new(Barrier::new(2));
|
|
let mutex = Arc::new(parking_lot::Mutex::new(false));
|
|
|
|
let count = Arc::new(AtomicUsize::new(0));
|
|
|
|
let thread = std::thread::spawn({
|
|
let latch = latch.clone();
|
|
let mutex = mutex.clone();
|
|
let barrier = barrier.clone();
|
|
let count = count.clone();
|
|
|
|
move || {
|
|
tracing::info!("Thread waiting on barrier");
|
|
let mut guard = mutex.lock();
|
|
barrier.wait();
|
|
|
|
tracing::info!("Thread waiting on latch");
|
|
latch.wait_with_lock(&mut guard);
|
|
count.fetch_add(1, Ordering::Relaxed);
|
|
tracing::info!("Thread woke up from latch");
|
|
barrier.wait();
|
|
tracing::info!("Thread finished waiting on barrier");
|
|
count.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
});
|
|
|
|
assert!(!latch.is_waiting(), "Latch should not be waiting yet");
|
|
barrier.wait();
|
|
tracing::info!("Main thread finished waiting on barrier");
|
|
// lock mutex and notify the thread that isn't yet waiting.
|
|
{
|
|
let guard = mutex.lock();
|
|
tracing::info!("Main thread acquired mutex, waking up thread");
|
|
assert!(latch.is_waiting(), "Latch should be waiting now");
|
|
|
|
latch.wake();
|
|
tracing::info!("Main thread woke up thread");
|
|
}
|
|
assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0");
|
|
barrier.wait();
|
|
assert_eq!(
|
|
count.load(Ordering::Relaxed),
|
|
1,
|
|
"Count should be 1 after waking up"
|
|
);
|
|
|
|
thread.join().expect("Thread should join successfully");
|
|
assert_eq!(
|
|
count.load(Ordering::Relaxed),
|
|
2,
|
|
"Count should be 2 after thread has finished"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_atomic_latch() {
|
|
let latch = AtomicLatch::new();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
unsafe {
|
|
assert!(!latch.probe());
|
|
AtomicLatch::set_raw(&latch);
|
|
}
|
|
assert_eq!(latch.get(), AtomicLatch::SET);
|
|
assert!(latch.probe());
|
|
latch.unset();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
}
|
|
|
|
#[test]
|
|
fn core_latch_sleep() {
|
|
let latch = AtomicLatch::new();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
latch.set_sleeping();
|
|
assert_eq!(latch.get(), AtomicLatch::SLEEPING);
|
|
unsafe {
|
|
assert!(!latch.probe());
|
|
assert!(AtomicLatch::set(&latch));
|
|
}
|
|
assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING);
|
|
assert!(latch.probe());
|
|
latch.reset();
|
|
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
|
}
|
|
|
|
#[test]
|
|
fn nop_latch() {
|
|
assert!(
|
|
core::mem::size_of::<NopLatch>() == 0,
|
|
"NopLatch should be zero-sized"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn count_latch() {
|
|
let latch = CountLatch::new(ptr::null());
|
|
assert_eq!(latch.count(), 0);
|
|
latch.increment();
|
|
assert_eq!(latch.count(), 1);
|
|
assert!(!latch.probe());
|
|
latch.increment();
|
|
assert_eq!(latch.count(), 2);
|
|
assert!(!latch.probe());
|
|
|
|
unsafe {
|
|
Latch::set_raw(&latch);
|
|
}
|
|
assert!(!latch.probe());
|
|
assert_eq!(latch.count(), 1);
|
|
|
|
unsafe {
|
|
Latch::set_raw(&latch);
|
|
}
|
|
assert!(latch.probe());
|
|
assert_eq!(latch.count(), 0);
|
|
}
|
|
|
|
#[test]
|
|
#[traced_test]
|
|
fn mutex_latch() {
|
|
let latch = Arc::new(MutexLatch::new());
|
|
assert!(!latch.probe());
|
|
latch.set();
|
|
assert!(latch.probe());
|
|
latch.reset();
|
|
assert!(!latch.probe());
|
|
|
|
// Test wait functionality
|
|
let latch_clone = latch.clone();
|
|
let handle = std::thread::spawn(move || {
|
|
tracing::info!("Thread waiting on latch");
|
|
latch_clone.wait_and_reset();
|
|
tracing::info!("Thread woke up from latch");
|
|
});
|
|
|
|
// Give the thread time to block
|
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
|
assert!(!latch.probe());
|
|
|
|
tracing::info!("Setting latch from main thread");
|
|
latch.set();
|
|
tracing::info!("Latch set, joining waiting thread");
|
|
handle.join().expect("Thread should join successfully");
|
|
}
|
|
}
|