use std::{ ptr::NonNull, sync::{ Arc, OnceLock, Weak, atomic::{AtomicU8, Ordering}, }, }; use alloc::collections::BTreeMap; use crossbeam_utils::CachePadded; use parking_lot::{Condvar, Mutex}; use crate::{ job::{Job, StackJob}, latch::{LatchRef, MutexLatch, WakeLatch}, workerthread::{HeartbeatThread, WorkerThread}, }; pub struct Heartbeat { heartbeat: AtomicU8, pub latch: MutexLatch, } impl Heartbeat { pub const CLEAR: u8 = 0; pub const PENDING: u8 = 1; pub const SLEEPING: u8 = 2; pub fn new() -> (Arc>, Weak>) { let strong = Arc::new(CachePadded::new(Self { heartbeat: AtomicU8::new(Self::CLEAR), latch: MutexLatch::new(), })); let weak = Arc::downgrade(&strong); (strong, weak) } /// returns true if the heartbeat was previously sleeping. pub fn set_pending(&self) -> bool { let old = self.heartbeat.swap(Self::PENDING, Ordering::Relaxed); old == Self::SLEEPING } pub fn clear(&self) { self.heartbeat.store(Self::CLEAR, Ordering::Relaxed); } pub fn is_pending(&self) -> bool { self.heartbeat.load(Ordering::Relaxed) == Self::PENDING } pub fn is_sleeping(&self) -> bool { self.heartbeat.load(Ordering::Relaxed) == Self::SLEEPING } } pub struct Context { shared: Mutex, pub shared_job: Condvar, } pub(crate) struct Shared { pub jobs: BTreeMap>, pub heartbeats: BTreeMap>>, injected_jobs: Vec>, heartbeat_count: usize, should_exit: bool, } unsafe impl Send for Shared {} impl Shared { pub fn new_heartbeat(&mut self) -> (Arc>, usize) { let index = self.heartbeat_count; self.heartbeat_count = index.wrapping_add(1); let (strong, weak) = Heartbeat::new(); self.heartbeats.insert(index, weak); (strong, index) } pub fn pop_job(&mut self) -> Option> { // this is unlikely, so make the function cold? // TODO: profile this if !self.injected_jobs.is_empty() { unsafe { return Some(self.pop_injected_job()) }; } else { self.jobs.pop_first().map(|(_, job)| job) } } #[cold] unsafe fn pop_injected_job(&mut self) -> NonNull { self.injected_jobs.pop().unwrap() } pub fn should_exit(&self) -> bool { self.should_exit } } impl Context { #[inline] pub fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> { self.shared.lock() } pub fn new_with_threads(num_threads: usize) -> Arc { let this = Arc::new(Self { shared: Mutex::new(Shared { jobs: BTreeMap::new(), heartbeats: BTreeMap::new(), injected_jobs: Vec::new(), heartbeat_count: 0, should_exit: false, }), shared_job: Condvar::new(), }); tracing::trace!("Creating thread pool with {} threads", num_threads); // Create a barrier to synchronize the worker threads and the heartbeat thread let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2)); for i in 0..num_threads { let ctx = this.clone(); let barrier = barrier.clone(); std::thread::Builder::new() .name(format!("worker-{}", i)) .spawn(move || { let worker = Box::new(WorkerThread::new_in(ctx)); barrier.wait(); worker.run(); }) .expect("Failed to spawn worker thread"); } { let ctx = this.clone(); let barrier = barrier.clone(); std::thread::Builder::new() .name("heartbeat-thread".to_string()) .spawn(move || { barrier.wait(); HeartbeatThread::new(ctx).run(); }) .expect("Failed to spawn heartbeat thread"); } barrier.wait(); this } pub fn new() -> Arc { Self::new_with_threads(crate::util::available_parallelism()) } pub fn global_context() -> &'static Arc { static GLOBAL_CONTEXT: OnceLock> = OnceLock::new(); GLOBAL_CONTEXT.get_or_init(|| Self::new()) } pub fn inject_job(&self, job: NonNull) { let mut shared = self.shared.lock(); shared.injected_jobs.push(job); self.notify_shared_job(); } pub fn notify_shared_job(&self) { self.shared_job.notify_one(); } /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. fn run_in_worker_cross(self: &Arc, worker: &WorkerThread, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, T: Send, { // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. let latch = WakeLatch::new(self.clone(), worker.index); let job = StackJob::new( move || { let worker = WorkerThread::current_ref() .expect("WorkerThread::run_in_worker called outside of worker thread"); f(worker) }, LatchRef::new(&latch), ); let job = job.as_job(); job.set_pending(); self.inject_job(Into::into(&job)); worker.wait_until_latch(&latch); let t = unsafe { job.transmute_ref::().wait().into_result() }; t } /// Run closure in this context, sleeping until the job is done. pub fn run_in_worker_cold(self: &Arc, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, T: Send, { use crate::latch::MutexLatch; // current thread isn't a worker thread, create job and inject into global context let latch = MutexLatch::new(); let job = StackJob::new( move || { let worker = WorkerThread::current_ref() .expect("WorkerThread::run_in_worker called outside of worker thread"); f(worker) }, LatchRef::new(&latch), ); let job = job.as_job(); job.set_pending(); self.inject_job(Into::into(&job)); latch.wait(); let t = unsafe { job.transmute_ref::().wait().into_result() }; t } /// Run closure in this context. pub fn run_in_worker(self: &Arc, f: F) -> T where T: Send, F: FnOnce(&WorkerThread) -> T + Send, { match WorkerThread::current_ref() { Some(worker) => { // check if worker is in the same context if Arc::ptr_eq(&worker.context, self) { tracing::trace!("run_in_worker: current thread"); f(worker) } else { // current thread is a worker for a different context tracing::trace!("run_in_worker: cross-context"); self.run_in_worker_cross(worker, f) } } None => { // current thread is not a worker for any context tracing::trace!("run_in_worker: inject into context"); self.run_in_worker_cold(f) } } } } pub fn run_in_worker(f: F) -> T where T: Send, F: FnOnce(&WorkerThread) -> T + Send, { Context::global_context().run_in_worker(f) }