405 lines
12 KiB
Rust
405 lines
12 KiB
Rust
use std::{
|
|
ptr::NonNull,
|
|
sync::{
|
|
Arc, OnceLock,
|
|
atomic::{AtomicBool, Ordering},
|
|
},
|
|
};
|
|
|
|
use alloc::collections::BTreeMap;
|
|
|
|
use async_task::Runnable;
|
|
use parking_lot::{Condvar, Mutex};
|
|
|
|
use crate::{
|
|
channel::{Parker, Sender},
|
|
heartbeat::HeartbeatList,
|
|
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
|
|
latch::NopLatch,
|
|
util::DropGuard,
|
|
workerthread::{HeartbeatThread, WorkerThread},
|
|
};
|
|
|
|
pub struct Context {
|
|
shared: Mutex<Shared>,
|
|
pub shared_job: Condvar,
|
|
should_exit: AtomicBool,
|
|
pub heartbeats: HeartbeatList,
|
|
}
|
|
|
|
pub(crate) struct Shared {
|
|
pub jobs: BTreeMap<usize, SharedJob>,
|
|
injected_jobs: Vec<SharedJob>,
|
|
}
|
|
|
|
unsafe impl Send for Shared {}
|
|
|
|
impl Shared {
|
|
pub fn pop_job(&mut self) -> Option<SharedJob> {
|
|
// this is unlikely, so make the function cold?
|
|
// TODO: profile this
|
|
if !self.injected_jobs.is_empty() {
|
|
// SAFETY: we checked that injected_jobs is not empty
|
|
unsafe { return Some(self.pop_injected_job()) };
|
|
} else {
|
|
self.jobs.pop_first().map(|(_, job)| job)
|
|
}
|
|
}
|
|
|
|
#[cold]
|
|
unsafe fn pop_injected_job(&mut self) -> SharedJob {
|
|
self.injected_jobs.pop().unwrap()
|
|
}
|
|
}
|
|
|
|
impl Context {
|
|
#[inline]
|
|
pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
|
|
self.shared.lock()
|
|
}
|
|
|
|
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("Creating context with {} threads", num_threads);
|
|
|
|
let this = Arc::new(Self {
|
|
shared: Mutex::new(Shared {
|
|
jobs: BTreeMap::new(),
|
|
injected_jobs: Vec::new(),
|
|
}),
|
|
shared_job: Condvar::new(),
|
|
should_exit: AtomicBool::new(false),
|
|
heartbeats: HeartbeatList::new(),
|
|
});
|
|
|
|
// Create a barrier to synchronize the worker threads and the heartbeat thread
|
|
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2));
|
|
|
|
for i in 0..num_threads {
|
|
let ctx = this.clone();
|
|
let barrier = barrier.clone();
|
|
|
|
std::thread::Builder::new()
|
|
.name(format!("worker-{}", i))
|
|
.spawn(move || {
|
|
let worker = Box::new(WorkerThread::new_in(ctx));
|
|
|
|
worker.run(barrier);
|
|
})
|
|
.expect("Failed to spawn worker thread");
|
|
}
|
|
|
|
{
|
|
let ctx = this.clone();
|
|
let barrier = barrier.clone();
|
|
|
|
std::thread::Builder::new()
|
|
.name("heartbeat-thread".to_string())
|
|
.spawn(move || {
|
|
HeartbeatThread::new(ctx).run(barrier);
|
|
})
|
|
.expect("Failed to spawn heartbeat thread");
|
|
}
|
|
|
|
barrier.wait();
|
|
|
|
this
|
|
}
|
|
|
|
pub fn set_should_exit(&self) {
|
|
self.should_exit.store(true, Ordering::Relaxed);
|
|
self.heartbeats.notify_all();
|
|
}
|
|
|
|
pub fn should_exit(&self) -> bool {
|
|
self.should_exit.load(Ordering::Relaxed)
|
|
}
|
|
|
|
pub fn new() -> Arc<Self> {
|
|
Self::new_with_threads(crate::util::available_parallelism())
|
|
}
|
|
|
|
pub fn global_context() -> &'static Arc<Self> {
|
|
static GLOBAL_CONTEXT: OnceLock<Arc<Context>> = OnceLock::new();
|
|
|
|
GLOBAL_CONTEXT.get_or_init(|| Self::new())
|
|
}
|
|
|
|
pub fn inject_job(&self, job: SharedJob) {
|
|
let mut shared = self.shared.lock();
|
|
shared.injected_jobs.push(job);
|
|
|
|
unsafe {
|
|
// SAFETY: we are holding the shared lock, so it is safe to notify
|
|
self.notify_job_shared();
|
|
}
|
|
}
|
|
|
|
/// caller should hold the shared lock while calling this
|
|
pub unsafe fn notify_job_shared(&self) {
|
|
if let Some((i, sender)) = self
|
|
.heartbeats
|
|
.inner()
|
|
.iter()
|
|
.find(|(_, heartbeat)| heartbeat.is_waiting())
|
|
{
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("Notifying worker thread {} about job sharing", i);
|
|
sender.wake();
|
|
} else {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::warn!("No worker found to notify about job sharing");
|
|
}
|
|
}
|
|
|
|
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
|
fn run_in_worker_cross<T, F>(self: &Arc<Self>, worker: &WorkerThread, f: F) -> T
|
|
where
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
T: Send,
|
|
{
|
|
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
|
|
|
|
// SAFETY: we are waiting on this latch in this thread.
|
|
let job = StackJob::new(
|
|
move || {
|
|
let worker = WorkerThread::current_ref()
|
|
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
|
|
|
f(worker)
|
|
},
|
|
NopLatch,
|
|
);
|
|
|
|
let job = Job::from_stackjob(&job);
|
|
|
|
self.inject_job(job.share(Some(worker.heartbeat.parker())));
|
|
|
|
let t = worker.wait_until_shared_job(&job).unwrap();
|
|
|
|
crate::util::unwrap_or_panic(t)
|
|
}
|
|
|
|
/// Run closure in this context, sleeping until the job is done.
|
|
pub fn run_in_worker_cold<T, F>(self: &Arc<Self>, f: F) -> T
|
|
where
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
T: Send,
|
|
{
|
|
// current thread isn't a worker thread, create job and inject into context
|
|
let parker = Parker::new();
|
|
|
|
let job = StackJob::new(
|
|
move || {
|
|
let worker = WorkerThread::current_ref()
|
|
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
|
|
|
f(worker)
|
|
},
|
|
NopLatch,
|
|
);
|
|
|
|
let job = Job::from_stackjob(&job);
|
|
|
|
self.inject_job(job.share(Some(&parker)));
|
|
|
|
let recv = job.take_receiver().unwrap();
|
|
|
|
crate::util::unwrap_or_panic(recv.recv())
|
|
}
|
|
|
|
/// Run closure in this context.
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
pub fn run_in_worker<T, F>(self: &Arc<Self>, f: F) -> T
|
|
where
|
|
T: Send,
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
{
|
|
let _guard = DropGuard::new(|| {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("run_in_worker: finished");
|
|
});
|
|
match WorkerThread::current_ref() {
|
|
Some(worker) => {
|
|
// check if worker is in the same context
|
|
if Arc::ptr_eq(&worker.context, self) {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("run_in_worker: current thread");
|
|
f(worker)
|
|
} else {
|
|
// current thread is a worker for a different context
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("run_in_worker: cross-context");
|
|
self.run_in_worker_cross(worker, f)
|
|
}
|
|
}
|
|
None => {
|
|
// current thread is not a worker for any context
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("run_in_worker: inject into context");
|
|
self.run_in_worker_cold(f)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Context {
|
|
pub fn spawn<F>(self: &Arc<Self>, f: F)
|
|
where
|
|
F: FnOnce() + Send + 'static,
|
|
{
|
|
let job = Job::from_heapjob(Box::new(HeapJob::new(f)));
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("Context::spawn: spawning job: {:?}", job);
|
|
self.inject_job(job.share(None));
|
|
}
|
|
|
|
pub fn spawn_future<T, F>(self: &Arc<Self>, future: F) -> async_task::Task<T>
|
|
where
|
|
F: Future<Output = T> + Send + 'static,
|
|
T: Send + 'static,
|
|
{
|
|
let schedule = move |runnable: Runnable| {
|
|
#[align(8)]
|
|
unsafe fn harness<T>(_: &WorkerThread, this: NonNull<()>, _: Option<Sender>) {
|
|
unsafe {
|
|
let runnable = Runnable::<()>::from_raw(this);
|
|
runnable.run();
|
|
}
|
|
}
|
|
|
|
let job = Job::<T>::from_harness(harness::<T>, runnable.into_raw());
|
|
|
|
self.inject_job(job.share(None));
|
|
};
|
|
|
|
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
|
|
|
runnable.schedule();
|
|
|
|
task
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn spawn_async<T, Fut, Fn>(self: &Arc<Self>, f: Fn) -> async_task::Task<T>
|
|
where
|
|
Fn: FnOnce() -> Fut + Send + 'static,
|
|
Fut: Future<Output = T> + Send + 'static,
|
|
T: Send + 'static,
|
|
{
|
|
let future = async move { f().await };
|
|
|
|
self.spawn_future(future)
|
|
}
|
|
}
|
|
|
|
pub fn run_in_worker<T, F>(f: F) -> T
|
|
where
|
|
T: Send,
|
|
F: FnOnce(&WorkerThread) -> T + Send,
|
|
{
|
|
Context::global_context().run_in_worker(f)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::atomic::AtomicU8;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn run_in_worker() {
|
|
let ctx = Context::global_context().clone();
|
|
let result = ctx.run_in_worker(|_| 42);
|
|
assert_eq!(result, 42);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn context_spawn_future() {
|
|
let ctx = Context::global_context().clone();
|
|
let task = ctx.spawn_future(async { 42 });
|
|
|
|
// Wait for the task to complete
|
|
let result = futures::executor::block_on(task);
|
|
assert_eq!(result, 42);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn context_spawn_async() {
|
|
let ctx = Context::global_context().clone();
|
|
let task = ctx.spawn_async(|| async { 42 });
|
|
|
|
// Wait for the task to complete
|
|
let result = futures::executor::block_on(task);
|
|
assert_eq!(result, 42);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn context_spawn() {
|
|
let ctx = Context::global_context().clone();
|
|
let counter = Arc::new(AtomicU8::new(0));
|
|
let barrier = Arc::new(std::sync::Barrier::new(2));
|
|
|
|
ctx.spawn({
|
|
let counter = counter.clone();
|
|
let barrier = barrier.clone();
|
|
move || {
|
|
counter.fetch_add(1, Ordering::SeqCst);
|
|
barrier.wait();
|
|
}
|
|
});
|
|
|
|
barrier.wait();
|
|
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
|
}
|
|
|
|
#[test]
|
|
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
|
fn inject_job_and_wake_worker() {
|
|
let ctx = Context::new_with_threads(1);
|
|
let counter = Arc::new(AtomicU8::new(0));
|
|
|
|
let parker = Parker::new();
|
|
|
|
let job = StackJob::new(
|
|
{
|
|
let counter = counter.clone();
|
|
move || {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::info!("Job running");
|
|
counter.fetch_add(1, Ordering::SeqCst);
|
|
|
|
42
|
|
}
|
|
},
|
|
NopLatch,
|
|
);
|
|
|
|
let job = Job::from_stackjob(&job);
|
|
|
|
// wait for the worker to sleep
|
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
|
|
|
ctx.heartbeats
|
|
.inner()
|
|
.iter_mut()
|
|
.next()
|
|
.map(|(_, heartbeat)| {
|
|
assert!(heartbeat.is_waiting());
|
|
});
|
|
|
|
ctx.inject_job(job.share(Some(&parker)));
|
|
|
|
// Wait for the job to be executed
|
|
let recv = job.take_receiver().unwrap();
|
|
let result = recv.recv();
|
|
let result = crate::util::unwrap_or_panic(result);
|
|
assert_eq!(result, 42);
|
|
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
|
}
|
|
}
|