cleanup and fix race

This commit is contained in:
Janis 2025-07-01 11:32:30 +02:00
parent 38ce1de3ac
commit 19ef21e2ef
8 changed files with 96 additions and 95 deletions

View file

@ -1,3 +1,5 @@
// This file is taken from [`chili`]
use std::{ use std::{
cell::UnsafeCell, cell::UnsafeCell,
ptr::NonNull, ptr::NonNull,
@ -15,6 +17,7 @@ enum State {
Taken, Taken,
} }
// taken from `std`
#[derive(Debug)] #[derive(Debug)]
#[repr(transparent)] #[repr(transparent)]
pub struct Parker { pub struct Parker {
@ -100,27 +103,45 @@ impl<T: Send> Receiver<T> {
} }
pub fn wait(&self) { pub fn wait(&self) {
match self.0.state.compare_exchange( loop {
State::Pending as u8, match self.0.state.compare_exchange(
State::Waiting as u8, State::Pending as u8,
Ordering::AcqRel, State::Waiting as u8,
Ordering::Acquire, Ordering::AcqRel,
) { Ordering::Acquire,
Ok(_) => { ) {
// SAFETY: Ok(_) => {
// The `waiting_thread` is set to the current thread's parker // SAFETY:
// before we park it. // The `waiting_thread` is set to the current thread's parker
unsafe { // before we park it.
let thread = self.0.waiting_thread.as_ref(); unsafe {
thread.park(); let thread = self.0.waiting_thread.as_ref();
thread.park();
}
// we might have been woken up because of a shared job.
// In that case, we need to check the state again.
if self
.0
.state
.compare_exchange(
State::Waiting as u8,
State::Pending as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
continue;
}
}
Err(state) if state == State::Ready as u8 => {
// The channel is ready, so we can return immediately.
return;
}
_ => {
panic!("Receiver is already waiting or consumed.");
} }
}
Err(state) if state == State::Ready as u8 => {
// The channel is ready, so we can return immediately.
return;
}
_ => {
panic!("Receiver is already waiting or consumed.");
} }
} }
} }
@ -144,22 +165,7 @@ impl<T: Send> Receiver<T> {
} }
pub fn recv(self) -> thread::Result<T> { pub fn recv(self) -> thread::Result<T> {
if self self.wait();
.0
.state
.compare_exchange(
State::Pending as u8,
State::Waiting as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
unsafe {
let thread = self.0.waiting_thread.as_ref();
thread.park();
}
}
// SAFETY: // SAFETY:
// To arrive here, either `state` is `State::Ready` or the above // To arrive here, either `state` is `State::Ready` or the above
@ -172,7 +178,14 @@ impl<T: Send> Receiver<T> {
} }
unsafe fn take(&self) -> thread::Result<T> { unsafe fn take(&self) -> thread::Result<T> {
unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() } let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() };
assert_eq!(
self.0.state.swap(State::Taken as u8, Ordering::Release),
State::Ready as u8
);
result
} }
} }

View file

@ -9,40 +9,17 @@ use std::{
use alloc::collections::BTreeMap; use alloc::collections::BTreeMap;
use async_task::Runnable; use async_task::Runnable;
use crossbeam_utils::CachePadded;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use crate::{ use crate::{
channel::{Parker, Sender}, channel::{Parker, Sender},
heartbeat::HeartbeatList, heartbeat::HeartbeatList,
job::{HeapJob, Job2 as Job, SharedJob, StackJob}, job::{HeapJob, Job2 as Job, SharedJob, StackJob},
latch::{AsCoreLatch, MutexLatch, NopLatch}, latch::NopLatch,
util::DropGuard, util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread}, workerthread::{HeartbeatThread, WorkerThread},
}; };
pub struct Heartbeat {
pub latch: MutexLatch,
}
impl Heartbeat {
pub fn new() -> NonNull<CachePadded<Self>> {
let ptr = Box::new(CachePadded::new(Self {
latch: MutexLatch::new(),
}));
Box::into_non_null(ptr)
}
pub fn is_pending(&self) -> bool {
self.latch.as_core_latch().poll_heartbeat()
}
pub fn is_sleeping(&self) -> bool {
self.latch.as_core_latch().is_sleeping()
}
}
pub struct Context { pub struct Context {
shared: Mutex<Shared>, shared: Mutex<Shared>,
pub shared_job: Condvar, pub shared_job: Condvar,

View file

@ -12,7 +12,7 @@ use std::{
use parking_lot::Mutex; use parking_lot::Mutex;
use crate::{channel::Parker, latch::WorkerLatch}; use crate::channel::Parker;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HeartbeatList { pub struct HeartbeatList {

View file

@ -7,23 +7,15 @@ use core::{
ptr::{self, NonNull}, ptr::{self, NonNull},
sync::atomic::Ordering, sync::atomic::Ordering,
}; };
use std::{ use std::cell::Cell;
cell::Cell,
marker::PhantomData,
mem::MaybeUninit,
ops::DerefMut,
sync::atomic::{AtomicU8, AtomicU32, AtomicUsize},
};
use alloc::boxed::Box; use alloc::boxed::Box;
use parking_lot::{Condvar, Mutex};
use parking_lot_core::SpinWait; use parking_lot_core::SpinWait;
use crate::{ use crate::{
WorkerThread, WorkerThread,
channel::{Parker, Sender}, channel::{Parker, Sender},
latch::{Probe, WorkerLatch}, util::{SmallBox, TaggedAtomicPtr},
util::{DropGuard, SmallBox, TaggedAtomicPtr},
}; };
#[repr(u8)] #[repr(u8)]

View file

@ -103,15 +103,8 @@ impl WorkerThread {
} }
}; };
let ra = if !job.is_shared() { let ra = if let Some(recv) = job.take_receiver() {
tracing::trace!("join_heartbeat: job is not shared, running a() inline"); match self.wait_until_recv(recv) {
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
self.pop_back();
// a is allowed to panic here, because we already finished b.
unsafe { a.unwrap()() }
} else {
match self.wait_until_shared_job(&job) {
Some(t) => crate::util::unwrap_or_panic(t), Some(t) => crate::util::unwrap_or_panic(t),
None => { None => {
tracing::trace!( tracing::trace!(
@ -122,6 +115,14 @@ impl WorkerThread {
unsafe { a.unwrap()() } unsafe { a.unwrap()() }
} }
} }
} else {
self.pop_back();
unsafe {
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.unwrap()()
}
}; };
(ra, rb) (ra, rb)

View file

@ -2,19 +2,11 @@ use core::{
marker::PhantomData, marker::PhantomData,
sync::atomic::{AtomicUsize, Ordering}, sync::atomic::{AtomicUsize, Ordering},
}; };
use std::{ use std::sync::atomic::{AtomicPtr, AtomicU8};
cell::UnsafeCell,
mem,
ops::DerefMut,
sync::{
Arc,
atomic::{AtomicPtr, AtomicU8},
},
};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use crate::{WorkerThread, channel::Parker, context::Context}; use crate::channel::Parker;
pub trait Latch { pub trait Latch {
unsafe fn set_raw(this: *const Self); unsafe fn set_raw(this: *const Self);
@ -430,7 +422,7 @@ impl WorkerLatch {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{ptr, sync::Barrier}; use std::{ptr, sync::Arc};
use tracing_test::traced_test; use tracing_test::traced_test;

View file

@ -14,7 +14,7 @@ use crate::{
channel::Sender, channel::Sender,
context::Context, context::Context,
job::{HeapJob, Job2 as Job}, job::{HeapJob, Job2 as Job},
latch::{CountLatch, Probe, WorkerLatch}, latch::{CountLatch, Probe},
util::{DropGuard, SendPtr}, util::{DropGuard, SendPtr},
workerthread::WorkerThread, workerthread::WorkerThread,
}; };

View file

@ -10,10 +10,10 @@ use std::{
use crossbeam_utils::CachePadded; use crossbeam_utils::CachePadded;
use crate::{ use crate::{
channel::Receiver,
context::Context, context::Context,
heartbeat::OwnedHeartbeatReceiver, heartbeat::OwnedHeartbeatReceiver,
job::{Job2 as Job, JobQueue as JobList, SharedJob}, job::{Job2 as Job, JobQueue as JobList, SharedJob},
latch::Probe,
util::DropGuard, util::DropGuard,
}; };
@ -304,7 +304,7 @@ impl HeartbeatThread {
impl WorkerThread { impl WorkerThread {
#[tracing::instrument(level = "trace", skip(self))] #[tracing::instrument(level = "trace", skip(self))]
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> { pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> {
let recv = (*job).take_receiver()?; let recv = (*job).take_receiver().unwrap();
let mut out = recv.poll(); let mut out = recv.poll();
@ -321,6 +321,32 @@ impl WorkerThread {
out out
} }
#[tracing::instrument(level = "trace", skip_all)]
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> Option<std::thread::Result<T>> {
if self
.context
.shared()
.jobs
.remove(&self.heartbeat.id())
.is_some()
{
tracing::trace!("reclaiming shared job");
return None;
}
while recv.is_empty() {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
} else {
break;
}
}
Some(recv.recv())
}
#[tracing::instrument(level = "trace", skip_all)] #[tracing::instrument(level = "trace", skip_all)]
pub fn wait_until_pred<F>(&self, mut pred: F) pub fn wait_until_pred<F>(&self, mut pred: F)
where where