executor/src/melange.rs
2025-02-20 19:25:31 +01:00

722 lines
20 KiB
Rust

use std::{
collections::VecDeque,
marker::PhantomPinned,
ptr::NonNull,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
thread,
time::Duration,
};
use crossbeam::utils::CachePadded;
use parking_lot::{Condvar, Mutex};
use crate::{latch::*, ThreadControl};
mod job {
use core::{
cell::UnsafeCell,
mem::{self, ManuallyDrop, MaybeUninit},
sync::atomic::{AtomicU8, Ordering},
};
use std::thread::Thread;
use parking_lot_core::SpinWait;
use crate::util::SendPtr;
use super::WorkerThread;
#[allow(dead_code)]
#[cfg_attr(target_pointer_width = "64", repr(align(16)))]
#[cfg_attr(target_pointer_width = "32", repr(align(8)))]
#[derive(Debug, Default, Clone, Copy)]
struct Size2([usize; 2]);
struct Value<T>(pub MaybeUninit<Box<MaybeUninit<T>>>);
impl<T> Value<T> {
unsafe fn get(self, inline: bool) -> T {
if inline {
unsafe { mem::transmute_copy(&self.0) }
} else {
unsafe { (*self.0.assume_init()).assume_init() }
}
}
}
#[repr(u8)]
pub enum JobState {
Empty,
Locked = 1,
Pending,
Finished,
Inline = 1 << (u8::BITS - 1),
}
pub struct Job<T = ()> {
state: AtomicU8,
this: SendPtr<()>,
harness: unsafe fn(*const (), *const Job<()>, &mut WorkerThread),
maybe_boxed_val: UnsafeCell<MaybeUninit<Value<T>>>,
waiting_thread: UnsafeCell<Option<Thread>>,
}
impl<T> Job<T> {
pub unsafe fn cast_box<U>(self: Box<Self>) -> Box<Job<U>>
where
T: Sized,
U: Sized,
{
let ptr = Box::into_raw(self);
Box::from_raw(ptr.cast())
}
pub unsafe fn cast<U>(self: &Self) -> &Job<U>
where
T: Sized,
U: Sized,
{
// SAFETY: both T and U are sized, so Box<T> and Box<U> should be the
// same size as well.
unsafe { mem::transmute(self) }
}
pub fn state(&self) -> u8 {
self.state.load(Ordering::Relaxed) & !(JobState::Inline as u8)
}
pub fn wait(&self) -> T {
let mut state = self.state.load(Ordering::Relaxed);
let mask = JobState::Inline as u8;
let mut spin = SpinWait::new();
loop {
match self.state.compare_exchange(
JobState::Pending as u8 | (state & mask),
JobState::Locked as u8 | (state & mask),
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(x) => {
state = x;
unsafe {
*self.waiting_thread.get() = Some(std::thread::current());
}
self.state
.store(JobState::Pending as u8 | (state & mask), Ordering::Release);
std::thread::park();
spin.reset();
continue;
}
Err(x) => {
if x & JobState::Finished as u8 != 0 {
let val = unsafe {
let value = (&*self.maybe_boxed_val.get()).assume_init_read();
value.get(state & JobState::Inline as u8 != 0)
};
return val;
} else {
spin.spin();
}
}
}
}
}
/// call this when popping value from local queue
pub fn set_pending(&self) {
let state = self.state.load(Ordering::Relaxed);
let mask = JobState::Inline as u8;
let mut spin = SpinWait::new();
loop {
match self.state.compare_exchange(
JobState::Empty as u8 | (state & mask),
JobState::Pending as u8 | (state & mask),
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => {
return;
}
Err(_) => {
spin.spin();
}
}
}
}
pub fn execute(&self, s: &mut WorkerThread) {
// SAFETY: self is non-null
unsafe { (self.harness)(self.this.as_ptr().cast(), (self as *const Self).cast(), s) };
}
#[allow(dead_code)]
fn complete(&self, result: T) {
let mut state = self.state.load(Ordering::Relaxed);
let mask = JobState::Inline as u8;
let mut spin = SpinWait::new();
loop {
match self.state.compare_exchange(
JobState::Pending as u8 | (state & mask),
JobState::Locked as u8 | (state & mask),
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(x) => {
state = x;
break;
}
Err(_) => {
spin.spin();
}
}
}
unsafe {
let value = (&mut *self.maybe_boxed_val.get()).assume_init_mut();
// SAFETY: we know the box is allocated if state was `Pending`.
if state & JobState::Inline as u8 == 0 {
value.0 = MaybeUninit::new(Box::new(MaybeUninit::new(result)));
} else {
*mem::transmute::<_, &mut T>(&mut value.0) = result;
}
}
if let Some(thread) = unsafe { &mut *self.waiting_thread.get() }.take() {
thread.unpark();
}
self.state
.store(JobState::Finished as u8 | (state & mask), Ordering::Release);
}
}
impl Job {}
#[allow(dead_code)]
pub struct HeapJob<F> {
f: F,
}
impl<F> HeapJob<F> {
#[allow(dead_code)]
pub fn new(f: F) -> Box<Self> {
Box::new(Self { f })
}
#[allow(dead_code)]
pub fn into_boxed_job<T>(self: Box<Self>) -> Box<Job<T>>
where
F: FnOnce(&mut WorkerThread) -> T + Send,
T: Send,
{
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, s: &mut WorkerThread)
where
F: FnOnce(&mut WorkerThread) -> T + Send,
T: Sized + Send,
{
let job = unsafe { &*job.cast::<Job<T>>() };
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
let f = this.f;
job.complete(f(s));
}
let size = mem::size_of::<T>();
let align = mem::align_of::<T>();
let new_state = if size > mem::size_of::<Box<T>>() || align > mem::align_of::<Box<T>>()
{
JobState::Empty as u8
} else {
JobState::Inline as u8
};
Box::new(Job {
state: AtomicU8::new(new_state),
this: SendPtr::new(Box::into_raw(self)).unwrap().cast(),
waiting_thread: UnsafeCell::new(None),
harness: harness::<F, T>,
maybe_boxed_val: UnsafeCell::new(MaybeUninit::uninit()),
})
}
}
impl<T> crate::latch::Probe for &Job<T> {
fn probe(&self) -> bool {
self.state() == JobState::Finished as u8
}
}
#[allow(dead_code)]
pub struct StackJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
}
impl<F> StackJob<F> {
#[allow(dead_code)]
pub fn new(f: F) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
}
}
#[allow(dead_code)]
pub unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
#[allow(dead_code)]
pub fn as_job<T>(&self) -> Job<T>
where
F: FnOnce(&mut WorkerThread) -> T + Send,
T: Send,
{
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, s: &mut WorkerThread)
where
F: FnOnce(&mut WorkerThread) -> T + Send,
T: Send,
{
let job = unsafe { &*job.cast::<Job<T>>() };
let this = unsafe { &*this.cast::<StackJob<F>>() };
let f = unsafe { this.unwrap() };
job.complete(f(s));
}
let size = mem::size_of::<T>();
let align = mem::align_of::<T>();
let new_state = if size > mem::size_of::<Box<T>>() || align > mem::align_of::<Box<T>>()
{
JobState::Empty as u8
} else {
JobState::Inline as u8
};
Job {
state: AtomicU8::new(new_state),
this: SendPtr::new(self).unwrap().cast(),
waiting_thread: UnsafeCell::new(None),
harness: harness::<F, T>,
maybe_boxed_val: UnsafeCell::new(MaybeUninit::uninit()),
}
}
}
}
//use job::{Future, Job, JobQueue, JobStack};
use crate::job::v2::{Job as JobArchetype, JobState, StackJob};
// use crate::job::{Job, JobRef, StackJob};
type Job<T = ()> = JobArchetype<T, WorkerThread>;
struct ThreadState {
control: ThreadControl,
}
struct Heartbeat {
is_set: Weak<AtomicBool>,
}
pub struct SharedContext {
shared_tasks: Vec<Option<NonNull<Job>>>,
heartbeats: Vec<Option<Heartbeat>>,
rng: crate::rng::XorShift64Star,
}
// SAFETY: Job is Send
unsafe impl Send for SharedContext {}
pub struct Context {
shared: Mutex<SharedContext>,
threads: Box<[CachePadded<ThreadState>]>,
heartbeat_control: CachePadded<ThreadControl>,
task_shared: Condvar,
}
pub struct ThreadPool {
context: Arc<Context>,
}
impl SharedContext {
fn new_heartbeat(&mut self) -> (Arc<AtomicBool>, usize) {
let is_set = Arc::new(AtomicBool::new(true));
let heartbeat = Heartbeat {
is_set: Arc::downgrade(&is_set),
};
let index = match self.heartbeats.iter().position(|a| a.is_none()) {
Some(i) => {
self.heartbeats[i] = Some(heartbeat);
i
}
None => {
self.heartbeats.push(Some(heartbeat));
self.shared_tasks.push(None);
self.heartbeats.len() - 1
}
};
(is_set, index)
}
fn pop_first_task(&mut self) -> Option<NonNull<Job>> {
self.shared_tasks
.iter_mut()
.filter_map(|task| task.take())
.next()
}
#[allow(dead_code)]
fn pop_random_task(&mut self) -> Option<NonNull<Job>> {
let i = self.rng.next_usize(self.shared_tasks.len());
let (a, b) = self.shared_tasks.split_at_mut(i);
a.into_iter().chain(b).filter_map(|task| task.take()).next()
}
}
pub struct WorkerThread {
context: Arc<Context>,
index: usize,
queue: VecDeque<NonNull<Job>>,
heartbeat: Arc<AtomicBool>,
join_count: u8,
_marker: PhantomPinned,
}
// SAFETY: Job is Send
unsafe impl Send for WorkerThread {}
impl WorkerThread {
fn new(context: Arc<Context>, heartbeat: Arc<AtomicBool>, index: usize) -> WorkerThread {
WorkerThread {
context,
index,
queue: VecDeque::default(),
join_count: 0,
heartbeat,
_marker: PhantomPinned,
}
}
#[allow(dead_code)]
fn state(&self) -> &CachePadded<ThreadState> {
&self.context.threads[self.index]
}
fn control(&self) -> &ThreadControl {
&self.context.threads[self.index].control
}
fn shared(&self) -> &Mutex<SharedContext> {
&self.context.shared
}
fn ctx(&self) -> &Arc<Context> {
&self.context
}
}
impl WorkerThread {
fn worker(mut self) {
self.control().notify_running();
'outer: loop {
// inner look runs until no shared tasks exist.
loop {
if self.control().should_terminate.probe() {
break 'outer;
}
let task = { self.shared().lock().pop_first_task() };
if let Some(task) = task {
self.execute_job(task);
} else {
break;
}
}
// signal heartbeat thread that we would really like another task
//self.ctx().heartbeat_control.wake();
// spin here maybe?
// wait to be signaled since no more shared tasks exist.
let mut guard = self.shared().lock();
self.ctx().task_shared.wait(&mut guard);
}
self.control().notify_termination();
}
fn execute_job(&mut self, job: NonNull<Job>) {
self.heartbeat();
unsafe {
job.as_ref().execute(self);
}
}
#[inline]
fn heartbeat(&mut self) {
if self.heartbeat.load(Ordering::Relaxed) {
self.heartbeat_cold();
}
}
#[cold]
fn heartbeat_cold(&mut self) {
let mut guard = self.context.shared.lock();
if guard.shared_tasks[self.index].is_none() {
if let Some(task) = self.queue.pop_front() {
unsafe {
task.as_ref().set_pending();
}
guard.shared_tasks[self.index] = Some(task);
self.context.task_shared.notify_one();
}
}
self.heartbeat.store(false, Ordering::Relaxed);
}
pub fn join<A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
self.join_with_every::<64, _, _, _, _>(a, b)
}
pub fn join_with_every<const T: u8, A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
self.join_count = self.join_count.wrapping_add(1) % T;
if self.join_count == 0 || self.queue.len() < 3 {
self.join_heartbeat(a, b)
} else {
self.join_seq(a, b)
}
}
fn join_seq<A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
let rb = b(self);
let ra = a(self);
(ra, rb)
}
pub fn join_heartbeat<A, B, RA, RB>(&mut self, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&mut WorkerThread) -> RA + Send,
B: FnOnce(&mut WorkerThread) -> RB + Send,
RA: Send,
RB: Send,
{
let b = StackJob::new(b);
let job = Box::into_raw(Box::new(b.as_job()));
let job_ref = unsafe { &*job };
// let job = Box::new(b.as_job());
self.queue
.push_back(unsafe { NonNull::new_unchecked(job as *mut _) });
let ra = a(self);
let rb =
if job_ref.state() == JobState::Empty as u8 && self.pop_job_ptr(job.cast()).is_some() {
unsafe { b.unwrap()(self) }
} else {
self.run_until(job_ref);
job_ref.wait()
};
let _job = unsafe { Box::from_raw(job) };
(ra, rb)
}
fn pop_job_ptr(&mut self, id: *const Job) -> Option<&Job<()>> {
self.queue
.iter()
.rposition(|job| job.as_ptr() == id.cast_mut())
.and_then(|i| self.queue.remove(i))
.map(|job| unsafe { job.as_ref() })
}
fn run_until<L: Probe>(&mut self, latch: &L) {
if !latch.probe() {
self.run_until_cold(latch);
}
}
#[cold]
fn run_until_cold<L: Probe>(&mut self, latch: &L) {
let job = self.shared().lock().shared_tasks[self.index].take();
if let Some(job) = job {
self.execute_job(job);
}
while !latch.probe() {
let job = self.context.shared.lock().pop_first_task();
if let Some(job) = job {
self.execute_job(job);
}
}
}
}
impl Context {
#[allow(dead_code)]
fn heartbeat(self: Arc<Self>, interaval: Duration) {
let mut n = 0;
loop {
if self.heartbeat_control.should_terminate.probe() {
break;
}
let sleep_for = {
let guard = self.shared.lock();
let num_heartbeats = guard
.heartbeats
.iter()
.filter_map(Option::as_ref)
.filter_map(|h| h.is_set.upgrade().map(|is_set| is_set))
.enumerate()
.inspect(|(i, is_set)| {
if *i == n {
is_set.store(true, Ordering::Relaxed);
}
})
.count();
if n >= num_heartbeats {
n = 0;
} else {
n += 1;
}
interaval.checked_div(num_heartbeats as u32)
};
if let Some(duration) = sleep_for {
thread::sleep(duration);
}
}
}
#[allow(dead_code)]
fn heartbeat2(self: Arc<Self>, interval: Duration) {
let mut i = 0;
loop {
if self.heartbeat_control.should_terminate.probe() {
break;
}
let sleep_for = {
let guard = self.shared.lock();
let mut num = 0;
for is_set in guard
.heartbeats
.iter()
.filter_map(Option::as_ref)
.filter_map(|h| h.is_set.upgrade())
{
if num == i {
is_set.store(true, Ordering::Relaxed);
}
num += 1;
}
if num >= i {
i = 0;
}
interval.checked_div(num)
};
if let Some(duration) = sleep_for {
self.heartbeat_control
.wait_for_should_wake_timeout(duration);
// thread::sleep(duration);
}
}
}
}
impl Drop for Context {
fn drop(&mut self) {
for thread in &self.threads {
thread.control.notify_should_terminate();
}
self.heartbeat_control.notify_should_terminate();
for thread in &self.threads {
thread.control.wait_for_termination();
}
self.heartbeat_control.wait_for_termination();
}
}
impl ThreadPool {
pub fn new_worker(&self) -> WorkerThread {
let (heartbeat, index) = self.context.shared.lock().new_heartbeat();
WorkerThread::new(self.context.clone(), heartbeat, index)
}
pub fn new(num_threads: usize) -> ThreadPool {
let threads = (0..num_threads)
.map(|_| {
CachePadded::new(ThreadState {
control: ThreadControl::new(),
})
})
.collect::<Box<_>>();
let context = Arc::new(Context {
shared: Mutex::new(SharedContext {
shared_tasks: Vec::with_capacity(num_threads),
heartbeats: Vec::with_capacity(num_threads),
rng: crate::rng::XorShift64Star::new(num_threads as u64),
}),
threads,
heartbeat_control: CachePadded::new(ThreadControl::new()),
task_shared: Condvar::new(),
});
let this = Self { context };
for _ in 0..num_threads {
let worker = this.new_worker();
std::thread::spawn(move || {
worker.worker();
});
}
let ctx = this.context.clone();
std::thread::spawn(|| {
ctx.heartbeat2(Duration::from_micros(100));
});
this
}
}