cleanup and fix race

This commit is contained in:
Janis 2025-07-01 11:32:30 +02:00
parent 38ce1de3ac
commit 19ef21e2ef
8 changed files with 96 additions and 95 deletions

View file

@ -1,3 +1,5 @@
// This file is taken from [`chili`]
use std::{
cell::UnsafeCell,
ptr::NonNull,
@ -15,6 +17,7 @@ enum State {
Taken,
}
// taken from `std`
#[derive(Debug)]
#[repr(transparent)]
pub struct Parker {
@ -100,27 +103,45 @@ impl<T: Send> Receiver<T> {
}
pub fn wait(&self) {
match self.0.state.compare_exchange(
State::Pending as u8,
State::Waiting as u8,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
// SAFETY:
// The `waiting_thread` is set to the current thread's parker
// before we park it.
unsafe {
let thread = self.0.waiting_thread.as_ref();
thread.park();
loop {
match self.0.state.compare_exchange(
State::Pending as u8,
State::Waiting as u8,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
// SAFETY:
// The `waiting_thread` is set to the current thread's parker
// before we park it.
unsafe {
let thread = self.0.waiting_thread.as_ref();
thread.park();
}
// we might have been woken up because of a shared job.
// In that case, we need to check the state again.
if self
.0
.state
.compare_exchange(
State::Waiting as u8,
State::Pending as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
continue;
}
}
Err(state) if state == State::Ready as u8 => {
// The channel is ready, so we can return immediately.
return;
}
_ => {
panic!("Receiver is already waiting or consumed.");
}
}
Err(state) if state == State::Ready as u8 => {
// The channel is ready, so we can return immediately.
return;
}
_ => {
panic!("Receiver is already waiting or consumed.");
}
}
}
@ -144,22 +165,7 @@ impl<T: Send> Receiver<T> {
}
pub fn recv(self) -> thread::Result<T> {
if self
.0
.state
.compare_exchange(
State::Pending as u8,
State::Waiting as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
unsafe {
let thread = self.0.waiting_thread.as_ref();
thread.park();
}
}
self.wait();
// SAFETY:
// To arrive here, either `state` is `State::Ready` or the above
@ -172,7 +178,14 @@ impl<T: Send> Receiver<T> {
}
unsafe fn take(&self) -> thread::Result<T> {
unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() }
let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() };
assert_eq!(
self.0.state.swap(State::Taken as u8, Ordering::Release),
State::Ready as u8
);
result
}
}

View file

@ -9,40 +9,17 @@ use std::{
use alloc::collections::BTreeMap;
use async_task::Runnable;
use crossbeam_utils::CachePadded;
use parking_lot::{Condvar, Mutex};
use crate::{
channel::{Parker, Sender},
heartbeat::HeartbeatList,
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
latch::{AsCoreLatch, MutexLatch, NopLatch},
latch::NopLatch,
util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread},
};
pub struct Heartbeat {
pub latch: MutexLatch,
}
impl Heartbeat {
pub fn new() -> NonNull<CachePadded<Self>> {
let ptr = Box::new(CachePadded::new(Self {
latch: MutexLatch::new(),
}));
Box::into_non_null(ptr)
}
pub fn is_pending(&self) -> bool {
self.latch.as_core_latch().poll_heartbeat()
}
pub fn is_sleeping(&self) -> bool {
self.latch.as_core_latch().is_sleeping()
}
}
pub struct Context {
shared: Mutex<Shared>,
pub shared_job: Condvar,

View file

@ -12,7 +12,7 @@ use std::{
use parking_lot::Mutex;
use crate::{channel::Parker, latch::WorkerLatch};
use crate::channel::Parker;
#[derive(Debug, Clone)]
pub struct HeartbeatList {

View file

@ -7,23 +7,15 @@ use core::{
ptr::{self, NonNull},
sync::atomic::Ordering,
};
use std::{
cell::Cell,
marker::PhantomData,
mem::MaybeUninit,
ops::DerefMut,
sync::atomic::{AtomicU8, AtomicU32, AtomicUsize},
};
use std::cell::Cell;
use alloc::boxed::Box;
use parking_lot::{Condvar, Mutex};
use parking_lot_core::SpinWait;
use crate::{
WorkerThread,
channel::{Parker, Sender},
latch::{Probe, WorkerLatch},
util::{DropGuard, SmallBox, TaggedAtomicPtr},
util::{SmallBox, TaggedAtomicPtr},
};
#[repr(u8)]

View file

@ -103,15 +103,8 @@ impl WorkerThread {
}
};
let ra = if !job.is_shared() {
tracing::trace!("join_heartbeat: job is not shared, running a() inline");
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
self.pop_back();
// a is allowed to panic here, because we already finished b.
unsafe { a.unwrap()() }
} else {
match self.wait_until_shared_job(&job) {
let ra = if let Some(recv) = job.take_receiver() {
match self.wait_until_recv(recv) {
Some(t) => crate::util::unwrap_or_panic(t),
None => {
tracing::trace!(
@ -122,6 +115,14 @@ impl WorkerThread {
unsafe { a.unwrap()() }
}
}
} else {
self.pop_back();
unsafe {
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.unwrap()()
}
};
(ra, rb)

View file

@ -2,19 +2,11 @@ use core::{
marker::PhantomData,
sync::atomic::{AtomicUsize, Ordering},
};
use std::{
cell::UnsafeCell,
mem,
ops::DerefMut,
sync::{
Arc,
atomic::{AtomicPtr, AtomicU8},
},
};
use std::sync::atomic::{AtomicPtr, AtomicU8};
use parking_lot::{Condvar, Mutex};
use crate::{WorkerThread, channel::Parker, context::Context};
use crate::channel::Parker;
pub trait Latch {
unsafe fn set_raw(this: *const Self);
@ -430,7 +422,7 @@ impl WorkerLatch {
#[cfg(test)]
mod tests {
use std::{ptr, sync::Barrier};
use std::{ptr, sync::Arc};
use tracing_test::traced_test;

View file

@ -14,7 +14,7 @@ use crate::{
channel::Sender,
context::Context,
job::{HeapJob, Job2 as Job},
latch::{CountLatch, Probe, WorkerLatch},
latch::{CountLatch, Probe},
util::{DropGuard, SendPtr},
workerthread::WorkerThread,
};

View file

@ -10,10 +10,10 @@ use std::{
use crossbeam_utils::CachePadded;
use crate::{
channel::Receiver,
context::Context,
heartbeat::OwnedHeartbeatReceiver,
job::{Job2 as Job, JobQueue as JobList, SharedJob},
latch::Probe,
util::DropGuard,
};
@ -304,7 +304,7 @@ impl HeartbeatThread {
impl WorkerThread {
#[tracing::instrument(level = "trace", skip(self))]
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> {
let recv = (*job).take_receiver()?;
let recv = (*job).take_receiver().unwrap();
let mut out = recv.poll();
@ -321,6 +321,32 @@ impl WorkerThread {
out
}
#[tracing::instrument(level = "trace", skip_all)]
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> Option<std::thread::Result<T>> {
if self
.context
.shared()
.jobs
.remove(&self.heartbeat.id())
.is_some()
{
tracing::trace!("reclaiming shared job");
return None;
}
while recv.is_empty() {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
} else {
break;
}
}
Some(recv.recv())
}
#[tracing::instrument(level = "trace", skip_all)]
pub fn wait_until_pred<F>(&self, mut pred: F)
where