cleanup and fix race
This commit is contained in:
parent
38ce1de3ac
commit
19ef21e2ef
|
@ -1,3 +1,5 @@
|
||||||
|
// This file is taken from [`chili`]
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
cell::UnsafeCell,
|
cell::UnsafeCell,
|
||||||
ptr::NonNull,
|
ptr::NonNull,
|
||||||
|
@ -15,6 +17,7 @@ enum State {
|
||||||
Taken,
|
Taken,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// taken from `std`
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
pub struct Parker {
|
pub struct Parker {
|
||||||
|
@ -100,27 +103,45 @@ impl<T: Send> Receiver<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait(&self) {
|
pub fn wait(&self) {
|
||||||
match self.0.state.compare_exchange(
|
loop {
|
||||||
State::Pending as u8,
|
match self.0.state.compare_exchange(
|
||||||
State::Waiting as u8,
|
State::Pending as u8,
|
||||||
Ordering::AcqRel,
|
State::Waiting as u8,
|
||||||
Ordering::Acquire,
|
Ordering::AcqRel,
|
||||||
) {
|
Ordering::Acquire,
|
||||||
Ok(_) => {
|
) {
|
||||||
// SAFETY:
|
Ok(_) => {
|
||||||
// The `waiting_thread` is set to the current thread's parker
|
// SAFETY:
|
||||||
// before we park it.
|
// The `waiting_thread` is set to the current thread's parker
|
||||||
unsafe {
|
// before we park it.
|
||||||
let thread = self.0.waiting_thread.as_ref();
|
unsafe {
|
||||||
thread.park();
|
let thread = self.0.waiting_thread.as_ref();
|
||||||
|
thread.park();
|
||||||
|
}
|
||||||
|
|
||||||
|
// we might have been woken up because of a shared job.
|
||||||
|
// In that case, we need to check the state again.
|
||||||
|
if self
|
||||||
|
.0
|
||||||
|
.state
|
||||||
|
.compare_exchange(
|
||||||
|
State::Waiting as u8,
|
||||||
|
State::Pending as u8,
|
||||||
|
Ordering::AcqRel,
|
||||||
|
Ordering::Acquire,
|
||||||
|
)
|
||||||
|
.is_ok()
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(state) if state == State::Ready as u8 => {
|
||||||
|
// The channel is ready, so we can return immediately.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("Receiver is already waiting or consumed.");
|
||||||
}
|
}
|
||||||
}
|
|
||||||
Err(state) if state == State::Ready as u8 => {
|
|
||||||
// The channel is ready, so we can return immediately.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
panic!("Receiver is already waiting or consumed.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -144,22 +165,7 @@ impl<T: Send> Receiver<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn recv(self) -> thread::Result<T> {
|
pub fn recv(self) -> thread::Result<T> {
|
||||||
if self
|
self.wait();
|
||||||
.0
|
|
||||||
.state
|
|
||||||
.compare_exchange(
|
|
||||||
State::Pending as u8,
|
|
||||||
State::Waiting as u8,
|
|
||||||
Ordering::AcqRel,
|
|
||||||
Ordering::Acquire,
|
|
||||||
)
|
|
||||||
.is_ok()
|
|
||||||
{
|
|
||||||
unsafe {
|
|
||||||
let thread = self.0.waiting_thread.as_ref();
|
|
||||||
thread.park();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SAFETY:
|
// SAFETY:
|
||||||
// To arrive here, either `state` is `State::Ready` or the above
|
// To arrive here, either `state` is `State::Ready` or the above
|
||||||
|
@ -172,7 +178,14 @@ impl<T: Send> Receiver<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn take(&self) -> thread::Result<T> {
|
unsafe fn take(&self) -> thread::Result<T> {
|
||||||
unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() }
|
let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() };
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
self.0.state.swap(State::Taken as u8, Ordering::Release),
|
||||||
|
State::Ready as u8
|
||||||
|
);
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,40 +9,17 @@ use std::{
|
||||||
use alloc::collections::BTreeMap;
|
use alloc::collections::BTreeMap;
|
||||||
|
|
||||||
use async_task::Runnable;
|
use async_task::Runnable;
|
||||||
use crossbeam_utils::CachePadded;
|
|
||||||
use parking_lot::{Condvar, Mutex};
|
use parking_lot::{Condvar, Mutex};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
channel::{Parker, Sender},
|
channel::{Parker, Sender},
|
||||||
heartbeat::HeartbeatList,
|
heartbeat::HeartbeatList,
|
||||||
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
|
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
|
||||||
latch::{AsCoreLatch, MutexLatch, NopLatch},
|
latch::NopLatch,
|
||||||
util::DropGuard,
|
util::DropGuard,
|
||||||
workerthread::{HeartbeatThread, WorkerThread},
|
workerthread::{HeartbeatThread, WorkerThread},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct Heartbeat {
|
|
||||||
pub latch: MutexLatch,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Heartbeat {
|
|
||||||
pub fn new() -> NonNull<CachePadded<Self>> {
|
|
||||||
let ptr = Box::new(CachePadded::new(Self {
|
|
||||||
latch: MutexLatch::new(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
Box::into_non_null(ptr)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_pending(&self) -> bool {
|
|
||||||
self.latch.as_core_latch().poll_heartbeat()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_sleeping(&self) -> bool {
|
|
||||||
self.latch.as_core_latch().is_sleeping()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Context {
|
pub struct Context {
|
||||||
shared: Mutex<Shared>,
|
shared: Mutex<Shared>,
|
||||||
pub shared_job: Condvar,
|
pub shared_job: Condvar,
|
||||||
|
|
|
@ -12,7 +12,7 @@ use std::{
|
||||||
|
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
use crate::{channel::Parker, latch::WorkerLatch};
|
use crate::channel::Parker;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct HeartbeatList {
|
pub struct HeartbeatList {
|
||||||
|
|
|
@ -7,23 +7,15 @@ use core::{
|
||||||
ptr::{self, NonNull},
|
ptr::{self, NonNull},
|
||||||
sync::atomic::Ordering,
|
sync::atomic::Ordering,
|
||||||
};
|
};
|
||||||
use std::{
|
use std::cell::Cell;
|
||||||
cell::Cell,
|
|
||||||
marker::PhantomData,
|
|
||||||
mem::MaybeUninit,
|
|
||||||
ops::DerefMut,
|
|
||||||
sync::atomic::{AtomicU8, AtomicU32, AtomicUsize},
|
|
||||||
};
|
|
||||||
|
|
||||||
use alloc::boxed::Box;
|
use alloc::boxed::Box;
|
||||||
use parking_lot::{Condvar, Mutex};
|
|
||||||
use parking_lot_core::SpinWait;
|
use parking_lot_core::SpinWait;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
WorkerThread,
|
WorkerThread,
|
||||||
channel::{Parker, Sender},
|
channel::{Parker, Sender},
|
||||||
latch::{Probe, WorkerLatch},
|
util::{SmallBox, TaggedAtomicPtr},
|
||||||
util::{DropGuard, SmallBox, TaggedAtomicPtr},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
|
|
|
@ -103,15 +103,8 @@ impl WorkerThread {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let ra = if !job.is_shared() {
|
let ra = if let Some(recv) = job.take_receiver() {
|
||||||
tracing::trace!("join_heartbeat: job is not shared, running a() inline");
|
match self.wait_until_recv(recv) {
|
||||||
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
|
|
||||||
self.pop_back();
|
|
||||||
|
|
||||||
// a is allowed to panic here, because we already finished b.
|
|
||||||
unsafe { a.unwrap()() }
|
|
||||||
} else {
|
|
||||||
match self.wait_until_shared_job(&job) {
|
|
||||||
Some(t) => crate::util::unwrap_or_panic(t),
|
Some(t) => crate::util::unwrap_or_panic(t),
|
||||||
None => {
|
None => {
|
||||||
tracing::trace!(
|
tracing::trace!(
|
||||||
|
@ -122,6 +115,14 @@ impl WorkerThread {
|
||||||
unsafe { a.unwrap()() }
|
unsafe { a.unwrap()() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
self.pop_back();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||||
|
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||||
|
a.unwrap()()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
(ra, rb)
|
(ra, rb)
|
||||||
|
|
|
@ -2,19 +2,11 @@ use core::{
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
sync::atomic::{AtomicUsize, Ordering},
|
sync::atomic::{AtomicUsize, Ordering},
|
||||||
};
|
};
|
||||||
use std::{
|
use std::sync::atomic::{AtomicPtr, AtomicU8};
|
||||||
cell::UnsafeCell,
|
|
||||||
mem,
|
|
||||||
ops::DerefMut,
|
|
||||||
sync::{
|
|
||||||
Arc,
|
|
||||||
atomic::{AtomicPtr, AtomicU8},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
use parking_lot::{Condvar, Mutex};
|
use parking_lot::{Condvar, Mutex};
|
||||||
|
|
||||||
use crate::{WorkerThread, channel::Parker, context::Context};
|
use crate::channel::Parker;
|
||||||
|
|
||||||
pub trait Latch {
|
pub trait Latch {
|
||||||
unsafe fn set_raw(this: *const Self);
|
unsafe fn set_raw(this: *const Self);
|
||||||
|
@ -430,7 +422,7 @@ impl WorkerLatch {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::{ptr, sync::Barrier};
|
use std::{ptr, sync::Arc};
|
||||||
|
|
||||||
use tracing_test::traced_test;
|
use tracing_test::traced_test;
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ use crate::{
|
||||||
channel::Sender,
|
channel::Sender,
|
||||||
context::Context,
|
context::Context,
|
||||||
job::{HeapJob, Job2 as Job},
|
job::{HeapJob, Job2 as Job},
|
||||||
latch::{CountLatch, Probe, WorkerLatch},
|
latch::{CountLatch, Probe},
|
||||||
util::{DropGuard, SendPtr},
|
util::{DropGuard, SendPtr},
|
||||||
workerthread::WorkerThread,
|
workerthread::WorkerThread,
|
||||||
};
|
};
|
||||||
|
|
|
@ -10,10 +10,10 @@ use std::{
|
||||||
use crossbeam_utils::CachePadded;
|
use crossbeam_utils::CachePadded;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
channel::Receiver,
|
||||||
context::Context,
|
context::Context,
|
||||||
heartbeat::OwnedHeartbeatReceiver,
|
heartbeat::OwnedHeartbeatReceiver,
|
||||||
job::{Job2 as Job, JobQueue as JobList, SharedJob},
|
job::{Job2 as Job, JobQueue as JobList, SharedJob},
|
||||||
latch::Probe,
|
|
||||||
util::DropGuard,
|
util::DropGuard,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -304,7 +304,7 @@ impl HeartbeatThread {
|
||||||
impl WorkerThread {
|
impl WorkerThread {
|
||||||
#[tracing::instrument(level = "trace", skip(self))]
|
#[tracing::instrument(level = "trace", skip(self))]
|
||||||
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> {
|
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> {
|
||||||
let recv = (*job).take_receiver()?;
|
let recv = (*job).take_receiver().unwrap();
|
||||||
|
|
||||||
let mut out = recv.poll();
|
let mut out = recv.poll();
|
||||||
|
|
||||||
|
@ -321,6 +321,32 @@ impl WorkerThread {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(level = "trace", skip_all)]
|
||||||
|
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> Option<std::thread::Result<T>> {
|
||||||
|
if self
|
||||||
|
.context
|
||||||
|
.shared()
|
||||||
|
.jobs
|
||||||
|
.remove(&self.heartbeat.id())
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
tracing::trace!("reclaiming shared job");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
while recv.is_empty() {
|
||||||
|
if let Some(job) = self.find_work() {
|
||||||
|
unsafe {
|
||||||
|
SharedJob::execute(job, self);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(recv.recv())
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(level = "trace", skip_all)]
|
#[tracing::instrument(level = "trace", skip_all)]
|
||||||
pub fn wait_until_pred<F>(&self, mut pred: F)
|
pub fn wait_until_pred<F>(&self, mut pred: F)
|
||||||
where
|
where
|
||||||
|
|
Loading…
Reference in a new issue