it works..

This commit is contained in:
Janis 2025-06-25 13:13:03 +02:00
parent 1363f20cfc
commit 4742733683
8 changed files with 344 additions and 267 deletions

View file

@ -30,7 +30,7 @@ parking_lot = {version = "0.12.3"}
thread_local = "1.1.8" thread_local = "1.1.8"
crossbeam = "0.8.4" crossbeam = "0.8.4"
st3 = "0.4" st3 = "0.4"
chili = "0.2.0" chili = "0.2.1"
async-task = "4.7.1" async-task = "4.7.1"

View file

@ -14,13 +14,13 @@ use parking_lot::{Condvar, Mutex};
use crate::{ use crate::{
job::{HeapJob, Job, StackJob}, job::{HeapJob, Job, StackJob},
latch::{LatchRef, MutexLatch, UnsafeWakeLatch}, latch::{AsCoreLatch, HeartbeatLatch, LatchRef, UnsafeWakeLatch},
workerthread::{HeartbeatThread, WorkerThread}, workerthread::{HeartbeatThread, WorkerThread},
}; };
pub struct Heartbeat { pub struct Heartbeat {
heartbeat: AtomicU8, heartbeat: AtomicU8,
pub latch: MutexLatch, pub latch: HeartbeatLatch,
} }
impl Heartbeat { impl Heartbeat {
@ -31,25 +31,15 @@ impl Heartbeat {
pub fn new() -> (Arc<CachePadded<Self>>, Weak<CachePadded<Self>>) { pub fn new() -> (Arc<CachePadded<Self>>, Weak<CachePadded<Self>>) {
let strong = Arc::new(CachePadded::new(Self { let strong = Arc::new(CachePadded::new(Self {
heartbeat: AtomicU8::new(Self::CLEAR), heartbeat: AtomicU8::new(Self::CLEAR),
latch: MutexLatch::new(), latch: HeartbeatLatch::new(),
})); }));
let weak = Arc::downgrade(&strong); let weak = Arc::downgrade(&strong);
(strong, weak) (strong, weak)
} }
/// returns true if the heartbeat was previously sleeping.
pub fn set_pending(&self) -> bool {
let old = self.heartbeat.swap(Self::PENDING, Ordering::Relaxed);
old == Self::SLEEPING
}
pub fn clear(&self) {
self.heartbeat.store(Self::CLEAR, Ordering::Relaxed);
}
pub fn is_pending(&self) -> bool { pub fn is_pending(&self) -> bool {
self.heartbeat.load(Ordering::Relaxed) == Self::PENDING self.latch.as_core_latch().poll_heartbeat()
} }
} }
@ -80,6 +70,10 @@ impl Shared {
(strong, index) (strong, index)
} }
pub(crate) fn remove_heartbeat(&mut self, index: usize) {
self.heartbeats.remove(&index);
}
pub fn pop_job(&mut self) -> Option<NonNull<Job>> { pub fn pop_job(&mut self) -> Option<NonNull<Job>> {
// this is unlikely, so make the function cold? // this is unlikely, so make the function cold?
// TODO: profile this // TODO: profile this
@ -96,6 +90,17 @@ impl Shared {
self.injected_jobs.pop().unwrap() self.injected_jobs.pop().unwrap()
} }
pub fn notify_job_shared(&self) {
_ = self.heartbeats.iter().find(|(_, heartbeat)| {
if let Some(heartbeat) = heartbeat.upgrade() {
heartbeat.latch.signal_job_shared();
true
} else {
false
}
});
}
pub fn should_exit(&self) -> bool { pub fn should_exit(&self) -> bool {
self.should_exit self.should_exit
} }
@ -162,7 +167,7 @@ impl Context {
shared.should_exit = true; shared.should_exit = true;
for (_, heartbeat) in shared.heartbeats.iter() { for (_, heartbeat) in shared.heartbeats.iter() {
if let Some(heartbeat) = heartbeat.upgrade() { if let Some(heartbeat) = heartbeat.upgrade() {
heartbeat.latch.set(); heartbeat.latch.signal_job_shared();
} }
} }
self.shared_job.notify_all(); self.shared_job.notify_all();
@ -181,11 +186,8 @@ impl Context {
pub fn inject_job(&self, job: NonNull<Job>) { pub fn inject_job(&self, job: NonNull<Job>) {
let mut shared = self.shared.lock(); let mut shared = self.shared.lock();
shared.injected_jobs.push(job); shared.injected_jobs.push(job);
self.notify_shared_job();
}
pub fn notify_shared_job(&self) { shared.notify_job_shared();
self.shared_job.notify_one();
} }
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
@ -227,10 +229,10 @@ impl Context {
F: FnOnce(&WorkerThread) -> T + Send, F: FnOnce(&WorkerThread) -> T + Send,
T: Send, T: Send,
{ {
use crate::latch::MutexLatch; use crate::latch::HeartbeatLatch;
// current thread isn't a worker thread, create job and inject into global context // current thread isn't a worker thread, create job and inject into global context
let latch = MutexLatch::new(); let latch = HeartbeatLatch::new();
let job = StackJob::new( let job = StackJob::new(
move || { move || {
@ -246,7 +248,7 @@ impl Context {
job.set_pending(); job.set_pending();
self.inject_job(Into::into(&job)); self.inject_job(Into::into(&job));
latch.wait(); latch.wait_and_reset();
let t = unsafe { job.transmute_ref::<T>().wait().into_result() }; let t = unsafe { job.transmute_ref::<T>().wait().into_result() };

View file

@ -650,7 +650,7 @@ mod stackjob {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
tracing::trace!("job completed: {:?}", job); tracing::trace!("stack job completed: {:?}", job);
let job = unsafe { &*job.cast::<Job<T>>() }; let job = unsafe { &*job.cast::<Job<T>>() };
job.complete(result); job.complete(result);
@ -703,13 +703,20 @@ mod heapjob {
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) }; let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
let f = this.into_inner(); let f = this.into_inner();
_ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
let job = unsafe { &*job.cast::<Job<T>>() };
job.complete(result);
}
// drop job (this is fine because the job of a HeapJob is pure POD). // drop job (this is fine because the job of a HeapJob is pure POD).
unsafe { unsafe {
ptr::drop_in_place(job); ptr::drop_in_place(job);
} }
tracing::trace!("heap job completed: {:?}", job);
// free box that was allocated at (1) // free box that was allocated at (1)
_ = unsafe { Box::<ManuallyDrop<Job<T>>>::from_raw(job.cast()) }; _ = unsafe { Box::<ManuallyDrop<Job<T>>>::from_raw(job.cast()) };
} }

View file

@ -67,20 +67,12 @@ impl WorkerThread {
// because we will be waiting on it. // because we will be waiting on it.
let latch = unsafe { UnsafeWakeLatch::new(&raw const (*self.heartbeat).latch) }; let latch = unsafe { UnsafeWakeLatch::new(&raw const (*self.heartbeat).latch) };
let a = StackJob::new( let a = StackJob::new(a, LatchRef::new(&latch));
move || {
// TODO: bench whether tick'ing here is good.
// turns out this actually costs a lot of time, likely because of the thread local check.
// WorkerThread::current_ref()
// .expect("stackjob is run in workerthread.")
// .tick();
a()
},
LatchRef::new(&latch),
);
let job = a.as_job(); let job = a.as_job();
self.push_front(&job); self.push_back(&job);
self.tick();
let rb = match catch_unwind(AssertUnwindSafe(|| b())) { let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
Ok(val) => val, Ok(val) => val,
@ -97,9 +89,12 @@ impl WorkerThread {
// remove job from the queue, so it doesn't get run again. // remove job from the queue, so it doesn't get run again.
// job.unlink(); // job.unlink();
//SAFETY: we are in a worker thread, so we can safely access the queue. //SAFETY: we are in a worker thread, so we can safely access the queue.
unsafe { // unsafe {
self.queue.as_mut_unchecked().remove(&job); // self.queue.as_mut_unchecked().remove(&job);
} // }
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
self.pop_back();
// a is allowed to panic here, because we already finished b. // a is allowed to panic here, because we already finished b.
unsafe { a.unwrap()() } unsafe { a.unwrap()() }

View file

@ -2,7 +2,10 @@ use core::{
marker::PhantomData, marker::PhantomData,
sync::atomic::{AtomicUsize, Ordering}, sync::atomic::{AtomicUsize, Ordering},
}; };
use std::sync::{Arc, atomic::AtomicU8}; use std::{
cell::UnsafeCell,
sync::{Arc, atomic::AtomicU8},
};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
@ -30,6 +33,8 @@ impl AtomicLatch {
pub const UNSET: u8 = 0; pub const UNSET: u8 = 0;
pub const SET: u8 = 1; pub const SET: u8 = 1;
pub const SLEEPING: u8 = 2; pub const SLEEPING: u8 = 2;
pub const WAKEUP: u8 = 4;
pub const HEARTBEAT: u8 = 8;
#[inline] #[inline]
pub const fn new() -> Self { pub const fn new() -> Self {
@ -45,24 +50,58 @@ impl AtomicLatch {
} }
#[inline] #[inline]
pub fn reset(&self) { pub fn unset(&self) {
self.inner.store(Self::UNSET, Ordering::Release); self.inner.fetch_and(!Self::SET, Ordering::Release);
}
pub fn reset(&self) -> u8 {
self.inner.swap(Self::UNSET, Ordering::Release)
} }
pub fn get(&self) -> u8 { pub fn get(&self) -> u8 {
self.inner.load(Ordering::Acquire) self.inner.load(Ordering::Acquire)
} }
pub fn set_sleeping(&self) { pub fn poll_heartbeat(&self) -> bool {
self.inner.store(Self::SLEEPING, Ordering::Release); self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT
== Self::HEARTBEAT
}
/// returns true if the latch was already set.
pub fn set_sleeping(&self) -> bool {
self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET
}
pub fn is_sleeping(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING
}
pub fn is_heartbeat(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT
}
pub fn is_wakeup(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP
}
pub fn is_set(&self) -> bool {
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
}
pub fn set_wakeup(&self) {
self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed);
}
pub fn set_heartbeat(&self) {
self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed);
} }
/// returns true if the latch was previously sleeping. /// returns true if the latch was previously sleeping.
#[inline] #[inline]
pub unsafe fn set(this: *const Self) -> bool { pub unsafe fn set(this: *const Self) -> bool {
unsafe { unsafe {
let old = (*this).inner.swap(Self::SET, Ordering::Release); let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed);
old == Self::SLEEPING old & Self::SLEEPING == Self::SLEEPING
} }
} }
} }
@ -79,7 +118,7 @@ impl Latch for AtomicLatch {
impl Probe for AtomicLatch { impl Probe for AtomicLatch {
#[inline] #[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
self.inner.load(Ordering::Acquire) == Self::SET self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
} }
} }
impl AsCoreLatch for AtomicLatch { impl AsCoreLatch for AtomicLatch {
@ -153,58 +192,6 @@ impl Probe for NopLatch {
} }
} }
pub struct ThreadWakeLatch {
waker: Mutex<Option<std::thread::Thread>>,
}
impl ThreadWakeLatch {
#[inline]
pub const fn new() -> Self {
Self {
waker: Mutex::new(None),
}
}
#[inline]
pub fn reset(&self) {
let mut waker = self.waker.lock();
*waker = None;
}
#[inline]
pub fn set_waker(&self, thread: std::thread::Thread) {
let mut waker = self.waker.lock();
*waker = Some(thread);
}
pub unsafe fn wait(&self) {
assert!(
self.waker.lock().replace(std::thread::current()).is_none(),
"ThreadWakeLatch can only be waited once per thread"
);
std::thread::park();
}
}
impl Latch for ThreadWakeLatch {
#[inline]
unsafe fn set_raw(this: *const Self) {
unsafe {
if let Some(thread) = (&*this).waker.lock().take() {
thread.unpark();
}
}
}
}
impl Probe for ThreadWakeLatch {
#[inline]
fn probe(&self) -> bool {
self.waker.lock().is_some()
}
}
pub struct CountLatch<L: Latch> { pub struct CountLatch<L: Latch> {
count: AtomicUsize, count: AtomicUsize,
inner: L, inner: L,
@ -234,10 +221,8 @@ impl<L: Latch> CountLatch<L> {
#[inline] #[inline]
pub fn decrement(&self) { pub fn decrement(&self) {
if self.count.fetch_sub(1, Ordering::Release) == 1 { unsafe {
unsafe { Latch::set_raw(self);
Latch::set_raw(&self.inner);
}
} }
} }
} }
@ -246,8 +231,11 @@ impl<L: Latch> Latch for CountLatch<L> {
#[inline] #[inline]
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
unsafe { unsafe {
let this = &*this; if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
this.decrement(); tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
// If the count was 1, we need to set the inner latch.
Latch::set_raw(&(*this).inner);
}
} }
} }
} }
@ -266,30 +254,60 @@ impl<L: Latch + AsCoreLatch> AsCoreLatch for CountLatch<L> {
} }
} }
pub struct MutexLatch { pub struct HeartbeatLatch {
inner: Mutex<bool>, inner: UnsafeCell<AtomicLatch>,
lock: Mutex<()>,
condvar: Condvar, condvar: Condvar,
} }
impl MutexLatch { unsafe impl Send for HeartbeatLatch {}
unsafe impl Sync for HeartbeatLatch {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WakeResult {
Wake,
Heartbeat,
Set,
}
impl HeartbeatLatch {
#[inline] #[inline]
pub const fn new() -> Self { pub const fn new() -> Self {
Self { Self {
inner: Mutex::new(false), inner: UnsafeCell::new(AtomicLatch::new()),
lock: Mutex::new(()),
condvar: Condvar::new(), condvar: Condvar::new(),
} }
} }
#[inline] #[inline]
pub fn reset(&self) { pub fn reset(&self) {
let mut guard = self.inner.lock(); let _guard = self.lock.lock();
*guard = false; // SAFETY: inner is atomic, so we can safely access it.
unsafe { self.inner.as_mut_unchecked().unset() };
} }
pub fn wait(&self) { pub fn wait_and_reset(&self) -> WakeResult {
let mut guard = self.inner.lock(); // SAFETY: inner is locked by the mutex, so we can safely access it.
while !*guard { let value = unsafe {
self.condvar.wait(&mut guard); let mut guard = self.lock.lock();
let inner = self.inner.as_ref_unchecked();
inner.set_sleeping();
while inner.get() & !AtomicLatch::SLEEPING == AtomicLatch::UNSET {
self.condvar.wait(&mut guard);
}
inner.reset()
};
if value & AtomicLatch::SET == AtomicLatch::SET {
WakeResult::Set
} else if value & AtomicLatch::WAKEUP == AtomicLatch::WAKEUP {
WakeResult::Wake
} else if value & AtomicLatch::HEARTBEAT == AtomicLatch::HEARTBEAT {
WakeResult::Heartbeat
} else {
panic!("MutexLatch was not set correctly");
} }
} }
@ -299,29 +317,72 @@ impl MutexLatch {
} }
} }
pub fn wait_and_reset(&self) { pub fn signal_heartbeat(&self) {
let mut guard = self.inner.lock(); let mut _guard = self.lock.lock();
while !*guard { // SAFETY: inner is locked by the mutex, so we can safely access it.
self.condvar.wait(&mut guard); unsafe {
let inner = self.inner.as_ref_unchecked();
inner.set_heartbeat();
// If the latch was sleeping, notify the waiting thread.
if inner.is_sleeping() {
self.condvar.notify_all();
}
}
}
pub fn signal_job_shared(&self) {
let mut _guard = self.lock.lock();
// SAFETY: inner is locked by the mutex, so we can safely access it.
unsafe {
self.inner.as_ref_unchecked().set_wakeup();
if self.inner.as_ref_unchecked().is_sleeping() {
self.condvar.notify_all();
}
}
}
pub fn signal_job_finished(&self) {
let mut _guard = self.lock.lock();
// SAFETY: inner is locked by the mutex, so we can safely access it.
unsafe {
CoreLatch::set(self.inner.get());
if self.inner.as_ref_unchecked().is_sleeping() {
self.condvar.notify_all();
}
} }
*guard = false;
} }
} }
impl Latch for MutexLatch { impl Latch for HeartbeatLatch {
#[inline] #[inline]
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
// SAFETY: `this` is valid until the guard is dropped.
unsafe { unsafe {
*(&*this).inner.lock() = true; let this = &*this;
(&*this).condvar.notify_all(); let _guard = this.lock.lock();
Latch::set_raw(this.inner.get() as *const AtomicLatch);
if this.inner.as_ref_unchecked().is_sleeping() {
this.condvar.notify_all();
}
} }
} }
} }
impl Probe for MutexLatch { impl Probe for HeartbeatLatch {
#[inline] #[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
*self.inner.lock() let _guard = self.lock.lock();
// SAFETY: inner is atomic, so we can safely access it.
unsafe { self.inner.as_ref_unchecked().probe() }
}
}
impl AsCoreLatch for HeartbeatLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
// SAFETY: inner is atomic, so we can safely access it.
unsafe { self.inner.as_ref_unchecked() }
} }
} }
@ -354,8 +415,10 @@ impl Latch for WakeLatch {
let ctx = WorkerThread::current_ref().unwrap().context.clone(); let ctx = WorkerThread::current_ref().unwrap().context.clone();
// If the latch was sleeping, wake the worker thread // If the latch was sleeping, wake the worker thread
ctx.shared().heartbeats.get(&worker_index).and_then(|weak| { ctx.shared().heartbeats.get(&worker_index).and_then(|weak| {
weak.upgrade() weak.upgrade().map(|heartbeat| {
.map(|heartbeat| Latch::set_raw(&heartbeat.latch)) // we set the latch to wake the worker so it knows to check the heartbeat
heartbeat.latch.signal_job_finished()
})
}); });
} }
} }
@ -376,19 +439,16 @@ impl AsCoreLatch for WakeLatch {
} }
} }
/// A latch that can be set from any thread, but must be created with a valid waker.
pub struct UnsafeWakeLatch { pub struct UnsafeWakeLatch {
inner: AtomicLatch, waker: *const HeartbeatLatch,
waker: *const MutexLatch,
} }
impl UnsafeWakeLatch { impl UnsafeWakeLatch {
/// # Safety /// # Safety
/// The `waker` must be valid until the latch is set. /// The `waker` must be valid until the latch is set.
pub unsafe fn new(waker: *const MutexLatch) -> Self { pub unsafe fn new(waker: *const HeartbeatLatch) -> Self {
Self { Self { waker }
inner: AtomicLatch::new(),
waker,
}
} }
} }
@ -397,9 +457,7 @@ impl Latch for UnsafeWakeLatch {
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
unsafe { unsafe {
let waker = (*this).waker; let waker = (*this).waker;
if CoreLatch::set(&(&*this).inner) { Latch::set_raw(waker);
Latch::set_raw(waker);
}
} }
} }
} }
@ -407,14 +465,22 @@ impl Latch for UnsafeWakeLatch {
impl Probe for UnsafeWakeLatch { impl Probe for UnsafeWakeLatch {
#[inline] #[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
self.inner.probe() // SAFETY: waker is valid as per the constructor contract.
unsafe {
let waker = &*self.waker;
waker.probe()
}
} }
} }
impl AsCoreLatch for UnsafeWakeLatch { impl AsCoreLatch for UnsafeWakeLatch {
#[inline] #[inline]
fn as_core_latch(&self) -> &CoreLatch { fn as_core_latch(&self) -> &CoreLatch {
&self.inner // SAFETY: waker is valid as per the constructor contract.
unsafe {
let waker = &*self.waker;
waker.as_core_latch()
}
} }
} }
@ -437,7 +503,7 @@ mod tests {
} }
assert_eq!(latch.get(), AtomicLatch::SET); assert_eq!(latch.get(), AtomicLatch::SET);
assert!(latch.probe()); assert!(latch.probe());
latch.reset(); latch.unset();
assert_eq!(latch.get(), AtomicLatch::UNSET); assert_eq!(latch.get(), AtomicLatch::UNSET);
} }
@ -451,7 +517,7 @@ mod tests {
assert!(!latch.probe()); assert!(!latch.probe());
assert!(AtomicLatch::set(&latch)); assert!(AtomicLatch::set(&latch));
} }
assert_eq!(latch.get(), AtomicLatch::SET); assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING);
assert!(latch.probe()); assert!(latch.probe());
latch.reset(); latch.reset();
assert_eq!(latch.get(), AtomicLatch::UNSET); assert_eq!(latch.get(), AtomicLatch::UNSET);
@ -465,32 +531,6 @@ mod tests {
); );
} }
#[test]
fn thread_wake_latch() {
let latch = Arc::new(ThreadWakeLatch::new());
let main = Arc::new(ThreadWakeLatch::new());
let handle = std::thread::spawn({
let latch = latch.clone();
let main = main.clone();
move || unsafe {
Latch::set_raw(&*main);
latch.wait();
}
});
unsafe {
main.wait();
Latch::set_raw(&*latch);
}
handle.join().expect("Thread should join successfully");
assert!(
!latch.probe() && !main.probe(),
"Latch should be set after waiting thread wakes up"
);
}
#[test] #[test]
fn count_latch() { fn count_latch() {
let latch = CountLatch::new(AtomicLatch::new()); let latch = CountLatch::new(AtomicLatch::new());
@ -516,8 +556,9 @@ mod tests {
} }
#[test] #[test]
#[traced_test]
fn mutex_latch() { fn mutex_latch() {
let latch = Arc::new(MutexLatch::new()); let latch = Arc::new(HeartbeatLatch::new());
assert!(!latch.probe()); assert!(!latch.probe());
latch.set(); latch.set();
assert!(latch.probe()); assert!(latch.probe());
@ -527,7 +568,7 @@ mod tests {
// Test wait functionality // Test wait functionality
let latch_clone = latch.clone(); let latch_clone = latch.clone();
let handle = std::thread::spawn(move || { let handle = std::thread::spawn(move || {
latch_clone.wait(); assert_eq!(latch_clone.wait_and_reset(), WakeResult::Set);
}); });
// Give the thread time to block // Give the thread time to block
@ -535,11 +576,11 @@ mod tests {
assert!(!latch.probe()); assert!(!latch.probe());
latch.set(); latch.set();
assert!(latch.probe());
handle.join().expect("Thread should join successfully"); handle.join().expect("Thread should join successfully");
} }
#[test] #[test]
#[traced_test]
fn wake_latch() { fn wake_latch() {
let context = Context::new_with_threads(1); let context = Context::new_with_threads(1);
let count = Arc::new(AtomicUsize::new(0)); let count = Arc::new(AtomicUsize::new(0));

View file

@ -13,14 +13,14 @@ use async_task::Runnable;
use crate::{ use crate::{
context::Context, context::Context,
job::{HeapJob, Job}, job::{HeapJob, Job},
latch::{AsCoreLatch, CountLatch, WakeLatch}, latch::{AsCoreLatch, CountLatch, HeartbeatLatch, WakeLatch},
util::{DropGuard, SendPtr}, util::{DropGuard, SendPtr},
workerthread::WorkerThread, workerthread::WorkerThread,
}; };
pub struct Scope<'scope, 'env: 'scope> { pub struct Scope<'scope, 'env: 'scope> {
// latch to wait on before the scope finishes // latch to wait on before the scope finishes
job_counter: CountLatch<WakeLatch>, job_counter: CountLatch<HeartbeatLatch>,
// local threadpool // local threadpool
context: Arc<Context>, context: Arc<Context>,
// panic error // panic error
@ -61,7 +61,6 @@ impl<'scope, 'env> Scope<'scope, 'env> {
}); });
// set worker index in the job counter // set worker index in the job counter
self.job_counter.inner().set_worker_index(worker.index);
worker.wait_until_latch(self.job_counter.as_core_latch()); worker.wait_until_latch(self.job_counter.as_core_latch());
} }
} }
@ -146,23 +145,23 @@ impl<'scope, 'env> Scope<'scope, 'env> {
where where
F: FnOnce(&'scope Self) + Send, F: FnOnce(&'scope Self) + Send,
{ {
self.context.run_in_worker(|worker| { self.job_counter.increment();
self.job_counter.increment();
let this = SendPtr::new_const(self).unwrap(); let this = SendPtr::new_const(self).unwrap();
let job = Box::new(HeapJob::new(move || unsafe { let job = Box::new(HeapJob::new(move || unsafe {
_ = f(this.as_ref()); _ = f(this.as_ref());
this.as_ref().job_counter.decrement(); this.as_unchecked_ref().job_counter.decrement();
})) }))
.into_boxed_job(); .into_boxed_job();
tracing::trace!("allocated heapjob"); tracing::trace!("allocated heapjob");
worker.push_front(job); WorkerThread::current_ref()
.expect("spawn is run in workerthread.")
.push_front(job as _);
tracing::trace!("leaked heapjob"); tracing::trace!("leaked heapjob");
});
} }
pub fn spawn_future<T, F>(&'scope self, future: F) -> async_task::Task<T> pub fn spawn_future<T, F>(&'scope self, future: F) -> async_task::Task<T>
@ -259,7 +258,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
unsafe fn from_context(context: Arc<Context>) -> Self { unsafe fn from_context(context: Arc<Context>) -> Self {
Self { Self {
context, context,
job_counter: CountLatch::new(WakeLatch::new(0)), job_counter: CountLatch::new(HeartbeatLatch::new()),
panic: AtomicPtr::new(ptr::null_mut()), panic: AtomicPtr::new(ptr::null_mut()),
_scope: PhantomData, _scope: PhantomData,
_env: PhantomData, _env: PhantomData,
@ -291,7 +290,6 @@ mod tests {
} }
#[test] #[test]
#[traced_test]
fn join() { fn join() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);

View file

@ -53,14 +53,18 @@ impl ThreadPool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use tracing_test::traced_test;
use super::*; use super::*;
#[test] #[test]
#[traced_test]
fn spawn_borrow() { fn spawn_borrow() {
let pool = ThreadPool::new_with_threads(1); let pool = ThreadPool::new_with_threads(1);
let mut x = 0; let mut x = 0;
pool.scope(|scope| { pool.scope(|scope| {
scope.spawn(|_| { scope.spawn(|_| {
tracing::info!("Incrementing x");
x += 1; x += 1;
}); });
}); });

View file

@ -22,6 +22,13 @@ pub struct WorkerThread {
pub(crate) join_count: Cell<u8>, pub(crate) join_count: Cell<u8>,
} }
impl Drop for WorkerThread {
fn drop(&mut self) {
// remove the current worker thread from the heartbeat list
self.context.shared().remove_heartbeat(self.index);
}
}
thread_local! { thread_local! {
static WORKER: UnsafeCell<Option<NonNull<WorkerThread>>> = const { UnsafeCell::new(None) }; static WORKER: UnsafeCell<Option<NonNull<WorkerThread>>> = const { UnsafeCell::new(None) };
} }
@ -65,49 +72,38 @@ impl WorkerThread {
fn run_inner(&self) { fn run_inner(&self) {
let mut job = self.context.shared().pop_job(); let mut job = self.context.shared().pop_job();
'outer: loop { 'outer: loop {
let mut guard = loop { while let Some(j) = job {
if let Some(job) = job.take() { self.execute(j);
self.execute(job);
let mut guard = self.context.shared();
if guard.should_exit() {
// if the context is stopped, break out of the outer loop which
// will exit the thread.
break 'outer;
} }
// we executed the shared job, now we want to check for any // we executed the shared job, now we want to check for any
// local jobs which this job might have spawned. // local jobs which this job might have spawned.
let next = self job = self.pop_front().or_else(|| guard.pop_job());
.pop_front() }
.map(|job| (Some(job), None))
.unwrap_or_else(|| {
let mut guard = self.context.shared();
(guard.pop_job(), Some(guard))
});
match next { // no more jobs, wait to be notified of a new job or a heartbeat.
// no job, but guard => check if we should exit match self.heartbeat.latch.wait_and_reset() {
(None, Some(guard)) => { crate::latch::WakeResult::Wake => {
tracing::trace!("worker: no local job, waiting for shared job"); let mut guard = self.context.shared();
if guard.should_exit() {
if guard.should_exit() { break 'outer;
// if the context is stopped, break out of the outer loop which
// will exit the thread.
break 'outer;
}
// no local jobs, wait for shared job
break guard;
} }
// some job => drop guard, continue inner loop
(Some(next), _) => { job = guard.pop_job();
tracing::trace!("worker: executing job: {:?}", next);
job = Some(next);
continue;
}
// no job, no guard ought to be unreachable.
_ => unreachable!(),
} }
}; crate::latch::WakeResult::Heartbeat => {
self.tick();
self.context.shared_job.wait(&mut guard); }
// a job was shared and we were notified, so we want to execute that job before any possible local jobs. crate::latch::WakeResult::Set => {
job = guard.pop_job(); panic!("this thread shouldn't be woken by a finished job")
}
}
} }
} }
} }
@ -138,11 +134,9 @@ impl WorkerThread {
job.as_ref().set_pending(); job.as_ref().set_pending();
} }
guard.jobs.insert(self.index, job); guard.jobs.insert(self.index, job);
self.context.notify_shared_job(); guard.notify_job_shared();
} }
} }
self.heartbeat.clear();
} }
} }
@ -236,9 +230,7 @@ impl HeartbeatThread {
b.upgrade() b.upgrade()
.inspect(|heartbeat| { .inspect(|heartbeat| {
if n == i { if n == i {
if heartbeat.set_pending() { heartbeat.latch.signal_heartbeat();
heartbeat.latch.set();
}
} }
n += 1; n += 1;
}) })
@ -267,60 +259,97 @@ impl HeartbeatThread {
impl WorkerThread { impl WorkerThread {
#[cold] #[cold]
fn wait_until_latch_cold(&self, latch: &CoreLatch) { fn wait_until_latch_cold(&self, latch: &CoreLatch) {
// does this optimise?
assert!(!latch.probe());
'outer: while !latch.probe() { 'outer: while !latch.probe() {
// process local jobs before locking shared context // process local jobs before locking shared context
while let Some(job) = self.pop_front() { while let Some(job) = self.pop_front() {
tracing::trace!("thread {:?} executing local job: {:?}", self.index, job);
unsafe { unsafe {
job.as_ref().set_pending(); job.as_ref().set_pending();
} }
self.execute(job); Job::execute(job);
tracing::trace!("thread {:?} finished local job: {:?}", self.index, job);
} }
// take a shared job, if it exists // take a shared job, if it exists
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { 'inner: loop {
self.execute(shared_job); if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
} tracing::trace!(
"thread {:?} executing shared job: {:?}",
self.index,
shared_job
);
Job::execute(shared_job);
}
while !latch.probe() { while !latch.probe() {
let job = { tracing::trace!("thread {:?} looking for shared jobs", self.index);
let mut guard = self.context.shared();
guard.jobs.remove(&self.index).or_else(|| guard.pop_job())
};
match job { let job = {
Some(job) => { let mut guard = self.context.shared();
self.execute(job); guard.jobs.remove(&self.index).or_else(|| guard.pop_job())
};
continue 'outer; match job {
} Some(job) => {
None => { tracing::trace!("thread {:?} found job: {:?}", self.index, job);
// TODO: wait on latch? if we have something that can Job::execute(job);
// signal being done, e.g. can be waited on instead of
// shared jobs, we should wait on it instead, but we
// would also want to receive shared jobs still?
// Spin? probably just wastes CPU time.
// self.context.shared_job.wait(&mut guard);
// if spin.spin() {
// // wait for more shared jobs.
// // self.context.shared_job.wait(&mut guard);
// return;
// }
// Yield? same as spinning, really, so just exit and let the upstream use wait
// std::thread::yield_now();
tracing::trace!("thread {:?} is sleeping", self.index); continue 'outer;
}
None => {
// TODO: wait on latch? if we have something that can
// signal being done, e.g. can be waited on instead of
// shared jobs, we should wait on it instead, but we
// would also want to receive shared jobs still?
// Spin? probably just wastes CPU time.
// self.context.shared_job.wait(&mut guard);
// if spin.spin() {
// // wait for more shared jobs.
// // self.context.shared_job.wait(&mut guard);
// return;
// }
// Yield? same as spinning, really, so just exit and let the upstream use wait
// std::thread::yield_now();
latch.set_sleeping(); tracing::trace!("thread {:?} is sleeping", self.index);
self.heartbeat.latch.wait_and_reset();
// since we were sleeping, the shared job can't be populated, match self.heartbeat.latch.wait_and_reset() {
// so resuming the inner loop is fine. // why were we woken up?
// 1. the heartbeat thread ticked and set the
// latch, so we should see if we have any work
// to share.
// 2. a job was shared and we were notified, so
// we should execute it.
// 3. the job we were waiting on was completed,
// so we should return it.
crate::latch::WakeResult::Set => {
break 'outer; // we were woken up by a job being set, so we should exit the loop.
}
crate::latch::WakeResult::Wake => {
// skip checking for local jobs, since we
// were woken up to check for shared jobs.
continue 'inner;
}
crate::latch::WakeResult::Heartbeat => {
self.tick();
continue 'outer;
}
}
// since we were sleeping, the shared job can't be populated,
// so resuming the inner loop is fine.
}
} }
} }
break;
} }
} }
tracing::trace!(
"thread {:?} finished waiting on latch {:?}",
self.index,
latch
);
return; return;
} }
@ -335,7 +364,7 @@ impl WorkerThread {
} else { } else {
// this isn't the job we are looking for, but we still need to // this isn't the job we are looking for, but we still need to
// execute it // execute it
self.execute(shared_job); Job::execute(shared_job);
} }
} }
@ -353,6 +382,7 @@ impl WorkerThread {
{ {
let latch = latch.as_core_latch(); let latch = latch.as_core_latch();
if !latch.probe() { if !latch.probe() {
tracing::trace!("thread {:?} waiting on latch {:?}", self.index, latch);
self.wait_until_latch_cold(latch) self.wait_until_latch_cold(latch)
} }
} }