use std::{ ptr::NonNull, sync::{ Arc, OnceLock, atomic::{AtomicBool, Ordering}, }, }; use alloc::collections::BTreeMap; use async_task::Runnable; use parking_lot::{Condvar, Mutex}; use crate::{ channel::{Parker, Sender}, heartbeat::HeartbeatList, job::{HeapJob, Job2 as Job, SharedJob, StackJob}, latch::NopLatch, util::DropGuard, workerthread::{HeartbeatThread, WorkerThread}, }; pub struct Context { shared: Mutex, pub shared_job: Condvar, should_exit: AtomicBool, pub heartbeats: HeartbeatList, } pub(crate) struct Shared { pub jobs: BTreeMap, injected_jobs: Vec, } unsafe impl Send for Shared {} impl Shared { pub fn pop_job(&mut self) -> Option { // this is unlikely, so make the function cold? // TODO: profile this if !self.injected_jobs.is_empty() { // SAFETY: we checked that injected_jobs is not empty unsafe { return Some(self.pop_injected_job()) }; } else { self.jobs.pop_first().map(|(_, job)| job) } } #[cold] unsafe fn pop_injected_job(&mut self) -> SharedJob { self.injected_jobs.pop().unwrap() } } impl Context { #[inline] pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> { self.shared.lock() } pub fn new_with_threads(num_threads: usize) -> Arc { #[cfg(feature = "tracing")] tracing::trace!("Creating context with {} threads", num_threads); let this = Arc::new(Self { shared: Mutex::new(Shared { jobs: BTreeMap::new(), injected_jobs: Vec::new(), }), shared_job: Condvar::new(), should_exit: AtomicBool::new(false), heartbeats: HeartbeatList::new(), }); // Create a barrier to synchronize the worker threads and the heartbeat thread let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2)); for i in 0..num_threads { let ctx = this.clone(); let barrier = barrier.clone(); std::thread::Builder::new() .name(format!("worker-{}", i)) .spawn(move || { let worker = Box::new(WorkerThread::new_in(ctx)); worker.run(barrier); }) .expect("Failed to spawn worker thread"); } { let ctx = this.clone(); let barrier = barrier.clone(); std::thread::Builder::new() .name("heartbeat-thread".to_string()) .spawn(move || { HeartbeatThread::new(ctx).run(barrier); }) .expect("Failed to spawn heartbeat thread"); } barrier.wait(); this } pub fn set_should_exit(&self) { self.should_exit.store(true, Ordering::Relaxed); self.heartbeats.notify_all(); } pub fn should_exit(&self) -> bool { self.should_exit.load(Ordering::Relaxed) } pub fn new() -> Arc { Self::new_with_threads(crate::util::available_parallelism()) } pub fn global_context() -> &'static Arc { static GLOBAL_CONTEXT: OnceLock> = OnceLock::new(); GLOBAL_CONTEXT.get_or_init(|| Self::new()) } pub fn inject_job(&self, job: SharedJob) { let mut shared = self.shared.lock(); shared.injected_jobs.push(job); unsafe { // SAFETY: we are holding the shared lock, so it is safe to notify self.notify_job_shared(); } } /// caller should hold the shared lock while calling this pub unsafe fn notify_job_shared(&self) { if let Some((i, sender)) = self .heartbeats .inner() .iter() .find(|(_, heartbeat)| heartbeat.is_waiting()) { #[cfg(feature = "tracing")] tracing::trace!("Notifying worker thread {} about job sharing", i); sender.wake(); } else { #[cfg(feature = "tracing")] tracing::warn!("No worker found to notify about job sharing"); } } /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. fn run_in_worker_cross(self: &Arc, worker: &WorkerThread, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, T: Send, { // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. // SAFETY: we are waiting on this latch in this thread. let job = StackJob::new( move || { let worker = WorkerThread::current_ref() .expect("WorkerThread::run_in_worker called outside of worker thread"); f(worker) }, NopLatch, ); let job = Job::from_stackjob(&job); self.inject_job(job.share(Some(worker.heartbeat.parker()))); let t = worker.wait_until_shared_job(&job).unwrap(); crate::util::unwrap_or_panic(t) } /// Run closure in this context, sleeping until the job is done. pub fn run_in_worker_cold(self: &Arc, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, T: Send, { // current thread isn't a worker thread, create job and inject into context let parker = Parker::new(); let job = StackJob::new( move || { let worker = WorkerThread::current_ref() .expect("WorkerThread::run_in_worker called outside of worker thread"); f(worker) }, NopLatch, ); let job = Job::from_stackjob(&job); self.inject_job(job.share(Some(&parker))); let recv = job.take_receiver().unwrap(); crate::util::unwrap_or_panic(recv.recv()) } /// Run closure in this context. #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))] pub fn run_in_worker(self: &Arc, f: F) -> T where T: Send, F: FnOnce(&WorkerThread) -> T + Send, { let _guard = DropGuard::new(|| { #[cfg(feature = "tracing")] tracing::trace!("run_in_worker: finished"); }); match WorkerThread::current_ref() { Some(worker) => { // check if worker is in the same context if Arc::ptr_eq(&worker.context, self) { #[cfg(feature = "tracing")] tracing::trace!("run_in_worker: current thread"); f(worker) } else { // current thread is a worker for a different context #[cfg(feature = "tracing")] tracing::trace!("run_in_worker: cross-context"); self.run_in_worker_cross(worker, f) } } None => { // current thread is not a worker for any context #[cfg(feature = "tracing")] tracing::trace!("run_in_worker: inject into context"); self.run_in_worker_cold(f) } } } } impl Context { pub fn spawn(self: &Arc, f: F) where F: FnOnce() + Send + 'static, { let job = Job::from_heapjob(Box::new(HeapJob::new(f))); #[cfg(feature = "tracing")] tracing::trace!("Context::spawn: spawning job: {:?}", job); self.inject_job(job.share(None)); } pub fn spawn_future(self: &Arc, future: F) -> async_task::Task where F: Future + Send + 'static, T: Send + 'static, { let schedule = move |runnable: Runnable| { #[align(8)] unsafe fn harness(_: &WorkerThread, this: NonNull<()>, _: Option) { unsafe { let runnable = Runnable::<()>::from_raw(this); runnable.run(); } } let job = Job::::from_harness(harness::, runnable.into_raw()); self.inject_job(job.share(None)); }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; runnable.schedule(); task } #[allow(dead_code)] fn spawn_async(self: &Arc, f: Fn) -> async_task::Task where Fn: FnOnce() -> Fut + Send + 'static, Fut: Future + Send + 'static, T: Send + 'static, { let future = async move { f().await }; self.spawn_future(future) } } pub fn run_in_worker(f: F) -> T where T: Send, F: FnOnce(&WorkerThread) -> T + Send, { Context::global_context().run_in_worker(f) } #[cfg(test)] mod tests { use std::sync::atomic::AtomicU8; use super::*; #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn run_in_worker() { let ctx = Context::global_context().clone(); let result = ctx.run_in_worker(|_| 42); assert_eq!(result, 42); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn context_spawn_future() { let ctx = Context::global_context().clone(); let task = ctx.spawn_future(async { 42 }); // Wait for the task to complete let result = futures::executor::block_on(task); assert_eq!(result, 42); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn context_spawn_async() { let ctx = Context::global_context().clone(); let task = ctx.spawn_async(|| async { 42 }); // Wait for the task to complete let result = futures::executor::block_on(task); assert_eq!(result, 42); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn context_spawn() { let ctx = Context::global_context().clone(); let counter = Arc::new(AtomicU8::new(0)); let barrier = Arc::new(std::sync::Barrier::new(2)); ctx.spawn({ let counter = counter.clone(); let barrier = barrier.clone(); move || { counter.fetch_add(1, Ordering::SeqCst); barrier.wait(); } }); barrier.wait(); assert_eq!(counter.load(Ordering::SeqCst), 1); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn inject_job_and_wake_worker() { let ctx = Context::new_with_threads(1); let counter = Arc::new(AtomicU8::new(0)); let parker = Parker::new(); let job = StackJob::new( { let counter = counter.clone(); move || { #[cfg(feature = "tracing")] tracing::info!("Job running"); counter.fetch_add(1, Ordering::SeqCst); 42 } }, NopLatch, ); let job = Job::from_stackjob(&job); // wait for the worker to sleep std::thread::sleep(std::time::Duration::from_millis(100)); ctx.heartbeats .inner() .iter_mut() .next() .map(|(_, heartbeat)| { assert!(heartbeat.is_waiting()); }); ctx.inject_job(job.share(Some(&parker))); // Wait for the job to be executed let recv = job.take_receiver().unwrap(); let result = recv.recv(); let result = crate::util::unwrap_or_panic(result); assert_eq!(result, 42); assert_eq!(counter.load(Ordering::SeqCst), 1); } }