executor/distaff/src/workerthread.rs

301 lines
9.1 KiB
Rust

#[cfg(feature = "metrics")]
use std::sync::atomic::Ordering;
use std::{
cell::{Cell, UnsafeCell},
ptr::NonNull,
sync::{Arc, Barrier},
time::Duration,
};
#[cfg(feature = "metrics")]
use werkzeug::CachePadded;
use crate::{
channel::Receiver,
context::{Context, Message},
heartbeat::OwnedHeartbeatReceiver,
job::{Job2 as Job, JobQueue as JobList, SharedJob},
queue,
util::DropGuard,
};
pub struct WorkerThread {
pub(crate) context: Arc<Context>,
pub(crate) receiver: queue::Receiver<Message>,
pub(crate) queue: UnsafeCell<JobList>,
pub(crate) heartbeat: OwnedHeartbeatReceiver,
pub(crate) join_count: Cell<u8>,
#[cfg(feature = "metrics")]
pub(crate) metrics: CachePadded<crate::metrics::WorkerMetrics>,
}
thread_local! {
static WORKER: UnsafeCell<Option<NonNull<WorkerThread>>> = const { UnsafeCell::new(None) };
}
impl WorkerThread {
pub fn new_in(context: Arc<Context>) -> Self {
let heartbeat = context.heartbeats.new_heartbeat();
Self {
receiver: context.queue.new_receiver(),
context,
queue: UnsafeCell::new(JobList::new()),
heartbeat,
join_count: Cell::new(0),
#[cfg(feature = "metrics")]
metrics: CachePadded::new(crate::metrics::WorkerMetrics::default()),
}
}
}
impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(
worker = self.heartbeat.index(),
)))]
pub fn run(self: Box<Self>, barrier: Arc<Barrier>) {
let this = Box::into_raw(self);
unsafe {
Self::set_current(this);
}
let _guard = DropGuard::new(|| unsafe {
// SAFETY: this is only called when the thread is exiting
Self::unset_current();
Self::drop_in_place(this);
});
#[cfg(feature = "tracing")]
tracing::trace!("WorkerThread::run: starting worker thread");
barrier.wait();
unsafe {
(&*this).run_inner();
}
#[cfg(feature = "metrics")]
unsafe {
eprintln!("{:?}", (&*this).metrics);
}
#[cfg(feature = "tracing")]
tracing::trace!("WorkerThread::run: worker thread finished");
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn run_inner(&self) {
loop {
if self.context.should_exit() {
break;
}
match self.receiver.recv() {
Message::Shared(shared_job) => {
self.execute(shared_job);
}
Message::Finished(werkzeug::util::Send(ptr)) => {
#[cfg(feature = "tracing")]
tracing::error!(
"WorkerThread::run_inner: received finished message: {:?}",
ptr
);
}
Message::Exit => break,
Message::ScopeFinished => {}
}
}
}
}
impl WorkerThread {
/// Checks if the worker thread has received a heartbeat, and if so,
/// attempts to share a job with other workers. If a job was popped from
/// the queue, but not shared, this function runs the job locally.
pub(crate) fn tick(&self) {
if self.heartbeat.take() {
#[cfg(feature = "metrics")]
self.metrics.num_heartbeats.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!(
"received heartbeat, thread id: {:?}",
self.heartbeat.index()
);
self.heartbeat_cold();
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn execute(&self, job: SharedJob) {
unsafe { SharedJob::execute(job, self) };
// TODO: maybe tick here?
}
/// Attempts to share a job with other workers within the same context.
/// returns `true` if the job was shared, `false` if it was not.
#[cold]
fn heartbeat_cold(&self) {
if let Some(job) = self.pop_back() {
#[cfg(feature = "tracing")]
tracing::trace!("heartbeat: sharing job: {:?}", job);
#[cfg(feature = "metrics")]
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
if let Err(Message::Shared(job)) =
self.context
.queue
.as_sender()
.try_anycast(Message::Shared(unsafe {
job.as_ref().share(Some(self.receiver.get_token()))
}))
{
unsafe {
SharedJob::execute(job, self);
}
}
}
}
}
impl WorkerThread {
pub fn pop_back(&self) -> Option<NonNull<Job>> {
unsafe { self.queue.as_mut_unchecked().pop_back() }
}
pub fn push_back<T>(&self, job: *const Job<T>) {
unsafe { self.queue.as_mut_unchecked().push_back(job.cast()) }
}
pub fn push_front<T>(&self, job: *const Job<T>) {
unsafe { self.queue.as_mut_unchecked().push_front(job.cast()) }
}
pub fn pop_front(&self) -> Option<NonNull<Job>> {
unsafe { self.queue.as_mut_unchecked().pop_front() }
}
}
impl WorkerThread {
pub fn current_ref<'a>() -> Option<&'a Self> {
unsafe { (*WORKER.with(UnsafeCell::get)).map(|ptr| ptr.as_ref()) }
}
unsafe fn set_current(this: *const Self) {
WORKER.with(|cell| {
unsafe {
// SAFETY: this cell is only ever accessed from the current thread
assert!(
(&mut *cell.get())
.replace(NonNull::new_unchecked(
this as *const WorkerThread as *mut WorkerThread,
))
.is_none()
);
}
});
}
unsafe fn unset_current() {
WORKER.with(|cell| {
unsafe {
// SAFETY: this cell is only ever accessed from the current thread
(&mut *cell.get()).take();
}
});
}
unsafe fn drop_in_place(this: *mut Self) {
unsafe {
// SAFETY: this is only called when the thread is exiting, so we can
// safely drop the thread. We use `drop_in_place` to prevent `Box`
// from creating a no-alias reference to the worker thread.
core::ptr::drop_in_place(this);
_ = Box::<core::mem::ManuallyDrop<Self>>::from_raw(this as _);
}
}
}
pub struct HeartbeatThread {
ctx: Arc<Context>,
num_workers: usize,
}
impl HeartbeatThread {
const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
pub fn new(ctx: Arc<Context>, num_workers: usize) -> Self {
Self { ctx, num_workers }
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn run(self, barrier: Arc<Barrier>) {
#[cfg(feature = "tracing")]
tracing::trace!("new heartbeat thread {:?}", std::thread::current());
barrier.wait();
let mut i = 0;
loop {
let sleep_for = {
// loop {
// if self.ctx.should_exit() || self.ctx.queue.num_receivers() != self.num_workers
// {
// break;
// }
// self.ctx.heartbeat.park();
// }
if self.ctx.should_exit() {
break;
}
self.ctx.heartbeats.notify_nth(i);
let num_heartbeats = self.ctx.heartbeats.len();
if i >= num_heartbeats {
i = 0;
} else {
i += 1;
}
Self::HEARTBEAT_INTERVAL.checked_div(num_heartbeats as u32)
};
if let Some(duration) = sleep_for {
std::thread::sleep(duration);
}
}
}
}
impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> std::thread::Result<T> {
loop {
match self.receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self);
},
Message::Finished(send) => {
break unsafe { *Box::from_non_null(send.0.cast()) };
}
Message::Exit | Message::ScopeFinished => {}
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_recv<T: Send>(&self) -> std::thread::Result<T> {
loop {
match self.receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self);
},
Message::Finished(send) => break unsafe { *Box::from_non_null(send.0.cast()) },
Message::Exit | Message::ScopeFinished => {}
}
}
}
}