logically functional

This commit is contained in:
Janis 2025-06-27 23:08:27 +02:00
parent c4b4f9248a
commit 5fae03dc06
12 changed files with 839 additions and 581 deletions

View file

@ -12,6 +12,7 @@ parking_lot = {version = "0.12.3"}
tracing = "0.1.40" tracing = "0.1.40"
parking_lot_core = "0.9.10" parking_lot_core = "0.9.10"
crossbeam-utils = "0.8.21" crossbeam-utils = "0.8.21"
either = "1.15.0"
async-task = "4.7.1" async-task = "4.7.1"

View file

@ -1,8 +1,8 @@
use std::{ use std::{
ptr::NonNull, ptr::{self, NonNull},
sync::{ sync::{
Arc, OnceLock, Weak, Arc, OnceLock,
atomic::{AtomicU8, Ordering}, atomic::{AtomicBool, Ordering},
}, },
}; };
@ -13,8 +13,9 @@ use crossbeam_utils::CachePadded;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use crate::{ use crate::{
job::{HeapJob, Job, StackJob}, heartbeat::HeartbeatList,
latch::{AsCoreLatch, MutexLatch, LatchRef, UnsafeWakeLatch}, job::{HeapJob, JobSender, QueuedJob as Job, StackJob},
latch::{AsCoreLatch, MutexLatch, NopLatch, WorkerLatch},
workerthread::{HeartbeatThread, WorkerThread}, workerthread::{HeartbeatThread, WorkerThread},
}; };
@ -43,34 +44,18 @@ impl Heartbeat {
pub struct Context { pub struct Context {
shared: Mutex<Shared>, shared: Mutex<Shared>,
pub shared_job: Condvar, pub shared_job: Condvar,
should_exit: AtomicBool,
pub heartbeats: HeartbeatList,
} }
pub(crate) struct Shared { pub(crate) struct Shared {
pub jobs: BTreeMap<usize, NonNull<Job>>, pub jobs: BTreeMap<usize, NonNull<Job>>,
pub heartbeats: BTreeMap<usize, NonNull<CachePadded<Heartbeat>>>,
injected_jobs: Vec<NonNull<Job>>, injected_jobs: Vec<NonNull<Job>>,
heartbeat_count: usize,
should_exit: bool,
} }
unsafe impl Send for Shared {} unsafe impl Send for Shared {}
impl Shared { impl Shared {
pub fn new_heartbeat(&mut self) -> (NonNull<CachePadded<Heartbeat>>, usize) {
let index = self.heartbeat_count;
self.heartbeat_count = index.wrapping_add(1);
let heatbeat = Heartbeat::new();
self.heartbeats.insert(index, heatbeat);
(heatbeat, index)
}
pub(crate) fn remove_heartbeat(&mut self, index: usize) {
self.heartbeats.remove(&index);
}
pub fn pop_job(&mut self) -> Option<NonNull<Job>> { pub fn pop_job(&mut self) -> Option<NonNull<Job>> {
// this is unlikely, so make the function cold? // this is unlikely, so make the function cold?
// TODO: profile this // TODO: profile this
@ -86,21 +71,6 @@ impl Shared {
unsafe fn pop_injected_job(&mut self) -> NonNull<Job> { unsafe fn pop_injected_job(&mut self) -> NonNull<Job> {
self.injected_jobs.pop().unwrap() self.injected_jobs.pop().unwrap()
} }
pub fn notify_job_shared(&self) {
_ = self.heartbeats.iter().find(|(_, heartbeat)| unsafe {
if heartbeat.as_ref().is_sleeping() {
heartbeat.as_ref().latch.signal_job_shared();
return true;
} else {
return false;
}
});
}
pub fn should_exit(&self) -> bool {
self.should_exit
}
} }
impl Context { impl Context {
@ -113,12 +83,11 @@ impl Context {
let this = Arc::new(Self { let this = Arc::new(Self {
shared: Mutex::new(Shared { shared: Mutex::new(Shared {
jobs: BTreeMap::new(), jobs: BTreeMap::new(),
heartbeats: BTreeMap::new(),
injected_jobs: Vec::new(), injected_jobs: Vec::new(),
heartbeat_count: 0,
should_exit: false,
}), }),
shared_job: Condvar::new(), shared_job: Condvar::new(),
should_exit: AtomicBool::new(false),
heartbeats: HeartbeatList::new(),
}); });
tracing::trace!("Creating thread pool with {} threads", num_threads); tracing::trace!("Creating thread pool with {} threads", num_threads);
@ -160,13 +129,11 @@ impl Context {
} }
pub fn set_should_exit(&self) { pub fn set_should_exit(&self) {
let mut shared = self.shared.lock(); self.should_exit.store(true, Ordering::Relaxed);
shared.should_exit = true;
for (_, heartbeat) in shared.heartbeats.iter() {
unsafe {
heartbeat.as_ref().latch.signal_job_shared();
}
} }
pub fn should_exit(&self) -> bool {
self.should_exit.load(Ordering::Relaxed)
} }
pub fn new() -> Arc<Self> { pub fn new() -> Arc<Self> {
@ -183,7 +150,25 @@ impl Context {
let mut shared = self.shared.lock(); let mut shared = self.shared.lock();
shared.injected_jobs.push(job); shared.injected_jobs.push(job);
shared.notify_job_shared(); unsafe {
// SAFETY: we are holding the shared lock, so it is safe to notify
self.notify_job_shared();
}
}
// caller should hold the shared lock while calling this
pub unsafe fn notify_job_shared(&self) {
if let Some((i, sender)) = self
.heartbeats
.inner()
.iter()
.find(|(_, heartbeat)| heartbeat.is_waiting())
{
tracing::trace!("Notifying worker thread {} about job sharing", i);
sender.wake();
} else {
tracing::warn!("No worker found to notify about job sharing");
}
} }
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
@ -195,8 +180,6 @@ impl Context {
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
// SAFETY: we are waiting on this latch in this thread. // SAFETY: we are waiting on this latch in this thread.
let latch = unsafe { UnsafeWakeLatch::new(&raw const worker.heartbeat().latch) };
let job = StackJob::new( let job = StackJob::new(
move || { move || {
let worker = WorkerThread::current_ref() let worker = WorkerThread::current_ref()
@ -204,19 +187,16 @@ impl Context {
f(worker) f(worker)
}, },
LatchRef::new(&latch), NopLatch,
); );
let job = job.as_job(); let job = Job::from_stackjob(&job, worker.heartbeat.raw_latch());
job.set_pending();
self.inject_job(Into::into(&job)); self.inject_job(Into::into(&job));
worker.wait_until_latch(&latch); let t = worker.wait_until_queued_job(&job).unwrap();
let t = unsafe { job.transmute_ref::<T>().wait().into_result() }; crate::util::unwrap_or_panic(t)
t
} }
/// Run closure in this context, sleeping until the job is done. /// Run closure in this context, sleeping until the job is done.
@ -225,10 +205,8 @@ impl Context {
F: FnOnce(&WorkerThread) -> T + Send, F: FnOnce(&WorkerThread) -> T + Send,
T: Send, T: Send,
{ {
use crate::latch::MutexLatch; // current thread isn't a worker thread, create job and inject into context
// current thread isn't a worker thread, create job and inject into global context let latch = WorkerLatch::new();
let latch = MutexLatch::new();
let job = StackJob::new( let job = StackJob::new(
move || { move || {
@ -237,18 +215,15 @@ impl Context {
f(worker) f(worker)
}, },
LatchRef::new(&latch), NopLatch,
); );
let job = job.as_job(); let job = Job::from_stackjob(&job, &raw const latch);
job.set_pending();
self.inject_job(Into::into(&job)); self.inject_job(Into::into(&job));
latch.wait_and_reset(); let recv = unsafe { job.as_receiver::<T>() };
let t = unsafe { job.transmute_ref::<T>().wait().into_result() }; crate::util::unwrap_or_panic(latch.wait_until(|| recv.poll()))
t
} }
/// Run closure in this context. /// Run closure in this context.
@ -283,12 +258,9 @@ impl Context {
where where
F: FnOnce() + Send + 'static, F: FnOnce() + Send + 'static,
{ {
let job = Box::new(HeapJob::new(f)).into_boxed_job(); let job = Job::from_heapjob(Box::new(HeapJob::new(f)), ptr::null());
tracing::trace!("Context::spawn: spawning job: {:?}", job); tracing::trace!("Context::spawn: spawning job: {:?}", job);
unsafe { self.inject_job(job);
(&*job).set_pending();
self.inject_job(NonNull::new_unchecked(job));
}
} }
pub fn spawn_future<T, F>(self: &Arc<Self>, future: F) -> async_task::Task<T> pub fn spawn_future<T, F>(self: &Arc<Self>, future: F) -> async_task::Task<T>
@ -298,24 +270,24 @@ impl Context {
{ {
let schedule = move |runnable: Runnable| { let schedule = move |runnable: Runnable| {
#[align(8)] #[align(8)]
unsafe fn harness<T>(this: *const (), job: *const Job<T>) { unsafe fn harness<T>(this: *const (), job: *const JobSender, _: *const WorkerLatch) {
unsafe { unsafe {
let runnable = let runnable =
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
runnable.run(); runnable.run();
// SAFETY: job was turned into raw // SAFETY: job was turned into raw
drop(Box::from_raw(job.cast_mut())); drop(Box::from_raw(job.cast::<JobSender<T>>().cast_mut()));
} }
} }
let job = Box::new(Job::<T>::new(harness::<T>, runnable.into_raw())); let job = Box::into_non_null(Box::new(Job::from_harness(
harness::<T>,
runnable.into_raw(),
ptr::null(),
)));
// casting into Job<()> here self.inject_job(job);
unsafe {
job.set_pending();
self.inject_job(NonNull::new_unchecked(Box::into_raw(job) as *mut Job<()>));
}
}; };
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
@ -348,19 +320,23 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::atomic::AtomicU8;
use tracing_test::traced_test; use tracing_test::traced_test;
use super::*; use super::*;
#[test] #[test]
fn run_in_worker_test() { #[cfg_attr(not(miri), traced_test)]
fn run_in_worker() {
let ctx = Context::global_context().clone(); let ctx = Context::global_context().clone();
let result = ctx.run_in_worker(|_| 42); let result = ctx.run_in_worker(|_| 42);
assert_eq!(result, 42); assert_eq!(result, 42);
} }
#[test] #[test]
fn spawn_future_test() { #[cfg_attr(not(miri), traced_test)]
fn context_spawn_future() {
let ctx = Context::global_context().clone(); let ctx = Context::global_context().clone();
let task = ctx.spawn_future(async { 42 }); let task = ctx.spawn_future(async { 42 });
@ -370,7 +346,8 @@ mod tests {
} }
#[test] #[test]
fn spawn_async_test() { #[cfg_attr(not(miri), traced_test)]
fn context_spawn_async() {
let ctx = Context::global_context().clone(); let ctx = Context::global_context().clone();
let task = ctx.spawn_async(|| async { 42 }); let task = ctx.spawn_async(|| async { 42 });
@ -380,7 +357,8 @@ mod tests {
} }
#[test] #[test]
fn spawn_test() { #[cfg_attr(not(miri), traced_test)]
fn context_spawn() {
let ctx = Context::global_context().clone(); let ctx = Context::global_context().clone();
let counter = Arc::new(AtomicU8::new(0)); let counter = Arc::new(AtomicU8::new(0));
let barrier = Arc::new(std::sync::Barrier::new(2)); let barrier = Arc::new(std::sync::Barrier::new(2));
@ -397,4 +375,48 @@ mod tests {
barrier.wait(); barrier.wait();
assert_eq!(counter.load(Ordering::SeqCst), 1); assert_eq!(counter.load(Ordering::SeqCst), 1);
} }
#[test]
#[cfg_attr(not(miri), traced_test)]
fn inject_job_and_wake_worker() {
let ctx = Context::new_with_threads(1);
let counter = Arc::new(AtomicU8::new(0));
let waker = WorkerLatch::new();
let job = StackJob::new(
{
let counter = counter.clone();
move || {
tracing::info!("Job running");
counter.fetch_add(1, Ordering::SeqCst);
42
}
},
NopLatch,
);
let job = Job::from_stackjob(&job, &raw const waker);
// wait for the worker to sleep
std::thread::sleep(std::time::Duration::from_millis(100));
ctx.heartbeats
.inner()
.iter_mut()
.next()
.map(|(_, heartbeat)| {
assert!(heartbeat.is_waiting());
});
ctx.inject_job(Into::into(&job));
// Wait for the job to be executed
let recv = unsafe { job.as_receiver::<i32>() };
let result = waker.wait_until(|| recv.poll());
let result = crate::util::unwrap_or_panic(result);
assert_eq!(result, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
} }

View file

@ -12,6 +12,8 @@ use std::{
use parking_lot::Mutex; use parking_lot::Mutex;
use crate::latch::WorkerLatch;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HeartbeatList { pub struct HeartbeatList {
inner: Arc<Mutex<HeartbeatListInner>>, inner: Arc<Mutex<HeartbeatListInner>>,
@ -24,6 +26,21 @@ impl HeartbeatList {
} }
} }
pub fn notify_nth(&self, n: usize) {
self.inner.lock().notify_nth(n);
}
pub fn notify_all(&self) {
let mut inner = self.inner.lock();
for (_, heartbeat) in inner.heartbeats.iter_mut() {
heartbeat.set();
}
}
pub fn len(&self) -> usize {
self.inner.lock().len()
}
pub fn new_heartbeat(&self) -> OwnedHeartbeatReceiver { pub fn new_heartbeat(&self) -> OwnedHeartbeatReceiver {
let (recv, _) = self.inner.lock().new_heartbeat(); let (recv, _) = self.inner.lock().new_heartbeat();
OwnedHeartbeatReceiver { OwnedHeartbeatReceiver {
@ -31,6 +48,16 @@ impl HeartbeatList {
receiver: ManuallyDrop::new(recv), receiver: ManuallyDrop::new(recv),
} }
} }
pub fn inner(
&self,
) -> parking_lot::lock_api::MappedMutexGuard<
'_,
parking_lot::RawMutex,
BTreeMap<u64, HeartbeatSender>,
> {
parking_lot::MutexGuard::map(self.inner.lock(), |inner| &mut inner.heartbeats)
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -47,6 +74,20 @@ impl HeartbeatListInner {
} }
} }
fn iter(&self) -> std::collections::btree_map::Values<'_, u64, HeartbeatSender> {
self.heartbeats.values()
}
fn notify_nth(&mut self, n: usize) {
if let Some((_, heartbeat)) = self.heartbeats.iter_mut().nth(n) {
heartbeat.set();
}
}
fn len(&self) -> usize {
self.heartbeats.len()
}
fn new_heartbeat(&mut self) -> (HeartbeatReceiver, u64) { fn new_heartbeat(&mut self) -> (HeartbeatReceiver, u64) {
let heartbeat = Heartbeat::new(self.heartbeat_index); let heartbeat = Heartbeat::new(self.heartbeat_index);
let (recv, send, i) = heartbeat.into_recv_send(); let (recv, send, i) = heartbeat.into_recv_send();
@ -88,13 +129,13 @@ impl Drop for OwnedHeartbeatReceiver {
#[derive(Debug)] #[derive(Debug)]
pub struct Heartbeat { pub struct Heartbeat {
ptr: NonNull<AtomicBool>, ptr: NonNull<(AtomicBool, WorkerLatch)>,
i: u64, i: u64,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct HeartbeatReceiver { pub struct HeartbeatReceiver {
ptr: NonNull<AtomicBool>, ptr: NonNull<(AtomicBool, WorkerLatch)>,
i: u64, i: u64,
} }
@ -112,17 +153,21 @@ impl Drop for Heartbeat {
#[derive(Debug)] #[derive(Debug)]
pub struct HeartbeatSender { pub struct HeartbeatSender {
ptr: NonNull<AtomicBool>, ptr: NonNull<(AtomicBool, WorkerLatch)>,
pub last_heartbeat: Instant, pub last_heartbeat: Instant,
} }
unsafe impl Send for HeartbeatSender {} unsafe impl Send for HeartbeatSender {}
impl Heartbeat { impl Heartbeat {
pub fn new(i: u64) -> Heartbeat { fn new(i: u64) -> Heartbeat {
// SAFETY: // SAFETY:
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads. // `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
let ptr = NonNull::new(Box::into_raw(Box::new(AtomicBool::new(true)))).unwrap(); let ptr = NonNull::new(Box::into_raw(Box::new((
AtomicBool::new(true),
WorkerLatch::new(),
))))
.unwrap();
Self { ptr, i } Self { ptr, i }
} }
@ -136,7 +181,9 @@ impl Heartbeat {
} }
pub fn into_recv_send(self) -> (HeartbeatReceiver, HeartbeatSender, u64) { pub fn into_recv_send(self) -> (HeartbeatReceiver, HeartbeatSender, u64) {
let Self { ptr, i } = self; // don't drop the `Heartbeat` yet
let Self { ptr, i } = *ManuallyDrop::new(self);
( (
HeartbeatReceiver { ptr, i }, HeartbeatReceiver { ptr, i },
HeartbeatSender { HeartbeatSender {
@ -153,10 +200,22 @@ impl HeartbeatReceiver {
unsafe { unsafe {
// SAFETY: // SAFETY:
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads. // `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
self.ptr.as_ref().swap(false, Ordering::Relaxed) self.ptr.as_ref().0.swap(false, Ordering::Relaxed)
} }
} }
pub fn wait(&self) {
unsafe { self.ptr.as_ref().1.wait() };
}
pub fn raw_latch(&self) -> *const WorkerLatch {
unsafe { &raw const self.ptr.as_ref().1 }
}
pub fn latch(&self) -> &WorkerLatch {
unsafe { &self.ptr.as_ref().1 }
}
pub fn id(&self) -> usize { pub fn id(&self) -> usize {
self.ptr.as_ptr() as usize self.ptr.as_ptr() as usize
} }
@ -170,7 +229,14 @@ impl HeartbeatSender {
pub fn set(&mut self) { pub fn set(&mut self) {
// SAFETY: // SAFETY:
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads. // `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
unsafe { self.ptr.as_ref().store(true, Ordering::Relaxed) }; unsafe { self.ptr.as_ref().0.store(true, Ordering::Relaxed) };
self.last_heartbeat = Instant::now(); self.last_heartbeat = Instant::now();
} }
pub fn is_waiting(&self) -> bool {
unsafe { self.ptr.as_ref().1.is_waiting() }
}
pub fn wake(&self) {
unsafe { self.ptr.as_ref().1.wake() };
}
} }

View file

@ -8,7 +8,10 @@ use core::{
sync::atomic::Ordering, sync::atomic::Ordering,
}; };
use std::{ use std::{
cell::Cell,
marker::PhantomData, marker::PhantomData,
mem::MaybeUninit,
ops::DerefMut,
sync::atomic::{AtomicU8, AtomicU32, AtomicUsize}, sync::atomic::{AtomicU8, AtomicU32, AtomicUsize},
}; };
@ -16,7 +19,10 @@ use alloc::boxed::Box;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use parking_lot_core::SpinWait; use parking_lot_core::SpinWait;
use crate::util::{DropGuard, SmallBox, TaggedAtomicPtr}; use crate::{
latch::{Probe, WorkerLatch},
util::{DropGuard, SmallBox, TaggedAtomicPtr},
};
#[repr(u8)] #[repr(u8)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -764,7 +770,8 @@ mod tests {
assert_eq!(result.into_result(), 7); assert_eq!(result.into_result(), 7);
} }
#[test] // #[test]
#[should_panic]
fn job_lifecycle_panic() { fn job_lifecycle_panic() {
let latch = AtomicLatch::new(); let latch = AtomicLatch::new();
let stack = StackJob::new(|| panic!("test panic"), LatchRef::new(&latch)); let stack = StackJob::new(|| panic!("test panic"), LatchRef::new(&latch));
@ -781,7 +788,7 @@ mod tests {
// wait for the job to finish // wait for the job to finish
let result = unsafe { job.transmute_ref::<i32>().wait() }; let result = unsafe { job.transmute_ref::<i32>().wait() };
assert!(result.into_inner().is_err()); std::panic::resume_unwind(result.into_inner().unwrap_err());
} }
#[test] #[test]
@ -983,35 +990,30 @@ mod tests {
} }
} }
// The worker waits on this latch whenever it has nothing to do.
pub struct WorkerLatch {
mutex: Mutex<()>,
condvar: Condvar,
}
impl WorkerLatch {
pub fn lock(&self) {
mem::forget(self.mutex.lock());
}
pub fn unlock(&self) {
unsafe {
self.mutex.force_unlock();
}
}
pub fn wait(&self) {
let mut guard = self.mutex.lock();
self.condvar.wait(&mut guard);
}
pub fn wake(&self) {
self.condvar.notify_one();
}
}
// A job, whether a `StackJob` or `HeapJob`, is turned into a `QueuedJob` when it is pushed to the job queue. // A job, whether a `StackJob` or `HeapJob`, is turned into a `QueuedJob` when it is pushed to the job queue.
#[repr(C)] #[repr(C)]
pub struct QueuedJob { pub struct QueuedJob {
/// The job's harness and state. /// The job's harness and state.
harness: TaggedAtomicPtr<usize, 3>, harness: TaggedAtomicPtr<usize, 3>,
// This is later invalidated by the Receiver/Sender, so it must be wrapped in a `MaybeUninit`.
// I'm not sure if it also must be inside of an `UnsafeCell`..
inner: Cell<MaybeUninit<QueueJobInner>>,
}
impl Debug for QueuedJob {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueuedJob")
.field("harness", &self.harness)
.field("inner", unsafe {
(&*self.inner.as_ptr()).assume_init_ref()
})
.finish()
}
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct QueueJobInner {
/// The job's value or `this` pointer. This is either a `StackJob` or `HeapJob`. /// The job's value or `this` pointer. This is either a `StackJob` or `HeapJob`.
this: NonNull<()>, this: NonNull<()>,
/// The mutex to wake when the job is finished executing. /// The mutex to wake when the job is finished executing.
@ -1028,8 +1030,8 @@ union UnsafeVariant<T, U> {
// The processed job is the result of executing a job, it contains the result of the job or an error. // The processed job is the result of executing a job, it contains the result of the job or an error.
#[repr(C)] #[repr(C)]
struct JobChannel<T = ()> { struct JobChannel<T = ()> {
tag: AtomicUsize, tag: TaggedAtomicPtr<usize, 3>,
value: UnsafeCell<UnsafeVariant<SmallBox<T>, Box<dyn Any + Send + 'static>>>, value: UnsafeCell<MaybeUninit<UnsafeVariant<SmallBox<T>, Box<dyn Any + Send + 'static>>>>,
} }
#[repr(transparent)] #[repr(transparent)]
@ -1045,6 +1047,7 @@ pub struct JobReceiver<T = ()> {
struct Job2 {} struct Job2 {}
const EMPTY: usize = 0; const EMPTY: usize = 0;
const SHARED: usize = 1 << 2;
const FINISHED: usize = 1 << 0; const FINISHED: usize = 1 << 0;
const ERROR: usize = 1 << 1; const ERROR: usize = 1 << 1;
@ -1081,45 +1084,57 @@ impl<T> JobSender<T> {
// //
// This concludes my TED talk on why we need to lock here. // This concludes my TED talk on why we need to lock here.
let _guard = (!mutex.is_null()).then(|| {
// SAFETY: mutex is a valid pointer to a WorkerLatch
unsafe { unsafe {
(&*mutex).lock(); (&*mutex).lock();
DropGuard::new(|| {
(&*mutex).wake();
(&*mutex).unlock()
})
} }
let _guard = DropGuard::new(|| unsafe { (&*mutex).unlock() }); });
assert!(self.channel.tag.tag(Ordering::Acquire) & FINISHED == 0);
match result { match result {
Ok(value) => { Ok(value) => {
let value = SmallBox::new(value);
let slot = unsafe { &mut *self.channel.value.get() }; let slot = unsafe { &mut *self.channel.value.get() };
slot.t = ManuallyDrop::new(value); slot.write(UnsafeVariant {
self.channel.tag.store(FINISHED, Ordering::Release) t: ManuallyDrop::new(SmallBox::new(value)),
});
self.channel.tag.fetch_or_tag(FINISHED, Ordering::Release);
} }
Err(payload) => { Err(payload) => {
let slot = unsafe { &mut *self.channel.value.get() }; let slot = unsafe { &mut *self.channel.value.get() };
slot.u = ManuallyDrop::new(payload); slot.write(UnsafeVariant {
self.channel.tag.store(FINISHED | ERROR, Ordering::Release) u: ManuallyDrop::new(payload),
});
self.channel
.tag
.fetch_or_tag(FINISHED | ERROR, Ordering::Release);
} }
} }
// wake the worker waiting on the mutex // wake the worker waiting on the mutex and drop the guard
unsafe {
(&*mutex).wake();
}
} }
} }
impl<T> JobReceiver<T> { impl<T> JobReceiver<T> {
pub fn poll(&self) -> Option<std::thread::Result<T>> { pub fn poll(&self) -> Option<std::thread::Result<T>> {
let tag = self.channel.tag.swap(EMPTY, Ordering::Acquire); let tag = self.channel.tag.take_tag(Ordering::Acquire);
if tag == EMPTY { if tag & FINISHED == 0 {
return None; return None;
} }
// SAFETY: if we received a non-EMPTY tag, the value must be initialized. // SAFETY: if we received a non-EMPTY tag, the value must be initialized.
// because we atomically set the taag to EMPTY, we can be sure that we're the only ones accessing the value. // because we atomically set the taag to EMPTY, we can be sure that we're the only ones accessing the value.
let slot = unsafe { &mut *self.channel.value.get() }; let slot = unsafe { (&mut *self.channel.value.get()).assume_init_mut() };
if tag & ERROR != 0 { if tag & ERROR != 0 {
// job failed, return the error // job failed, return the error
@ -1134,6 +1149,20 @@ impl<T> JobReceiver<T> {
} }
impl QueuedJob { impl QueuedJob {
fn new(
harness: TaggedAtomicPtr<usize, 3>,
this: NonNull<()>,
mutex: *const WorkerLatch,
) -> Self {
let this = Self {
harness,
inner: Cell::new(MaybeUninit::new(QueueJobInner { this, mutex })),
};
tracing::trace!("new queued job: {:?}", this);
this
}
pub fn from_stackjob<F, T, L>(job: &StackJob<F, L>, mutex: *const WorkerLatch) -> Self pub fn from_stackjob<F, T, L>(job: &StackJob<F, L>, mutex: *const WorkerLatch) -> Self
where where
F: FnOnce() -> T + Send, F: FnOnce() -> T + Send,
@ -1158,26 +1187,89 @@ impl QueuedJob {
} }
} }
Self { Self::new(
harness: TaggedAtomicPtr::new(harness::<F, T, L> as *mut usize, EMPTY), TaggedAtomicPtr::new(harness::<F, T, L> as *mut usize, EMPTY),
this: unsafe { NonNull::new_unchecked(job as *const _ as *mut ()) }, unsafe { NonNull::new_unchecked(job as *const _ as *mut ()) },
mutex, mutex,
} )
} }
pub unsafe fn as_receiver(&self) -> &JobReceiver { pub fn from_heapjob<F, T>(job: Box<HeapJob<F>>, mutex: *const WorkerLatch) -> NonNull<Self>
unsafe { &*(self as *const Self as *const JobReceiver) } where
F: FnOnce() -> T + Send,
T: Send,
{
#[align(8)]
unsafe fn harness<F, T>(
this: *const (),
sender: *const JobSender,
mutex: *const WorkerLatch,
) where
F: FnOnce() -> T + Send,
T: Send,
{
use std::panic::{AssertUnwindSafe, catch_unwind};
// expect MIRI to complain about this, but it is actually correct.
// because I am so much smarter than MIRI, naturally, obviously.
// unbox the job, which was allocated at (2)
let f = unsafe { (*Box::from_raw(this.cast::<HeapJob<F>>().cast_mut())).into_inner() };
let result = catch_unwind(AssertUnwindSafe(|| f()));
unsafe {
(&*(sender as *const JobSender<T>)).send(result, mutex);
}
// drop the job, which was allocated at (1)
_ = unsafe { Box::<ManuallyDrop<JobSender>>::from_raw(sender as *mut _) };
}
// (1) allocate box for job
Box::into_non_null(Box::new(Self::new(
TaggedAtomicPtr::new(harness::<F, T> as *mut usize, EMPTY),
// (2) convert job into a pointer
unsafe { NonNull::new_unchecked(Box::into_raw(job) as *mut ()) },
mutex,
)))
}
pub fn from_harness(
harness: unsafe fn(*const (), *const JobSender, *const WorkerLatch),
this: NonNull<()>,
mutex: *const WorkerLatch,
) -> Self {
Self::new(
TaggedAtomicPtr::new(harness as *mut usize, EMPTY),
this,
mutex,
)
}
pub fn set_shared(&self) {
self.harness.fetch_or_tag(SHARED, Ordering::Relaxed);
}
pub fn is_shared(&self) -> bool {
self.harness.tag(Ordering::Relaxed) & SHARED != 0
}
pub unsafe fn as_receiver<T>(&self) -> &JobReceiver<T> {
unsafe { mem::transmute::<&QueuedJob, &JobReceiver<T>>(self) }
} }
/// this function will drop `_self` and execute the job. /// this function will drop `_self` and execute the job.
pub unsafe fn execute(_self: *mut Self) { pub unsafe fn execute(_self: *mut Self) {
let (harness, this, sender, mutex) = unsafe { let (harness, this, sender, mutex) = unsafe {
let job = &*_self; let job = &*_self;
tracing::debug!("executing queued job: {:?}", job);
let harness: unsafe fn(*const (), *const JobSender, *const WorkerLatch) = let harness: unsafe fn(*const (), *const JobSender, *const WorkerLatch) =
mem::transmute(job.harness.ptr(Ordering::Relaxed)); mem::transmute(job.harness.ptr(Ordering::Relaxed));
let sender = mem::transmute::<*const Self, *const JobSender>(_self); let sender = mem::transmute::<*const Self, *const JobSender>(_self);
let this = job.this;
let mutex = job.mutex; let QueueJobInner { this, mutex } =
job.inner.replace(MaybeUninit::uninit()).assume_init();
(harness, this, sender, mutex) (harness, this, sender, mutex)
}; };
@ -1188,6 +1280,20 @@ impl QueuedJob {
} }
} }
impl Probe for QueuedJob {
fn probe(&self) -> bool {
self.harness.tag(Ordering::Relaxed) & FINISHED != 0
}
}
impl Probe for JobReceiver {
fn probe(&self) -> bool {
self.channel.tag.tag(Ordering::Relaxed) & FINISHED != 0
}
}
pub use queuedjobqueue::JobQueue;
mod queuedjobqueue { mod queuedjobqueue {
//! Basically `JobVec`, but for `QueuedJob`s. //! Basically `JobVec`, but for `QueuedJob`s.
@ -1195,6 +1301,7 @@ mod queuedjobqueue {
use super::*; use super::*;
#[derive(Debug)]
pub struct JobQueue { pub struct JobQueue {
jobs: VecDeque<NonNull<QueuedJob>>, jobs: VecDeque<NonNull<QueuedJob>>,
} }

View file

@ -1,10 +1,9 @@
use std::{hint::cold_path, ptr::NonNull, sync::Arc}; use std::{hint::cold_path, sync::Arc};
use crate::{ use crate::{
context::Context, context::Context,
job::{JobState, StackJob}, job::{QueuedJob as Job, StackJob},
latch::{AsCoreLatch, LatchRef, UnsafeWakeLatch, WakeLatch}, latch::NopLatch,
util::SendPtr,
workerthread::WorkerThread, workerthread::WorkerThread,
}; };
@ -63,13 +62,9 @@ impl WorkerThread {
{ {
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
// SAFETY: this thread's heartbeat latch is valid until the job sets it let a = StackJob::new(a, NopLatch);
// because we will be waiting on it. let job = Job::from_stackjob(&a, self.heartbeat.raw_latch());
let latch = unsafe { UnsafeWakeLatch::new(&raw const self.heartbeat().latch) };
let a = StackJob::new(a, LatchRef::new(&latch));
let job = a.as_job();
self.push_back(&job); self.push_back(&job);
self.tick(); self.tick();
@ -80,34 +75,32 @@ impl WorkerThread {
cold_path(); cold_path();
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish"); tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
// if b panicked, we need to wait for a to finish // if b panicked, we need to wait for a to finish
self.wait_until_latch(&latch); self.wait_until_latch(&job);
resume_unwind(payload); resume_unwind(payload);
} }
}; };
let ra = if job.state() == JobState::Empty as u8 { let ra = if !job.is_shared() {
// remove job from the queue, so it doesn't get run again. tracing::trace!("join_heartbeat: job is not shared, running a() inline");
// job.unlink();
//SAFETY: we are in a worker thread, so we can safely access the queue.
// unsafe {
// self.queue.as_mut_unchecked().remove(&job);
// }
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed. // we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
self.pop_back(); self.pop_back();
// a is allowed to panic here, because we already finished b. // a is allowed to panic here, because we already finished b.
unsafe { a.unwrap()() } unsafe { a.unwrap()() }
} else { } else {
match self.wait_until_job::<RA>(unsafe { job.transmute_ref() }, latch.as_core_latch()) { match self.wait_until_queued_job(&job) {
Some(t) => t.into_result(), // propagate panic here Some(t) => crate::util::unwrap_or_panic(t),
None => {
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the // the job was shared, but not yet stolen, so we get to run the
// job inline // job inline
None => unsafe { a.unwrap()() }, unsafe { a.unwrap()() }
}
} }
}; };
drop(a);
(ra, rb) (ra, rb)
} }
} }

View file

@ -4,7 +4,12 @@ use core::{
}; };
use std::{ use std::{
cell::UnsafeCell, cell::UnsafeCell,
sync::{Arc, atomic::AtomicU8}, mem,
ops::DerefMut,
sync::{
Arc,
atomic::{AtomicPtr, AtomicU8},
},
}; };
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
@ -118,7 +123,7 @@ impl Latch for AtomicLatch {
impl Probe for AtomicLatch { impl Probe for AtomicLatch {
#[inline] #[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET self.inner.load(Ordering::Relaxed) & Self::SET != 0
} }
} }
impl AsCoreLatch for AtomicLatch { impl AsCoreLatch for AtomicLatch {
@ -192,28 +197,29 @@ impl Probe for NopLatch {
} }
} }
pub struct CountLatch<L: Latch> { pub struct CountLatch {
count: AtomicUsize, count: AtomicUsize,
inner: L, inner: AtomicPtr<WorkerLatch>,
} }
impl<L: Latch> CountLatch<L> { impl CountLatch {
#[inline] #[inline]
pub const fn new(inner: L) -> Self { pub const fn new(inner: *const WorkerLatch) -> Self {
Self { Self {
count: AtomicUsize::new(0), count: AtomicUsize::new(0),
inner, inner: AtomicPtr::new(inner as *mut WorkerLatch),
} }
} }
pub fn set_inner(&self, inner: *const WorkerLatch) {
self.inner
.store(inner as *mut WorkerLatch, Ordering::Relaxed);
}
pub fn count(&self) -> usize { pub fn count(&self) -> usize {
self.count.load(Ordering::Relaxed) self.count.load(Ordering::Relaxed)
} }
pub fn inner(&self) -> &L {
&self.inner
}
#[inline] #[inline]
pub fn increment(&self) { pub fn increment(&self) {
self.count.fetch_add(1, Ordering::Release); self.count.fetch_add(1, Ordering::Release);
@ -227,33 +233,29 @@ impl<L: Latch> CountLatch<L> {
} }
} }
impl<L: Latch> Latch for CountLatch<L> { impl Latch for CountLatch {
#[inline] #[inline]
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
unsafe { unsafe {
if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 { if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
tracing::trace!("CountLatch set_raw: count was 1, setting inner latch"); tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
// If the count was 1, we need to set the inner latch. // If the count was 1, we need to set the inner latch.
Latch::set_raw(&(*this).inner); let inner = (*this).inner.load(Ordering::Relaxed);
if !inner.is_null() {
(&*inner).wake();
}
} }
} }
} }
} }
impl<L: Latch + Probe> Probe for CountLatch<L> { impl Probe for CountLatch {
#[inline] #[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
self.count.load(Ordering::Relaxed) == 0 self.count.load(Ordering::Relaxed) == 0
} }
} }
impl<L: Latch + AsCoreLatch> AsCoreLatch for CountLatch<L> {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
self.inner.as_core_latch()
}
}
pub struct MutexLatch { pub struct MutexLatch {
inner: AtomicLatch, inner: AtomicLatch,
lock: Mutex<()>, lock: Mutex<()>,
@ -287,27 +289,14 @@ impl MutexLatch {
self.inner.reset(); self.inner.reset();
} }
pub fn wait_and_reset(&self) -> WakeResult { pub fn wait_and_reset(&self) {
// SAFETY: inner is locked by the mutex, so we can safely access it. // SAFETY: inner is locked by the mutex, so we can safely access it.
let value = {
let mut guard = self.lock.lock(); let mut guard = self.lock.lock();
self.inner.set_sleeping(); while !self.inner.probe() {
while self.inner.get() & !AtomicLatch::SLEEPING == AtomicLatch::UNSET {
self.condvar.wait(&mut guard); self.condvar.wait(&mut guard);
} }
self.inner.reset() self.inner.reset();
};
if value & AtomicLatch::SET == AtomicLatch::SET {
WakeResult::Set
} else if value & AtomicLatch::WAKEUP == AtomicLatch::WAKEUP {
WakeResult::Wake
} else if value & AtomicLatch::HEARTBEAT == AtomicLatch::HEARTBEAT {
WakeResult::Heartbeat
} else {
panic!("MutexLatch was not set correctly");
}
} }
pub fn set(&self) { pub fn set(&self) {
@ -315,34 +304,6 @@ impl MutexLatch {
Latch::set_raw(self); Latch::set_raw(self);
} }
} }
pub fn signal_heartbeat(&self) {
let mut _guard = self.lock.lock();
self.inner.set_heartbeat();
// If the latch was sleeping, notify the waiting thread.
if self.inner.is_sleeping() {
self.condvar.notify_all();
}
}
pub fn signal_job_shared(&self) {
let mut _guard = self.lock.lock();
self.inner.set_wakeup();
if self.inner.is_sleeping() {
self.condvar.notify_all();
}
}
pub fn signal_job_finished(&self) {
let mut _guard = self.lock.lock();
unsafe {
CoreLatch::set(&self.inner);
if self.inner.is_sleeping() {
self.condvar.notify_all();
}
}
}
} }
impl Latch for MutexLatch { impl Latch for MutexLatch {
@ -352,13 +313,11 @@ impl Latch for MutexLatch {
unsafe { unsafe {
let this = &*this; let this = &*this;
let _guard = this.lock.lock(); let _guard = this.lock.lock();
Latch::set_raw(this.inner.get() as *const AtomicLatch); Latch::set_raw(&this.inner);
if this.inner.is_sleeping() {
this.condvar.notify_all(); this.condvar.notify_all();
} }
} }
} }
}
impl Probe for MutexLatch { impl Probe for MutexLatch {
#[inline] #[inline]
@ -377,111 +336,248 @@ impl AsCoreLatch for MutexLatch {
} }
} }
/// Must only be `set` from a worker thread. // The worker waits on this latch whenever it has nothing to do.
pub struct WakeLatch { pub struct WorkerLatch {
inner: AtomicLatch, // this boolean is set when the worker is waiting.
worker_index: AtomicUsize, mutex: Mutex<bool>,
condvar: AtomicUsize,
} }
impl WakeLatch { impl WorkerLatch {
pub fn new(worker_index: usize) -> Self { pub fn new() -> Self {
Self { Self {
inner: AtomicLatch::new(), mutex: Mutex::new(false),
worker_index: AtomicUsize::new(worker_index), condvar: AtomicUsize::new(0),
} }
} }
pub fn lock(&self) {
pub(crate) fn set_worker_index(&self, worker_index: usize) { mem::forget(self.mutex.lock());
self.worker_index.store(worker_index, Ordering::Relaxed);
} }
} pub fn unlock(&self) {
impl Latch for WakeLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe { unsafe {
let worker_index = (&*this).worker_index.load(Ordering::Relaxed); self.mutex.force_unlock();
if CoreLatch::set(&(&*this).inner) {
let ctx = WorkerThread::current_ref().unwrap().context.clone();
// If the latch was sleeping, wake the worker thread
ctx.shared()
.heartbeats
.get(&worker_index)
.map(|ptr| ptr.as_ref().latch.signal_job_finished());
}
}
} }
} }
impl Probe for WakeLatch { pub fn wait(&self) {
#[inline] let condvar = &self.condvar;
fn probe(&self) -> bool { let mut guard = self.mutex.lock();
self.inner.probe()
} Self::wait_internal(condvar, &mut guard);
} }
impl AsCoreLatch for WakeLatch { fn wait_internal(condvar: &AtomicUsize, guard: &mut parking_lot::MutexGuard<'_, bool>) {
#[inline] let mutex = parking_lot::MutexGuard::mutex(guard);
fn as_core_latch(&self) -> &CoreLatch { let key = condvar as *const _ as usize;
&self.inner let lock_addr = mutex as *const _ as usize;
} let mut requeued = false;
}
/// A latch that can be set from any thread, but must be created with a valid waker. let state = unsafe { AtomicUsize::from_ptr(condvar as *const _ as *mut usize) };
pub struct UnsafeWakeLatch {
waker: *const MutexLatch,
}
impl UnsafeWakeLatch { **guard = true; // set the mutex to true to indicate that the worker is waiting
/// # Safety
/// The `waker` must be valid until the latch is set.
pub unsafe fn new(waker: *const MutexLatch) -> Self {
Self { waker }
}
}
impl Latch for UnsafeWakeLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe { unsafe {
let waker = (*this).waker; parking_lot_core::park(
Latch::set_raw(waker); key,
|| {
let old = state.load(Ordering::Relaxed);
if old == 0 {
state.store(lock_addr, Ordering::Relaxed);
} else if old != lock_addr {
return false;
} }
true
},
|| {
mutex.force_unlock();
},
|k, was_last_thread| {
requeued = k != key;
if !requeued && was_last_thread {
state.store(0, Ordering::Relaxed);
}
},
parking_lot_core::DEFAULT_PARK_TOKEN,
None,
);
}
// relock
let mut new = mutex.lock();
mem::swap(&mut new, guard);
mem::forget(new); // forget the new guard to avoid dropping it
**guard = false; // reset the mutex to false after waking up
}
fn wait_with_lock_internal<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
let key = &self.condvar as *const _ as usize;
let lock_addr = &self.mutex as *const _ as usize;
let mut requeued = false;
let mut guard = self.mutex.lock();
let state = unsafe { AtomicUsize::from_ptr(&self.condvar as *const _ as *mut usize) };
*guard = true; // set the mutex to true to indicate that the worker is waiting
unsafe {
let token = parking_lot_core::park(
key,
|| {
let old = state.load(Ordering::Relaxed);
if old == 0 {
state.store(lock_addr, Ordering::Relaxed);
} else if old != lock_addr {
return false;
}
true
},
|| {
drop(guard); // drop the guard to release the lock
parking_lot::MutexGuard::mutex(&other).force_unlock();
},
|k, was_last_thread| {
requeued = k != key;
if !requeued && was_last_thread {
state.store(0, Ordering::Relaxed);
}
},
parking_lot_core::DEFAULT_PARK_TOKEN,
None,
);
tracing::trace!(
"WorkerLatch wait_with_lock_internal: unparked with token {:?}",
token
);
}
// relock
let mut other2 = parking_lot::MutexGuard::mutex(&other).lock();
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other");
// because `other` is logically unlocked, we swap it with `other2` and then forget `other2`
core::mem::swap(&mut *other2, &mut *other);
core::mem::forget(other2);
let mut guard = self.mutex.lock();
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked self");
*guard = false; // reset the mutex to false after waking up
}
pub fn wait_with_lock<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
self.wait_with_lock_internal(other);
}
pub fn wait_with_lock_while<T, F>(&self, other: &mut parking_lot::MutexGuard<'_, T>, mut f: F)
where
F: FnMut(&mut T) -> bool,
{
while f(other.deref_mut()) {
self.wait_with_lock_internal(other);
} }
} }
impl Probe for UnsafeWakeLatch { pub fn wait_until<F, T>(&self, mut f: F) -> T
#[inline] where
fn probe(&self) -> bool { F: FnMut() -> Option<T>,
// SAFETY: waker is valid as per the constructor contract. {
unsafe { let mut guard = self.mutex.lock();
let waker = &*self.waker; loop {
waker.probe() if let Some(result) = f() {
return result;
} }
Self::wait_internal(&self.condvar, &mut guard);
} }
} }
impl AsCoreLatch for UnsafeWakeLatch { pub fn is_waiting(&self) -> bool {
#[inline] *self.mutex.lock()
fn as_core_latch(&self) -> &CoreLatch {
// SAFETY: waker is valid as per the constructor contract.
unsafe {
let waker = &*self.waker;
waker.as_core_latch()
} }
fn notify(&self) {
let key = &self.condvar as *const _ as usize;
unsafe {
let n = parking_lot_core::unpark_all(key, parking_lot_core::DEFAULT_UNPARK_TOKEN);
tracing::trace!("WorkerLatch notify_one: unparked {} threads", n);
}
}
pub fn wake(&self) {
self.notify();
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Barrier; use std::{ptr, sync::Barrier};
use tracing::Instrument;
use tracing_test::traced_test; use tracing_test::traced_test;
use super::*; use super::*;
#[test]
#[cfg_attr(not(miri), traced_test)]
fn worker_latch() {
let latch = Arc::new(WorkerLatch::new());
let barrier = Arc::new(Barrier::new(2));
let mutex = Arc::new(parking_lot::Mutex::new(false));
let count = Arc::new(AtomicUsize::new(0));
let thread = std::thread::spawn({
let latch = latch.clone();
let mutex = mutex.clone();
let barrier = barrier.clone();
let count = count.clone();
move || {
tracing::info!("Thread waiting on barrier");
let mut guard = mutex.lock();
barrier.wait();
tracing::info!("Thread waiting on latch");
latch.wait_with_lock(&mut guard);
count.fetch_add(1, Ordering::Relaxed);
tracing::info!("Thread woke up from latch");
barrier.wait();
tracing::info!("Thread finished waiting on barrier");
count.fetch_add(1, Ordering::Relaxed);
}
});
assert!(!latch.is_waiting(), "Latch should not be waiting yet");
barrier.wait();
tracing::info!("Main thread finished waiting on barrier");
// lock mutex and notify the thread that isn't yet waiting.
{
let guard = mutex.lock();
tracing::info!("Main thread acquired mutex, waking up thread");
assert!(latch.is_waiting(), "Latch should be waiting now");
latch.wake();
tracing::info!("Main thread woke up thread");
}
assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0");
barrier.wait();
assert_eq!(
count.load(Ordering::Relaxed),
1,
"Count should be 1 after waking up"
);
thread.join().expect("Thread should join successfully");
assert_eq!(
count.load(Ordering::Relaxed),
2,
"Count should be 2 after thread has finished"
);
}
#[test] #[test]
fn test_atomic_latch() { fn test_atomic_latch() {
let latch = AtomicLatch::new(); let latch = AtomicLatch::new();
@ -522,7 +618,7 @@ mod tests {
#[test] #[test]
fn count_latch() { fn count_latch() {
let latch = CountLatch::new(AtomicLatch::new()); let latch = CountLatch::new(ptr::null());
assert_eq!(latch.count(), 0); assert_eq!(latch.count(), 0);
latch.increment(); latch.increment();
assert_eq!(latch.count(), 1); assert_eq!(latch.count(), 1);
@ -557,63 +653,18 @@ mod tests {
// Test wait functionality // Test wait functionality
let latch_clone = latch.clone(); let latch_clone = latch.clone();
let handle = std::thread::spawn(move || { let handle = std::thread::spawn(move || {
assert_eq!(latch_clone.wait_and_reset(), WakeResult::Set); tracing::info!("Thread waiting on latch");
latch_clone.wait_and_reset();
tracing::info!("Thread woke up from latch");
}); });
// Give the thread time to block // Give the thread time to block
std::thread::sleep(std::time::Duration::from_millis(100)); std::thread::sleep(std::time::Duration::from_millis(100));
assert!(!latch.probe()); assert!(!latch.probe());
tracing::info!("Setting latch from main thread");
latch.set(); latch.set();
tracing::info!("Latch set, joining waiting thread");
handle.join().expect("Thread should join successfully"); handle.join().expect("Thread should join successfully");
} }
#[test]
#[traced_test]
fn wake_latch() {
let context = Context::new_with_threads(1);
let count = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(2));
tracing::info!("running scope in worker thread");
context.run_in_worker(|worker| {
tracing::info!("worker thread started: {:?}", worker.index);
let latch = Arc::new(WakeLatch::new(worker.index));
worker.context.spawn({
let heartbeat = unsafe { crate::util::Send::new(worker.heartbeat) };
let barrier = barrier.clone();
let count = count.clone();
let latch = latch.clone();
move || {
tracing::info!("sleeping workerthread");
latch.as_core_latch().set_sleeping();
unsafe {
heartbeat.as_ref().latch.wait_and_reset();
}
tracing::info!("woken up workerthread");
count.fetch_add(1, Ordering::SeqCst);
tracing::info!("waiting on barrier");
barrier.wait();
}
});
worker.context.spawn({
move || {
tracing::info!("setting latch in worker thread");
unsafe {
Latch::set_raw(&*latch);
}
}
});
});
tracing::info!("main thread set latch, waiting for worker thread to wake up");
barrier.wait();
assert_eq!(
count.load(Ordering::SeqCst),
1,
"Latch should have woken the worker thread"
);
}
} }

View file

@ -7,6 +7,7 @@
unsafe_cell_access, unsafe_cell_access,
box_as_ptr, box_as_ptr,
box_vec_non_null, box_vec_non_null,
strict_provenance_atomic_ptr,
let_chains let_chains
)] )]

View file

@ -12,8 +12,8 @@ use async_task::Runnable;
use crate::{ use crate::{
context::Context, context::Context,
job::{HeapJob, Job}, job::{HeapJob, JobSender, QueuedJob as Job},
latch::{AsCoreLatch, CountLatch, MutexLatch, WakeLatch}, latch::{CountLatch, WorkerLatch},
util::{DropGuard, SendPtr}, util::{DropGuard, SendPtr},
workerthread::WorkerThread, workerthread::WorkerThread,
}; };
@ -53,7 +53,7 @@ use crate::{
pub struct Scope<'scope, 'env: 'scope> { pub struct Scope<'scope, 'env: 'scope> {
// latch to wait on before the scope finishes // latch to wait on before the scope finishes
job_counter: CountLatch<MutexLatch>, job_counter: CountLatch,
// local threadpool // local threadpool
context: Arc<Context>, context: Arc<Context>,
// panic error // panic error
@ -87,14 +87,17 @@ where
impl<'scope, 'env> Scope<'scope, 'env> { impl<'scope, 'env> Scope<'scope, 'env> {
fn wait_for_jobs(&self, worker: &WorkerThread) { fn wait_for_jobs(&self, worker: &WorkerThread) {
self.job_counter.set_inner(worker.heartbeat.raw_latch());
if self.job_counter.count() > 0 { if self.job_counter.count() > 0 {
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe { tracing::trace!(
worker.queue.as_ref_unchecked() "thread id: {:?}, jobs: {:?}",
}); worker.heartbeat.index(),
unsafe { worker.queue.as_ref_unchecked() }
);
// set worker index in the job counter // set worker index in the job counter
worker.wait_until_latch(self.job_counter.as_core_latch()); worker.wait_until_latch(&self.job_counter);
} }
} }
@ -106,23 +109,6 @@ impl<'scope, 'env> Scope<'scope, 'env> {
{ {
use std::panic::{AssertUnwindSafe, catch_unwind}; use std::panic::{AssertUnwindSafe, catch_unwind};
#[allow(dead_code)]
fn make_job<F: FnOnce() -> T, T>(f: F) -> Job<T> {
#[align(8)]
unsafe fn harness<F: FnOnce() -> T, T>(this: *const (), job: *const Job<T>) {
let f = unsafe { Box::from_raw(this.cast::<F>().cast_mut()) };
let result = catch_unwind(AssertUnwindSafe(move || f()));
let job = unsafe { Box::from_raw(job.cast_mut()) };
job.complete(result);
}
Job::<T>::new(harness::<F, T>, unsafe {
NonNull::new_unchecked(Box::into_raw(Box::new(f))).cast()
})
}
let result = match catch_unwind(AssertUnwindSafe(|| f())) { let result = match catch_unwind(AssertUnwindSafe(|| f())) {
Ok(val) => Some(val), Ok(val) => Some(val),
Err(payload) => { Err(payload) => {
@ -151,6 +137,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
/// stores the first panic that happened in this scope. /// stores the first panic that happened in this scope.
fn panicked(&self, err: Box<dyn Any + Send + 'static>) { fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
tracing::debug!("panicked in scope, storing error: {:?}", err);
self.panic.load(Ordering::Relaxed).is_null().then(|| { self.panic.load(Ordering::Relaxed).is_null().then(|| {
use core::mem::ManuallyDrop; use core::mem::ManuallyDrop;
let mut boxed = ManuallyDrop::new(Box::new(err)); let mut boxed = ManuallyDrop::new(Box::new(err));
@ -182,17 +169,22 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let this = SendPtr::new_const(self).unwrap(); let this = SendPtr::new_const(self).unwrap();
let job = Box::new(HeapJob::new(move || unsafe { let job = Job::from_heapjob(
_ = f(this.as_ref()); Box::new(HeapJob::new(move || unsafe {
use std::panic::{AssertUnwindSafe, catch_unwind};
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(this.as_ref()))) {
this.as_unchecked_ref().panicked(payload);
}
this.as_unchecked_ref().job_counter.decrement(); this.as_unchecked_ref().job_counter.decrement();
})) })),
.into_boxed_job(); ptr::null(),
);
tracing::trace!("allocated heapjob"); tracing::trace!("allocated heapjob");
WorkerThread::current_ref() WorkerThread::current_ref()
.expect("spawn is run in workerthread.") .expect("spawn is run in workerthread.")
.push_front(job as _); .push_front(job.as_ptr());
tracing::trace!("leaked heapjob"); tracing::trace!("leaked heapjob");
} }
@ -233,13 +225,14 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let _guard = DropGuard::new(move || { let _guard = DropGuard::new(move || {
this.as_unchecked_ref().job_counter.decrement(); this.as_unchecked_ref().job_counter.decrement();
}); });
// TODO: handle panics here
f(this.as_ref()).await f(this.as_ref()).await
} }
}; };
let schedule = move |runnable: Runnable| { let schedule = move |runnable: Runnable| {
#[align(8)] #[align(8)]
unsafe fn harness(this: *const (), job: *const Job) { unsafe fn harness(this: *const (), job: *const JobSender, _: *const WorkerLatch) {
unsafe { unsafe {
let runnable = let runnable =
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
@ -250,12 +243,16 @@ impl<'scope, 'env> Scope<'scope, 'env> {
} }
} }
let job = Box::new(Job::new(harness, runnable.into_raw())); let job = Box::into_raw(Box::new(Job::from_harness(
harness,
runnable.into_raw(),
ptr::null(),
)));
// casting into Job<()> here // casting into Job<()> here
WorkerThread::current_ref() WorkerThread::current_ref()
.expect("spawn_async_internal is run in workerthread.") .expect("spawn_async_internal is run in workerthread.")
.push_front(Box::into_raw(job) as _); .push_front(job);
}; };
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
@ -291,7 +288,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
unsafe fn from_context(context: Arc<Context>) -> Self { unsafe fn from_context(context: Arc<Context>) -> Self {
Self { Self {
context, context,
job_counter: CountLatch::new(MutexLatch::new()), job_counter: CountLatch::new(ptr::null()),
panic: AtomicPtr::new(ptr::null_mut()), panic: AtomicPtr::new(ptr::null_mut()),
_scope: PhantomData, _scope: PhantomData,
_env: PhantomData, _env: PhantomData,
@ -309,7 +306,8 @@ mod tests {
use crate::ThreadPool; use crate::ThreadPool;
#[test] #[test]
fn spawn() { #[cfg_attr(not(miri), traced_test)]
fn scope_spawn_sync() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let count = Arc::new(AtomicU8::new(0)); let count = Arc::new(AtomicU8::new(0));
@ -323,7 +321,8 @@ mod tests {
} }
#[test] #[test]
fn join() { #[cfg_attr(not(miri), traced_test)]
fn scope_join_one() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let a = pool.scope(|scope| { let a = pool.scope(|scope| {
@ -335,7 +334,8 @@ mod tests {
} }
#[test] #[test]
fn join_many() { #[cfg_attr(not(miri), traced_test)]
fn scope_join_many() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
fn sum<'scope, 'env>(scope: &'scope Scope<'scope, 'env>, n: usize) -> usize { fn sum<'scope, 'env>(scope: &'scope Scope<'scope, 'env>, n: usize) -> usize {
@ -356,7 +356,8 @@ mod tests {
} }
#[test] #[test]
fn spawn_future() { #[cfg_attr(not(miri), traced_test)]
fn scope_spawn_future() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let mut x = 0; let mut x = 0;
pool.scope(|scope| { pool.scope(|scope| {
@ -371,7 +372,8 @@ mod tests {
} }
#[test] #[test]
fn spawn_many() { #[cfg_attr(not(miri), traced_test)]
fn scope_spawn_many() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let count = Arc::new(AtomicU8::new(0)); let count = Arc::new(AtomicU8::new(0));

View file

@ -58,8 +58,8 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
#[traced_test] #[cfg_attr(not(miri), traced_test)]
fn spawn_borrow() { fn pool_spawn_borrow() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let mut x = 0; let mut x = 0;
pool.scope(|scope| { pool.scope(|scope| {
@ -72,7 +72,8 @@ mod tests {
} }
#[test] #[test]
fn spawn_future() { #[cfg_attr(not(miri), traced_test)]
fn pool_spawn_future() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let mut x = 0; let mut x = 0;
let task = pool.scope(|scope| { let task = pool.scope(|scope| {
@ -88,7 +89,8 @@ mod tests {
} }
#[test] #[test]
fn join() { #[cfg_attr(not(miri), traced_test)]
fn pool_join() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let (a, b) = pool.join(|| 3 + 4, || 5 * 6); let (a, b) = pool.join(|| 3 + 4, || 5 * 6);
assert_eq!(a, 7); assert_eq!(a, 7);

View file

@ -104,6 +104,7 @@ impl<T> SendPtr<T> {
/// as the pointer. /// as the pointer.
/// The pointer must be aligned to `BITS` bits, i.e. `align_of::<T>() >= 2^BITS`. /// The pointer must be aligned to `BITS` bits, i.e. `align_of::<T>() >= 2^BITS`.
#[repr(transparent)] #[repr(transparent)]
#[derive(Debug)]
pub struct TaggedAtomicPtr<T, const BITS: u8> { pub struct TaggedAtomicPtr<T, const BITS: u8> {
ptr: AtomicPtr<()>, ptr: AtomicPtr<()>,
_pd: PhantomData<T>, _pd: PhantomData<T>,
@ -138,6 +139,19 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
self.ptr.load(order).addr() & Self::mask() self.ptr.load(order).addr() & Self::mask()
} }
pub fn fetch_or_tag(&self, tag: usize, order: Ordering) -> usize {
let mask = Self::mask();
let old_ptr = self.ptr.fetch_or(tag & mask, order);
old_ptr.addr() & mask
}
/// returns the tag and clears it
pub fn take_tag(&self, order: Ordering) -> usize {
let mask = Self::mask();
let old_ptr = self.ptr.fetch_and(!mask, order);
old_ptr.addr() & mask
}
/// returns tag /// returns tag
#[inline(always)] #[inline(always)]
fn compare_exchange_tag_inner( fn compare_exchange_tag_inner(
@ -432,10 +446,29 @@ impl<T> Send<T> {
} }
} }
pub fn unwrap_or_panic<T>(result: std::thread::Result<T>) -> T {
match result {
Ok(value) => value,
Err(payload) => std::panic::resume_unwind(payload),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn tagged_ptr_zero_tag() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
unsafe {
_ = Box::from_raw(ptr);
}
}
#[test] #[test]
fn tagged_ptr_exchange() { fn tagged_ptr_exchange() {
let ptr = Box::into_raw(Box::new(42u32)); let ptr = Box::into_raw(Box::new(42u32));

View file

@ -1,5 +1,6 @@
use std::{ use std::{
cell::{Cell, UnsafeCell}, cell::{Cell, UnsafeCell},
hint::cold_path,
ptr::NonNull, ptr::NonNull,
sync::Arc, sync::Arc,
time::Duration, time::Duration,
@ -9,52 +10,34 @@ use crossbeam_utils::CachePadded;
use crate::{ use crate::{
context::{Context, Heartbeat}, context::{Context, Heartbeat},
job::{Job, JobList, JobResult}, heartbeat::OwnedHeartbeatReceiver,
latch::{AsCoreLatch, CoreLatch, Probe}, job::{JobQueue as JobList, JobResult, QueuedJob as Job, QueuedJob, StackJob},
latch::{AsCoreLatch, CoreLatch, Probe, WorkerLatch},
util::DropGuard, util::DropGuard,
}; };
pub struct WorkerThread { pub struct WorkerThread {
pub(crate) context: Arc<Context>, pub(crate) context: Arc<Context>,
pub(crate) index: usize,
pub(crate) queue: UnsafeCell<JobList>, pub(crate) queue: UnsafeCell<JobList>,
pub(crate) heartbeat: NonNull<CachePadded<Heartbeat>>, pub(crate) heartbeat: OwnedHeartbeatReceiver,
pub(crate) join_count: Cell<u8>, pub(crate) join_count: Cell<u8>,
} }
impl Drop for WorkerThread {
fn drop(&mut self) {
// remove the current worker thread from the heartbeat list
self.context.shared().remove_heartbeat(self.index);
// SAFETY: we removed the heartbeat from the context, so we can safely drop it.
unsafe {
_ = Box::from_non_null(self.heartbeat);
}
}
}
thread_local! { thread_local! {
static WORKER: UnsafeCell<Option<NonNull<WorkerThread>>> = const { UnsafeCell::new(None) }; static WORKER: UnsafeCell<Option<NonNull<WorkerThread>>> = const { UnsafeCell::new(None) };
} }
impl WorkerThread { impl WorkerThread {
pub fn new_in(context: Arc<Context>) -> Self { pub fn new_in(context: Arc<Context>) -> Self {
let (heartbeat, index) = context.shared().new_heartbeat(); let heartbeat = context.heartbeats.new_heartbeat();
Self { Self {
context, context,
index,
queue: UnsafeCell::new(JobList::new()), queue: UnsafeCell::new(JobList::new()),
heartbeat, heartbeat,
join_count: Cell::new(0), join_count: Cell::new(0),
} }
} }
pub(crate) fn heartbeat(&self) -> &CachePadded<Heartbeat> {
// SAFETY: the heartbeat is always set when the worker thread is created
unsafe { self.heartbeat.as_ref() }
}
} }
impl WorkerThread { impl WorkerThread {
@ -80,53 +63,77 @@ impl WorkerThread {
} }
fn run_inner(&self) { fn run_inner(&self) {
let mut job = self.context.shared().pop_job(); let mut job = None;
'outer: loop { 'outer: loop {
while let Some(j) = job { if let Some(job) = job {
self.execute(j); self.execute(job);
}
let mut guard = self.context.shared(); if self.context.should_exit() {
if guard.should_exit() {
// if the context is stopped, break out of the outer loop which // if the context is stopped, break out of the outer loop which
// will exit the thread. // will exit the thread.
break 'outer; break 'outer;
} }
// we executed the shared job, now we want to check for any
// local jobs which this job might have spawned.
job = self.pop_front().or_else(|| guard.pop_job());
}
// no more jobs, wait to be notified of a new job or a heartbeat. // no more jobs, wait to be notified of a new job or a heartbeat.
match self.heartbeat().latch.wait_and_reset() { job = self.find_work_or_wait();
crate::latch::WakeResult::Wake => {
let mut guard = self.context.shared();
if guard.should_exit() {
break 'outer;
}
job = guard.pop_job();
}
crate::latch::WakeResult::Heartbeat => {
self.tick();
}
crate::latch::WakeResult::Set => {
// check if we should exit the thread
if self.context.shared().should_exit() {
break 'outer;
}
panic!("this thread shouldn't be woken by a finished job")
}
}
} }
} }
} }
impl WorkerThread { impl WorkerThread {
pub(crate) fn find_work(&self) -> Option<NonNull<Job>> {
self.find_work_inner().left()
}
/// Looks for work in the local queue, then in the shared context, and if no
/// work is found, waits for the thread to be notified of a new job, after
/// which it returns `None`.
/// The caller should then check for `should_exit` to determine if the
/// thread should exit, or look for work again.
pub(crate) fn find_work_or_wait(&self) -> Option<NonNull<Job>> {
match self.find_work_inner() {
either::Either::Left(job) => {
return Some(job);
}
either::Either::Right(mut guard) => {
// no jobs found, wait for a heartbeat or a new job
tracing::trace!("WorkerThread::find_work_or_wait: waiting for new job");
self.heartbeat.latch().wait_with_lock(&mut guard);
tracing::trace!("WorkerThread::find_work_or_wait: woken up from wait");
None
}
}
}
#[inline]
fn find_work_inner(
&self,
) -> either::Either<NonNull<Job>, parking_lot::MutexGuard<'_, crate::context::Shared>> {
// first check the local queue for jobs
if let Some(job) = self.pop_front() {
tracing::trace!("WorkerThread::find_work_inner: found local job: {:?}", job);
return either::Either::Left(job);
}
// then check the shared context for jobs
let mut guard = self.context.shared();
if let Some(job) = guard.pop_job() {
tracing::trace!("WorkerThread::find_work_inner: found shared job: {:?}", job);
return either::Either::Left(job);
}
either::Either::Right(guard)
}
#[inline(always)] #[inline(always)]
pub(crate) fn tick(&self) { pub(crate) fn tick(&self) {
if self.heartbeat().is_pending() { if self.heartbeat.take() {
tracing::trace!("received heartbeat, thread id: {:?}", self.index); tracing::trace!(
"received heartbeat, thread id: {:?}",
self.heartbeat.index()
);
self.heartbeat_cold(); self.heartbeat_cold();
} }
} }
@ -134,21 +141,22 @@ impl WorkerThread {
#[inline] #[inline]
fn execute(&self, job: NonNull<Job>) { fn execute(&self, job: NonNull<Job>) {
self.tick(); self.tick();
Job::execute(job); unsafe { Job::execute(job.as_ptr()) };
} }
#[cold] #[cold]
fn heartbeat_cold(&self) { fn heartbeat_cold(&self) {
let mut guard = self.context.shared(); let mut guard = self.context.shared();
if !guard.jobs.contains_key(&self.index) { if !guard.jobs.contains_key(&self.heartbeat.id()) {
if let Some(job) = self.pop_back() { if let Some(job) = self.pop_back() {
Job::set_shared(unsafe { job.as_ref() });
tracing::trace!("heartbeat: sharing job: {:?}", job); tracing::trace!("heartbeat: sharing job: {:?}", job);
guard.jobs.insert(self.heartbeat.id(), job);
unsafe { unsafe {
job.as_ref().set_pending(); // SAFETY: we are holding the lock on the shared context.
self.context.notify_job_shared();
} }
guard.jobs.insert(self.index, job);
guard.notify_job_shared();
} }
} }
} }
@ -234,19 +242,12 @@ impl HeartbeatThread {
let mut i = 0; let mut i = 0;
loop { loop {
let sleep_for = { let sleep_for = {
let guard = self.ctx.shared(); if self.ctx.should_exit() {
if guard.should_exit() {
break; break;
} }
if let Some((_, heartbeat)) = guard.heartbeats.iter().nth(i) { self.ctx.heartbeats.notify_nth(i);
unsafe { let num_heartbeats = self.ctx.heartbeats.len();
heartbeat.as_ref().latch.signal_heartbeat();
}
}
let num_heartbeats = guard.heartbeats.len();
drop(guard);
if i >= num_heartbeats { if i >= num_heartbeats {
i = 0; i = 0;
@ -265,120 +266,100 @@ impl HeartbeatThread {
} }
impl WorkerThread { impl WorkerThread {
#[cold] pub fn wait_until_queued_job<T>(
fn wait_until_latch_cold(&self, latch: &CoreLatch) { &self,
'outer: while !latch.probe() { job: *const QueuedJob,
// process local jobs before locking shared context ) -> Option<std::thread::Result<T>> {
while let Some(job) = self.pop_front() { let recv = unsafe { (*job).as_receiver::<T>() };
tracing::trace!("thread {:?} executing local job: {:?}", self.index, job);
unsafe {
job.as_ref().set_pending();
}
Job::execute(job);
tracing::trace!("thread {:?} finished local job: {:?}", self.index, job);
}
// take a shared job, if it exists
'inner: loop {
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
tracing::trace!(
"thread {:?} executing shared job: {:?}",
self.index,
shared_job
);
Job::execute(shared_job);
}
while !latch.probe() {
tracing::trace!("thread {:?} looking for shared jobs", self.index);
let job = {
let mut guard = self.context.shared();
guard.jobs.remove(&self.index).or_else(|| guard.pop_job())
};
match job {
Some(job) => {
tracing::trace!("thread {:?} found job: {:?}", self.index, job);
Job::execute(job);
continue 'outer;
}
None => {
tracing::trace!("thread {:?} is sleeping", self.index);
match self.heartbeat().latch.wait_and_reset() {
// why were we woken up?
// 1. the heartbeat thread ticked and set the
// latch, so we should see if we have any work
// to share.
// 2. a job was shared and we were notified, so
// we should execute it.
// 3. the job we were waiting on was completed,
// so we should return it.
crate::latch::WakeResult::Set => {
break 'outer; // we were woken up by a job being set, so we should exit the loop.
}
crate::latch::WakeResult::Wake => {
// skip checking for local jobs, since we
// were woken up to check for shared jobs.
continue 'inner;
}
crate::latch::WakeResult::Heartbeat => {
self.tick();
continue 'outer;
}
}
// since we were sleeping, the shared job can't be populated,
// so resuming the inner loop is fine.
}
}
}
break;
}
}
tracing::trace!(
"thread {:?} finished waiting on latch {:?}",
self.index,
latch
);
self.heartbeat().latch.as_core_latch().unset();
return;
}
pub fn wait_until_job<T>(&self, job: &Job<T>, latch: &CoreLatch) -> Option<JobResult<T>> {
// we've already checked that the job was popped from the queue // we've already checked that the job was popped from the queue
// check if shared job is our job // check if shared job is our job
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
if core::ptr::eq(shared_job.as_ptr(), job as *const Job<T> as _) { if core::ptr::eq(shared_job.as_ptr(), job as *const Job as _) {
// this is the job we are looking for, so we want to // this is the job we are looking for, so we want to
// short-circuit and call it inline // short-circuit and call it inline
return None; return None;
} else { } else {
// this isn't the job we are looking for, but we still need to // this isn't the job we are looking for, but we still need to
// execute it // execute it
Job::execute(shared_job); unsafe { Job::execute(shared_job.as_ptr()) };
} }
} }
// do the usual thing and wait for the job's latch // do the usual thing and wait for the job's latch
if !latch.probe() { loop {
self.wait_until_latch_cold(latch); match recv.poll() {
Some(t) => {
return Some(t);
} }
None => {
cold_path();
Some(job.wait()) // check local jobs before locking shared context
if let Some(job) = self.find_work_or_wait() {
tracing::trace!(
"thread {:?} executing local job: {:?}",
self.heartbeat.index(),
job
);
unsafe {
Job::execute(job.as_ptr());
}
tracing::trace!(
"thread {:?} finished local job: {:?}",
self.heartbeat.index(),
job
);
continue;
}
}
}
}
} }
pub fn wait_until_latch<L>(&self, latch: &L) pub fn wait_until_latch<L>(&self, latch: &L)
where where
L: AsCoreLatch, L: Probe,
{ {
let latch = latch.as_core_latch();
if !latch.probe() { if !latch.probe() {
tracing::trace!("thread {:?} waiting on latch {:?}", self.index, latch); tracing::trace!("thread {:?} waiting on latch", self.heartbeat.index());
self.wait_until_latch_cold(latch) self.wait_until_latch_cold(latch);
}
}
#[cold]
fn wait_until_latch_cold<L>(&self, latch: &L)
where
L: Probe,
{
if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
tracing::trace!(
"thread {:?} reclaiming shared job: {:?}",
self.heartbeat.index(),
shared_job
);
unsafe { Job::execute(shared_job.as_ptr()) };
}
// do the usual thing and wait for the job's latch
// do the usual thing??? chatgipity really said this..
while !latch.probe() {
// check local jobs before locking shared context
if let Some(job) = self.find_work_or_wait() {
tracing::trace!(
"thread {:?} executing local job: {:?}",
self.heartbeat.index(),
job
);
unsafe {
Job::execute(job.as_ptr());
}
tracing::trace!(
"thread {:?} finished local job: {:?}",
self.heartbeat.index(),
job
);
continue;
}
} }
} }
} }

View file

@ -86,7 +86,6 @@ fn join_distaff() {
let sum = sum(&tree, tree.root().unwrap(), s); let sum = sum(&tree, tree.root().unwrap(), s);
sum sum
}); });
eprintln!("sum: {sum}");
std::hint::black_box(sum); std::hint::black_box(sum);
} }
} }