executor/distaff/src/context.rs
2025-07-01 11:54:39 +02:00

405 lines
12 KiB
Rust

use std::{
ptr::NonNull,
sync::{
Arc, OnceLock,
atomic::{AtomicBool, Ordering},
},
};
use alloc::collections::BTreeMap;
use async_task::Runnable;
use parking_lot::{Condvar, Mutex};
use crate::{
channel::{Parker, Sender},
heartbeat::HeartbeatList,
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
latch::NopLatch,
util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread},
};
pub struct Context {
shared: Mutex<Shared>,
pub shared_job: Condvar,
should_exit: AtomicBool,
pub heartbeats: HeartbeatList,
}
pub(crate) struct Shared {
pub jobs: BTreeMap<usize, SharedJob>,
injected_jobs: Vec<SharedJob>,
}
unsafe impl Send for Shared {}
impl Shared {
pub fn pop_job(&mut self) -> Option<SharedJob> {
// this is unlikely, so make the function cold?
// TODO: profile this
if !self.injected_jobs.is_empty() {
// SAFETY: we checked that injected_jobs is not empty
unsafe { return Some(self.pop_injected_job()) };
} else {
self.jobs.pop_first().map(|(_, job)| job)
}
}
#[cold]
unsafe fn pop_injected_job(&mut self) -> SharedJob {
self.injected_jobs.pop().unwrap()
}
}
impl Context {
#[inline]
pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
self.shared.lock()
}
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
#[cfg(feature = "tracing")]
tracing::trace!("Creating context with {} threads", num_threads);
let this = Arc::new(Self {
shared: Mutex::new(Shared {
jobs: BTreeMap::new(),
injected_jobs: Vec::new(),
}),
shared_job: Condvar::new(),
should_exit: AtomicBool::new(false),
heartbeats: HeartbeatList::new(),
});
// Create a barrier to synchronize the worker threads and the heartbeat thread
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2));
for i in 0..num_threads {
let ctx = this.clone();
let barrier = barrier.clone();
std::thread::Builder::new()
.name(format!("worker-{}", i))
.spawn(move || {
let worker = Box::new(WorkerThread::new_in(ctx));
worker.run(barrier);
})
.expect("Failed to spawn worker thread");
}
{
let ctx = this.clone();
let barrier = barrier.clone();
std::thread::Builder::new()
.name("heartbeat-thread".to_string())
.spawn(move || {
HeartbeatThread::new(ctx).run(barrier);
})
.expect("Failed to spawn heartbeat thread");
}
barrier.wait();
this
}
pub fn set_should_exit(&self) {
self.should_exit.store(true, Ordering::Relaxed);
self.heartbeats.notify_all();
}
pub fn should_exit(&self) -> bool {
self.should_exit.load(Ordering::Relaxed)
}
pub fn new() -> Arc<Self> {
Self::new_with_threads(crate::util::available_parallelism())
}
pub fn global_context() -> &'static Arc<Self> {
static GLOBAL_CONTEXT: OnceLock<Arc<Context>> = OnceLock::new();
GLOBAL_CONTEXT.get_or_init(|| Self::new())
}
pub fn inject_job(&self, job: SharedJob) {
let mut shared = self.shared.lock();
shared.injected_jobs.push(job);
unsafe {
// SAFETY: we are holding the shared lock, so it is safe to notify
self.notify_job_shared();
}
}
/// caller should hold the shared lock while calling this
pub unsafe fn notify_job_shared(&self) {
if let Some((i, sender)) = self
.heartbeats
.inner()
.iter()
.find(|(_, heartbeat)| heartbeat.is_waiting())
{
#[cfg(feature = "tracing")]
tracing::trace!("Notifying worker thread {} about job sharing", i);
sender.wake();
} else {
#[cfg(feature = "tracing")]
tracing::warn!("No worker found to notify about job sharing");
}
}
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
// SAFETY: we are waiting on this latch in this thread.
let job = StackJob::new(
move || {
let worker = WorkerThread::current_ref()
.expect("WorkerThread::run_in_worker called outside of worker thread");
f(worker)
},
NopLatch,
);
let job = Job::from_stackjob(&job);
self.inject_job(job.share(Some(worker.heartbeat.parker())));
let t = worker.wait_until_shared_job(&job).unwrap();
crate::util::unwrap_or_panic(t)
}
/// Run closure in this context, sleeping until the job is done.
pub fn run_in_worker_cold<T, F>(self: &Arc<Self>, f: F) -> T
where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
// current thread isn't a worker thread, create job and inject into context
let parker = Parker::new();
let job = StackJob::new(
move || {
let worker = WorkerThread::current_ref()
.expect("WorkerThread::run_in_worker called outside of worker thread");
f(worker)
},
NopLatch,
);
let job = Job::from_stackjob(&job);
self.inject_job(job.share(Some(&parker)));
let recv = job.take_receiver().unwrap();
crate::util::unwrap_or_panic(recv.recv())
}
/// Run closure in this context.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn run_in_worker<T, F>(self: &Arc<Self>, f: F) -> T
where
T: Send,
F: FnOnce(&WorkerThread) -> T + Send,
{
let _guard = DropGuard::new(|| {
#[cfg(feature = "tracing")]
tracing::trace!("run_in_worker: finished");
});
match WorkerThread::current_ref() {
Some(worker) => {
// check if worker is in the same context
if Arc::ptr_eq(&worker.context, self) {
#[cfg(feature = "tracing")]
tracing::trace!("run_in_worker: current thread");
f(worker)
} else {
// current thread is a worker for a different context
#[cfg(feature = "tracing")]
tracing::trace!("run_in_worker: cross-context");
self.run_in_worker_cross(worker, f)
}
}
None => {
// current thread is not a worker for any context
#[cfg(feature = "tracing")]
tracing::trace!("run_in_worker: inject into context");
self.run_in_worker_cold(f)
}
}
}
}
impl Context {
pub fn spawn<F>(self: &Arc<Self>, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Job::from_heapjob(Box::new(HeapJob::new(f)));
#[cfg(feature = "tracing")]
tracing::trace!("Context::spawn: spawning job: {:?}", job);
self.inject_job(job.share(None));
}
pub fn spawn_future<T, F>(self: &Arc<Self>, future: F) -> async_task::Task<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let schedule = move |runnable: Runnable| {
#[align(8)]
unsafe fn harness<T>(_: &WorkerThread, this: NonNull<()>, _: Option<Sender>) {
unsafe {
let runnable = Runnable::<()>::from_raw(this);
runnable.run();
}
}
let job = Job::<T>::from_harness(harness::<T>, runnable.into_raw());
self.inject_job(job.share(None));
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
#[allow(dead_code)]
fn spawn_async<T, Fut, Fn>(self: &Arc<Self>, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let future = async move { f().await };
self.spawn_future(future)
}
}
pub fn run_in_worker<T, F>(f: F) -> T
where
T: Send,
F: FnOnce(&WorkerThread) -> T + Send,
{
Context::global_context().run_in_worker(f)
}
#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicU8;
use super::*;
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn run_in_worker() {
let ctx = Context::global_context().clone();
let result = ctx.run_in_worker(|_| 42);
assert_eq!(result, 42);
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn context_spawn_future() {
let ctx = Context::global_context().clone();
let task = ctx.spawn_future(async { 42 });
// Wait for the task to complete
let result = futures::executor::block_on(task);
assert_eq!(result, 42);
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn context_spawn_async() {
let ctx = Context::global_context().clone();
let task = ctx.spawn_async(|| async { 42 });
// Wait for the task to complete
let result = futures::executor::block_on(task);
assert_eq!(result, 42);
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn context_spawn() {
let ctx = Context::global_context().clone();
let counter = Arc::new(AtomicU8::new(0));
let barrier = Arc::new(std::sync::Barrier::new(2));
ctx.spawn({
let counter = counter.clone();
let barrier = barrier.clone();
move || {
counter.fetch_add(1, Ordering::SeqCst);
barrier.wait();
}
});
barrier.wait();
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn inject_job_and_wake_worker() {
let ctx = Context::new_with_threads(1);
let counter = Arc::new(AtomicU8::new(0));
let parker = Parker::new();
let job = StackJob::new(
{
let counter = counter.clone();
move || {
#[cfg(feature = "tracing")]
tracing::info!("Job running");
counter.fetch_add(1, Ordering::SeqCst);
42
}
},
NopLatch,
);
let job = Job::from_stackjob(&job);
// wait for the worker to sleep
std::thread::sleep(std::time::Duration::from_millis(100));
ctx.heartbeats
.inner()
.iter_mut()
.next()
.map(|(_, heartbeat)| {
assert!(heartbeat.is_waiting());
});
ctx.inject_job(job.share(Some(&parker)));
// Wait for the job to be executed
let recv = job.take_receiver().unwrap();
let result = recv.recv();
let result = crate::util::unwrap_or_panic(result);
assert_eq!(result, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}