executor/distaff/src/context.rs
2025-06-24 11:13:17 +02:00

280 lines
7.6 KiB
Rust

use std::{
ptr::NonNull,
sync::{
Arc, OnceLock, Weak,
atomic::{AtomicU8, Ordering},
},
};
use alloc::collections::BTreeMap;
use crossbeam_utils::CachePadded;
use parking_lot::{Condvar, Mutex};
use crate::{
job::{Job, StackJob},
latch::{LatchRef, MutexLatch, WakeLatch},
workerthread::{HeartbeatThread, WorkerThread},
};
pub struct Heartbeat {
heartbeat: AtomicU8,
pub latch: MutexLatch,
}
impl Heartbeat {
pub const CLEAR: u8 = 0;
pub const PENDING: u8 = 1;
pub const SLEEPING: u8 = 2;
pub fn new() -> (Arc<CachePadded<Self>>, Weak<CachePadded<Self>>) {
let strong = Arc::new(CachePadded::new(Self {
heartbeat: AtomicU8::new(Self::CLEAR),
latch: MutexLatch::new(),
}));
let weak = Arc::downgrade(&strong);
(strong, weak)
}
/// returns true if the heartbeat was previously sleeping.
pub fn set_pending(&self) -> bool {
let old = self.heartbeat.swap(Self::PENDING, Ordering::Relaxed);
old == Self::SLEEPING
}
pub fn clear(&self) {
self.heartbeat.store(Self::CLEAR, Ordering::Relaxed);
}
pub fn is_pending(&self) -> bool {
self.heartbeat.load(Ordering::Relaxed) == Self::PENDING
}
pub fn is_sleeping(&self) -> bool {
self.heartbeat.load(Ordering::Relaxed) == Self::SLEEPING
}
}
pub struct Context {
shared: Mutex<Shared>,
pub shared_job: Condvar,
}
pub(crate) struct Shared {
pub jobs: BTreeMap<usize, NonNull<Job>>,
pub heartbeats: BTreeMap<usize, Weak<CachePadded<Heartbeat>>>,
injected_jobs: Vec<NonNull<Job>>,
heartbeat_count: usize,
should_exit: bool,
}
unsafe impl Send for Shared {}
impl Shared {
pub fn new_heartbeat(&mut self) -> (Arc<CachePadded<Heartbeat>>, usize) {
let index = self.heartbeat_count;
self.heartbeat_count = index.wrapping_add(1);
let (strong, weak) = Heartbeat::new();
self.heartbeats.insert(index, weak);
(strong, index)
}
pub fn pop_job(&mut self) -> Option<NonNull<Job>> {
// this is unlikely, so make the function cold?
// TODO: profile this
if !self.injected_jobs.is_empty() {
unsafe { return Some(self.pop_injected_job()) };
} else {
self.jobs.pop_first().map(|(_, job)| job)
}
}
#[cold]
unsafe fn pop_injected_job(&mut self) -> NonNull<Job> {
self.injected_jobs.pop().unwrap()
}
pub fn should_exit(&self) -> bool {
self.should_exit
}
}
impl Context {
#[inline]
pub fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
self.shared.lock()
}
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
let this = Arc::new(Self {
shared: Mutex::new(Shared {
jobs: BTreeMap::new(),
heartbeats: BTreeMap::new(),
injected_jobs: Vec::new(),
heartbeat_count: 0,
should_exit: false,
}),
shared_job: Condvar::new(),
});
tracing::trace!("Creating thread pool with {} threads", num_threads);
// Create a barrier to synchronize the worker threads and the heartbeat thread
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2));
for i in 0..num_threads {
let ctx = this.clone();
let barrier = barrier.clone();
std::thread::Builder::new()
.name(format!("worker-{}", i))
.spawn(move || {
let worker = Box::new(WorkerThread::new_in(ctx));
barrier.wait();
worker.run();
})
.expect("Failed to spawn worker thread");
}
{
let ctx = this.clone();
let barrier = barrier.clone();
std::thread::Builder::new()
.name("heartbeat-thread".to_string())
.spawn(move || {
barrier.wait();
HeartbeatThread::new(ctx).run();
})
.expect("Failed to spawn heartbeat thread");
}
barrier.wait();
this
}
pub fn new() -> Arc<Self> {
Self::new_with_threads(crate::util::available_parallelism())
}
pub fn global_context() -> &'static Arc<Self> {
static GLOBAL_CONTEXT: OnceLock<Arc<Context>> = OnceLock::new();
GLOBAL_CONTEXT.get_or_init(|| Self::new())
}
pub fn inject_job(&self, job: NonNull<Job>) {
let mut shared = self.shared.lock();
shared.injected_jobs.push(job);
self.notify_shared_job();
}
pub fn notify_shared_job(&self) {
self.shared_job.notify_one();
}
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
let latch = WakeLatch::new(self.clone(), worker.index);
let job = StackJob::new(
move || {
let worker = WorkerThread::current_ref()
.expect("WorkerThread::run_in_worker called outside of worker thread");
f(worker)
},
LatchRef::new(&latch),
);
let job = job.as_job();
job.set_pending();
self.inject_job(Into::into(&job));
worker.wait_until_latch(&latch);
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
t
}
/// Run closure in this context, sleeping until the job is done.
pub fn run_in_worker_cold<T, F>(self: &Arc<Self>, f: F) -> T
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
use crate::latch::MutexLatch;
// current thread isn't a worker thread, create job and inject into global context
let latch = MutexLatch::new();
let job = StackJob::new(
move || {
let worker = WorkerThread::current_ref()
.expect("WorkerThread::run_in_worker called outside of worker thread");
f(worker)
},
LatchRef::new(&latch),
);
let job = job.as_job();
job.set_pending();
self.inject_job(Into::into(&job));
latch.wait();
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
t
}
/// Run closure in this context.
pub fn run_in_worker<T, F>(self: &Arc<Self>, f: F) -> T
where
T: Send,
F: FnOnce(&WorkerThread) -> T + Send,
{
match WorkerThread::current_ref() {
Some(worker) => {
// check if worker is in the same context
if Arc::ptr_eq(&worker.context, self) {
tracing::trace!("run_in_worker: current thread");
f(worker)
} else {
// current thread is a worker for a different context
tracing::trace!("run_in_worker: cross-context");
self.run_in_worker_cross(worker, f)
}
}
None => {
// current thread is not a worker for any context
tracing::trace!("run_in_worker: inject into context");
self.run_in_worker_cold(f)
}
}
}
}
pub fn run_in_worker<T, F>(f: F) -> T
where
T: Send,
F: FnOnce(&WorkerThread) -> T + Send,
{
Context::global_context().run_in_worker(f)
}