301 lines
9.1 KiB
Rust
301 lines
9.1 KiB
Rust
#[cfg(feature = "metrics")]
|
|
use std::sync::atomic::Ordering;
|
|
|
|
use std::{
|
|
cell::{Cell, UnsafeCell},
|
|
ptr::NonNull,
|
|
sync::{Arc, Barrier},
|
|
time::Duration,
|
|
};
|
|
|
|
#[cfg(feature = "metrics")]
|
|
use werkzeug::CachePadded;
|
|
|
|
use crate::{
|
|
channel::Receiver,
|
|
context::{Context, Message},
|
|
heartbeat::OwnedHeartbeatReceiver,
|
|
job::{Job2 as Job, JobQueue as JobList, SharedJob},
|
|
queue,
|
|
util::DropGuard,
|
|
};
|
|
|
|
pub struct WorkerThread {
|
|
pub(crate) context: Arc<Context>,
|
|
pub(crate) receiver: queue::Receiver<Message>,
|
|
pub(crate) queue: UnsafeCell<JobList>,
|
|
pub(crate) heartbeat: OwnedHeartbeatReceiver,
|
|
pub(crate) join_count: Cell<u8>,
|
|
|
|
#[cfg(feature = "metrics")]
|
|
pub(crate) metrics: CachePadded<crate::metrics::WorkerMetrics>,
|
|
}
|
|
|
|
thread_local! {
|
|
static WORKER: UnsafeCell<Option<NonNull<WorkerThread>>> = const { UnsafeCell::new(None) };
|
|
}
|
|
|
|
impl WorkerThread {
|
|
pub fn new_in(context: Arc<Context>) -> Self {
|
|
let heartbeat = context.heartbeats.new_heartbeat();
|
|
|
|
Self {
|
|
receiver: context.queue.new_receiver(),
|
|
context,
|
|
queue: UnsafeCell::new(JobList::new()),
|
|
heartbeat,
|
|
join_count: Cell::new(0),
|
|
#[cfg(feature = "metrics")]
|
|
metrics: CachePadded::new(crate::metrics::WorkerMetrics::default()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl WorkerThread {
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(
|
|
worker = self.heartbeat.index(),
|
|
)))]
|
|
pub fn run(self: Box<Self>, barrier: Arc<Barrier>) {
|
|
let this = Box::into_raw(self);
|
|
unsafe {
|
|
Self::set_current(this);
|
|
}
|
|
|
|
let _guard = DropGuard::new(|| unsafe {
|
|
// SAFETY: this is only called when the thread is exiting
|
|
Self::unset_current();
|
|
Self::drop_in_place(this);
|
|
});
|
|
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("WorkerThread::run: starting worker thread");
|
|
|
|
barrier.wait();
|
|
unsafe {
|
|
(&*this).run_inner();
|
|
}
|
|
|
|
#[cfg(feature = "metrics")]
|
|
unsafe {
|
|
eprintln!("{:?}", (&*this).metrics);
|
|
}
|
|
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("WorkerThread::run: worker thread finished");
|
|
}
|
|
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
fn run_inner(&self) {
|
|
loop {
|
|
if self.context.should_exit() {
|
|
break;
|
|
}
|
|
|
|
match self.receiver.recv() {
|
|
Message::Shared(shared_job) => {
|
|
self.execute(shared_job);
|
|
}
|
|
Message::Finished(werkzeug::util::Send(ptr)) => {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::error!(
|
|
"WorkerThread::run_inner: received finished message: {:?}",
|
|
ptr
|
|
);
|
|
}
|
|
Message::Exit => break,
|
|
Message::ScopeFinished => {}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl WorkerThread {
|
|
/// Checks if the worker thread has received a heartbeat, and if so,
|
|
/// attempts to share a job with other workers. If a job was popped from
|
|
/// the queue, but not shared, this function runs the job locally.
|
|
pub(crate) fn tick(&self) {
|
|
if self.heartbeat.take() {
|
|
#[cfg(feature = "metrics")]
|
|
self.metrics.num_heartbeats.fetch_add(1, Ordering::Relaxed);
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!(
|
|
"received heartbeat, thread id: {:?}",
|
|
self.heartbeat.index()
|
|
);
|
|
|
|
self.heartbeat_cold();
|
|
}
|
|
}
|
|
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
fn execute(&self, job: SharedJob) {
|
|
unsafe { SharedJob::execute(job, self) };
|
|
// TODO: maybe tick here?
|
|
}
|
|
|
|
/// Attempts to share a job with other workers within the same context.
|
|
/// returns `true` if the job was shared, `false` if it was not.
|
|
#[cold]
|
|
fn heartbeat_cold(&self) {
|
|
if let Some(job) = self.pop_back() {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("heartbeat: sharing job: {:?}", job);
|
|
|
|
#[cfg(feature = "metrics")]
|
|
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
|
|
|
|
if let Err(Message::Shared(job)) =
|
|
self.context
|
|
.queue
|
|
.as_sender()
|
|
.try_anycast(Message::Shared(unsafe {
|
|
job.as_ref().share(Some(self.receiver.get_token()))
|
|
}))
|
|
{
|
|
unsafe {
|
|
SharedJob::execute(job, self);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl WorkerThread {
|
|
pub fn pop_back(&self) -> Option<NonNull<Job>> {
|
|
unsafe { self.queue.as_mut_unchecked().pop_back() }
|
|
}
|
|
|
|
pub fn push_back<T>(&self, job: *const Job<T>) {
|
|
unsafe { self.queue.as_mut_unchecked().push_back(job.cast()) }
|
|
}
|
|
pub fn push_front<T>(&self, job: *const Job<T>) {
|
|
unsafe { self.queue.as_mut_unchecked().push_front(job.cast()) }
|
|
}
|
|
|
|
pub fn pop_front(&self) -> Option<NonNull<Job>> {
|
|
unsafe { self.queue.as_mut_unchecked().pop_front() }
|
|
}
|
|
}
|
|
|
|
impl WorkerThread {
|
|
pub fn current_ref<'a>() -> Option<&'a Self> {
|
|
unsafe { (*WORKER.with(UnsafeCell::get)).map(|ptr| ptr.as_ref()) }
|
|
}
|
|
|
|
unsafe fn set_current(this: *const Self) {
|
|
WORKER.with(|cell| {
|
|
unsafe {
|
|
// SAFETY: this cell is only ever accessed from the current thread
|
|
assert!(
|
|
(&mut *cell.get())
|
|
.replace(NonNull::new_unchecked(
|
|
this as *const WorkerThread as *mut WorkerThread,
|
|
))
|
|
.is_none()
|
|
);
|
|
}
|
|
});
|
|
}
|
|
|
|
unsafe fn unset_current() {
|
|
WORKER.with(|cell| {
|
|
unsafe {
|
|
// SAFETY: this cell is only ever accessed from the current thread
|
|
(&mut *cell.get()).take();
|
|
}
|
|
});
|
|
}
|
|
|
|
unsafe fn drop_in_place(this: *mut Self) {
|
|
unsafe {
|
|
// SAFETY: this is only called when the thread is exiting, so we can
|
|
// safely drop the thread. We use `drop_in_place` to prevent `Box`
|
|
// from creating a no-alias reference to the worker thread.
|
|
core::ptr::drop_in_place(this);
|
|
_ = Box::<core::mem::ManuallyDrop<Self>>::from_raw(this as _);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct HeartbeatThread {
|
|
ctx: Arc<Context>,
|
|
num_workers: usize,
|
|
}
|
|
|
|
impl HeartbeatThread {
|
|
const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
|
|
|
|
pub fn new(ctx: Arc<Context>, num_workers: usize) -> Self {
|
|
Self { ctx, num_workers }
|
|
}
|
|
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
|
|
pub fn run(self, barrier: Arc<Barrier>) {
|
|
#[cfg(feature = "tracing")]
|
|
tracing::trace!("new heartbeat thread {:?}", std::thread::current());
|
|
barrier.wait();
|
|
|
|
let mut i = 0;
|
|
loop {
|
|
let sleep_for = {
|
|
// loop {
|
|
// if self.ctx.should_exit() || self.ctx.queue.num_receivers() != self.num_workers
|
|
// {
|
|
// break;
|
|
// }
|
|
|
|
// self.ctx.heartbeat.park();
|
|
// }
|
|
if self.ctx.should_exit() {
|
|
break;
|
|
}
|
|
|
|
self.ctx.heartbeats.notify_nth(i);
|
|
let num_heartbeats = self.ctx.heartbeats.len();
|
|
|
|
if i >= num_heartbeats {
|
|
i = 0;
|
|
} else {
|
|
i += 1;
|
|
}
|
|
|
|
Self::HEARTBEAT_INTERVAL.checked_div(num_heartbeats as u32)
|
|
};
|
|
|
|
if let Some(duration) = sleep_for {
|
|
std::thread::sleep(duration);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl WorkerThread {
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
|
|
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> std::thread::Result<T> {
|
|
loop {
|
|
match self.receiver.recv() {
|
|
Message::Shared(shared_job) => unsafe {
|
|
SharedJob::execute(shared_job, self);
|
|
},
|
|
Message::Finished(send) => {
|
|
break unsafe { *Box::from_non_null(send.0.cast()) };
|
|
}
|
|
Message::Exit | Message::ScopeFinished => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
|
pub fn wait_until_recv<T: Send>(&self) -> std::thread::Result<T> {
|
|
loop {
|
|
match self.receiver.recv() {
|
|
Message::Shared(shared_job) => unsafe {
|
|
SharedJob::execute(shared_job, self);
|
|
},
|
|
Message::Finished(send) => break unsafe { *Box::from_non_null(send.0.cast()) },
|
|
Message::Exit | Message::ScopeFinished => {}
|
|
}
|
|
}
|
|
}
|
|
}
|