280 lines
7.6 KiB
Rust
280 lines
7.6 KiB
Rust
use std::{
|
|
ptr::NonNull,
|
|
sync::{
|
|
Arc, OnceLock, Weak,
|
|
atomic::{AtomicU8, Ordering},
|
|
},
|
|
};
|
|
|
|
use alloc::collections::BTreeMap;
|
|
|
|
use crossbeam_utils::CachePadded;
|
|
use parking_lot::{Condvar, Mutex};
|
|
|
|
use crate::{
|
|
job::{Job, StackJob},
|
|
latch::{LatchRef, MutexLatch, WakeLatch},
|
|
workerthread::{HeartbeatThread, WorkerThread},
|
|
};
|
|
|
|
pub struct Heartbeat {
|
|
heartbeat: AtomicU8,
|
|
pub latch: MutexLatch,
|
|
}
|
|
|
|
impl Heartbeat {
|
|
pub const CLEAR: u8 = 0;
|
|
pub const PENDING: u8 = 1;
|
|
pub const SLEEPING: u8 = 2;
|
|
|
|
pub fn new() -> (Arc<CachePadded<Self>>, Weak<CachePadded<Self>>) {
|
|
let strong = Arc::new(CachePadded::new(Self {
|
|
heartbeat: AtomicU8::new(Self::CLEAR),
|
|
latch: MutexLatch::new(),
|
|
}));
|
|
let weak = Arc::downgrade(&strong);
|
|
|
|
(strong, weak)
|
|
}
|
|
|
|
/// returns true if the heartbeat was previously sleeping.
|
|
pub fn set_pending(&self) -> bool {
|
|
let old = self.heartbeat.swap(Self::PENDING, Ordering::Relaxed);
|
|
old == Self::SLEEPING
|
|
}
|
|
|
|
pub fn clear(&self) {
|
|
self.heartbeat.store(Self::CLEAR, Ordering::Relaxed);
|
|
}
|
|
|
|
pub fn is_pending(&self) -> bool {
|
|
self.heartbeat.load(Ordering::Relaxed) == Self::PENDING
|
|
}
|
|
|
|
pub fn is_sleeping(&self) -> bool {
|
|
self.heartbeat.load(Ordering::Relaxed) == Self::SLEEPING
|
|
}
|
|
}
|
|
|
|
pub struct Context {
|
|
shared: Mutex<Shared>,
|
|
pub shared_job: Condvar,
|
|
}
|
|
|
|
pub(crate) struct Shared {
|
|
pub jobs: BTreeMap<usize, NonNull<Job>>,
|
|
pub heartbeats: BTreeMap<usize, Weak<CachePadded<Heartbeat>>>,
|
|
injected_jobs: Vec<NonNull<Job>>,
|
|
heartbeat_count: usize,
|
|
should_exit: bool,
|
|
}
|
|
|
|
unsafe impl Send for Shared {}
|
|
|
|
impl Shared {
|
|
pub fn new_heartbeat(&mut self) -> (Arc<CachePadded<Heartbeat>>, usize) {
|
|
let index = self.heartbeat_count;
|
|
self.heartbeat_count = index.wrapping_add(1);
|
|
|
|
let (strong, weak) = Heartbeat::new();
|
|
|
|
self.heartbeats.insert(index, weak);
|
|
|
|
(strong, index)
|
|
}
|
|
|
|
pub fn pop_job(&mut self) -> Option<NonNull<Job>> {
|
|
// this is unlikely, so make the function cold?
|
|
// TODO: profile this
|
|
if !self.injected_jobs.is_empty() {
|
|
unsafe { return Some(self.pop_injected_job()) };
|
|
} else {
|
|
self.jobs.pop_first().map(|(_, job)| job)
|
|
}
|
|
}
|
|
|
|
#[cold]
|
|
unsafe fn pop_injected_job(&mut self) -> NonNull<Job> {
|
|
self.injected_jobs.pop().unwrap()
|
|
}
|
|
|
|
pub fn should_exit(&self) -> bool {
|
|
self.should_exit
|
|
}
|
|
}
|
|
|
|
impl Context {
|
|
#[inline]
|
|
pub fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
|
|
self.shared.lock()
|
|
}
|
|
|
|
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
|
|
let this = Arc::new(Self {
|
|
shared: Mutex::new(Shared {
|
|
jobs: BTreeMap::new(),
|
|
heartbeats: BTreeMap::new(),
|
|
injected_jobs: Vec::new(),
|
|
heartbeat_count: 0,
|
|
should_exit: false,
|
|
}),
|
|
shared_job: Condvar::new(),
|
|
});
|
|
|
|
tracing::trace!("Creating thread pool with {} threads", num_threads);
|
|
|
|
// Create a barrier to synchronize the worker threads and the heartbeat thread
|
|
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2));
|
|
|
|
for i in 0..num_threads {
|
|
let ctx = this.clone();
|
|
let barrier = barrier.clone();
|
|
|
|
std::thread::Builder::new()
|
|
.name(format!("worker-{}", i))
|
|
.spawn(move || {
|
|
let worker = Box::new(WorkerThread::new_in(ctx));
|
|
|
|
barrier.wait();
|
|
worker.run();
|
|
})
|
|
.expect("Failed to spawn worker thread");
|
|
}
|
|
|
|
{
|
|
let ctx = this.clone();
|
|
let barrier = barrier.clone();
|
|
|
|
std::thread::Builder::new()
|
|
.name("heartbeat-thread".to_string())
|
|
.spawn(move || {
|
|
barrier.wait();
|
|
HeartbeatThread::new(ctx).run();
|
|
})
|
|
.expect("Failed to spawn heartbeat thread");
|
|
}
|
|
|
|
barrier.wait();
|
|
|
|
this
|
|
}
|
|
|
|
pub fn new() -> Arc<Self> {
|
|
Self::new_with_threads(crate::util::available_parallelism())
|
|
}
|
|
|
|
pub fn global_context() -> &'static Arc<Self> {
|
|
static GLOBAL_CONTEXT: OnceLock<Arc<Context>> = OnceLock::new();
|
|
|
|
GLOBAL_CONTEXT.get_or_init(|| Self::new())
|
|
}
|
|
|
|
pub fn inject_job(&self, job: NonNull<Job>) {
|
|
let mut shared = self.shared.lock();
|
|
shared.injected_jobs.push(job);
|
|
self.notify_shared_job();
|
|
}
|
|
|
|
pub fn notify_shared_job(&self) {
|
|
self.shared_job.notify_one();
|
|
}
|
|
|
|
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
|
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
|
|
where
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
T: Send,
|
|
{
|
|
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
|
|
|
|
let latch = WakeLatch::new(self.clone(), worker.index);
|
|
|
|
let job = StackJob::new(
|
|
move || {
|
|
let worker = WorkerThread::current_ref()
|
|
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
|
|
|
f(worker)
|
|
},
|
|
LatchRef::new(&latch),
|
|
);
|
|
|
|
let job = job.as_job();
|
|
job.set_pending();
|
|
|
|
self.inject_job(Into::into(&job));
|
|
|
|
worker.wait_until_latch(&latch);
|
|
|
|
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
|
|
|
|
t
|
|
}
|
|
|
|
/// Run closure in this context, sleeping until the job is done.
|
|
pub fn run_in_worker_cold<T, F>(self: &Arc<Self>, f: F) -> T
|
|
where
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
T: Send,
|
|
{
|
|
use crate::latch::MutexLatch;
|
|
// current thread isn't a worker thread, create job and inject into global context
|
|
|
|
let latch = MutexLatch::new();
|
|
|
|
let job = StackJob::new(
|
|
move || {
|
|
let worker = WorkerThread::current_ref()
|
|
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
|
|
|
f(worker)
|
|
},
|
|
LatchRef::new(&latch),
|
|
);
|
|
|
|
let job = job.as_job();
|
|
job.set_pending();
|
|
|
|
self.inject_job(Into::into(&job));
|
|
latch.wait();
|
|
|
|
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
|
|
|
|
t
|
|
}
|
|
|
|
/// Run closure in this context.
|
|
pub fn run_in_worker<T, F>(self: &Arc<Self>, f: F) -> T
|
|
where
|
|
T: Send,
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
{
|
|
match WorkerThread::current_ref() {
|
|
Some(worker) => {
|
|
// check if worker is in the same context
|
|
if Arc::ptr_eq(&worker.context, self) {
|
|
tracing::trace!("run_in_worker: current thread");
|
|
f(worker)
|
|
} else {
|
|
// current thread is a worker for a different context
|
|
tracing::trace!("run_in_worker: cross-context");
|
|
self.run_in_worker_cross(worker, f)
|
|
}
|
|
}
|
|
None => {
|
|
// current thread is not a worker for any context
|
|
tracing::trace!("run_in_worker: inject into context");
|
|
self.run_in_worker_cold(f)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn run_in_worker<T, F>(f: F) -> T
|
|
where
|
|
T: Send,
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
{
|
|
Context::global_context().run_in_worker(f)
|
|
}
|