From e8a07ce6a546119b0b39e65cb3d56a9cf431ef58 Mon Sep 17 00:00:00 2001 From: Janis Date: Tue, 24 Jun 2025 11:13:17 +0200 Subject: [PATCH] export into distaff crate --- Cargo.toml | 3 + distaff/Cargo.toml | 16 + distaff/rust-toolchain | 1 + distaff/src/context.rs | 279 +++++++++++++++ distaff/src/job.rs | 666 ++++++++++++++++++++++++++++++++++++ distaff/src/join.rs | 110 ++++++ distaff/src/latch.rs | 367 ++++++++++++++++++++ distaff/src/lib.rs | 22 ++ distaff/src/scope.rs | 267 +++++++++++++++ distaff/src/threadpool.rs | 1 + distaff/src/util.rs | 404 ++++++++++++++++++++++ distaff/src/workerthread.rs | 396 +++++++++++++++++++++ 12 files changed, 2532 insertions(+) create mode 100644 distaff/Cargo.toml create mode 100644 distaff/rust-toolchain create mode 100644 distaff/src/context.rs create mode 100644 distaff/src/job.rs create mode 100644 distaff/src/join.rs create mode 100644 distaff/src/latch.rs create mode 100644 distaff/src/lib.rs create mode 100644 distaff/src/scope.rs create mode 100644 distaff/src/threadpool.rs create mode 100644 distaff/src/util.rs create mode 100644 distaff/src/workerthread.rs diff --git a/Cargo.toml b/Cargo.toml index 550b0b6..af5838e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,9 @@ debug = true [profile.release] debug = true +[workspace] +members = ["distaff"] + [dependencies] futures = "0.3" diff --git a/distaff/Cargo.toml b/distaff/Cargo.toml new file mode 100644 index 0000000..5206454 --- /dev/null +++ b/distaff/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "distaff" +version = "0.1.0" +edition = "2024" + +[features] +default = [] +std = [] + +[dependencies] +parking_lot = {version = "0.12.3"} +tracing = "0.1.40" +parking_lot_core = "0.9.10" +crossbeam-utils = "0.8.21" + +async-task = "4.7.1" \ No newline at end of file diff --git a/distaff/rust-toolchain b/distaff/rust-toolchain new file mode 100644 index 0000000..bf867e0 --- /dev/null +++ b/distaff/rust-toolchain @@ -0,0 +1 @@ +nightly diff --git a/distaff/src/context.rs b/distaff/src/context.rs new file mode 100644 index 0000000..71427a0 --- /dev/null +++ b/distaff/src/context.rs @@ -0,0 +1,279 @@ +use std::{ + ptr::NonNull, + sync::{ + Arc, OnceLock, Weak, + atomic::{AtomicU8, Ordering}, + }, +}; + +use alloc::collections::BTreeMap; + +use crossbeam_utils::CachePadded; +use parking_lot::{Condvar, Mutex}; + +use crate::{ + job::{Job, StackJob}, + latch::{LatchRef, MutexLatch, WakeLatch}, + workerthread::{HeartbeatThread, WorkerThread}, +}; + +pub struct Heartbeat { + heartbeat: AtomicU8, + pub latch: MutexLatch, +} + +impl Heartbeat { + pub const CLEAR: u8 = 0; + pub const PENDING: u8 = 1; + pub const SLEEPING: u8 = 2; + + pub fn new() -> (Arc>, Weak>) { + let strong = Arc::new(CachePadded::new(Self { + heartbeat: AtomicU8::new(Self::CLEAR), + latch: MutexLatch::new(), + })); + let weak = Arc::downgrade(&strong); + + (strong, weak) + } + + /// returns true if the heartbeat was previously sleeping. + pub fn set_pending(&self) -> bool { + let old = self.heartbeat.swap(Self::PENDING, Ordering::Relaxed); + old == Self::SLEEPING + } + + pub fn clear(&self) { + self.heartbeat.store(Self::CLEAR, Ordering::Relaxed); + } + + pub fn is_pending(&self) -> bool { + self.heartbeat.load(Ordering::Relaxed) == Self::PENDING + } + + pub fn is_sleeping(&self) -> bool { + self.heartbeat.load(Ordering::Relaxed) == Self::SLEEPING + } +} + +pub struct Context { + shared: Mutex, + pub shared_job: Condvar, +} + +pub(crate) struct Shared { + pub jobs: BTreeMap>, + pub heartbeats: BTreeMap>>, + injected_jobs: Vec>, + heartbeat_count: usize, + should_exit: bool, +} + +unsafe impl Send for Shared {} + +impl Shared { + pub fn new_heartbeat(&mut self) -> (Arc>, usize) { + let index = self.heartbeat_count; + self.heartbeat_count = index.wrapping_add(1); + + let (strong, weak) = Heartbeat::new(); + + self.heartbeats.insert(index, weak); + + (strong, index) + } + + pub fn pop_job(&mut self) -> Option> { + // this is unlikely, so make the function cold? + // TODO: profile this + if !self.injected_jobs.is_empty() { + unsafe { return Some(self.pop_injected_job()) }; + } else { + self.jobs.pop_first().map(|(_, job)| job) + } + } + + #[cold] + unsafe fn pop_injected_job(&mut self) -> NonNull { + self.injected_jobs.pop().unwrap() + } + + pub fn should_exit(&self) -> bool { + self.should_exit + } +} + +impl Context { + #[inline] + pub fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> { + self.shared.lock() + } + + pub fn new_with_threads(num_threads: usize) -> Arc { + let this = Arc::new(Self { + shared: Mutex::new(Shared { + jobs: BTreeMap::new(), + heartbeats: BTreeMap::new(), + injected_jobs: Vec::new(), + heartbeat_count: 0, + should_exit: false, + }), + shared_job: Condvar::new(), + }); + + tracing::trace!("Creating thread pool with {} threads", num_threads); + + // Create a barrier to synchronize the worker threads and the heartbeat thread + let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2)); + + for i in 0..num_threads { + let ctx = this.clone(); + let barrier = barrier.clone(); + + std::thread::Builder::new() + .name(format!("worker-{}", i)) + .spawn(move || { + let worker = Box::new(WorkerThread::new_in(ctx)); + + barrier.wait(); + worker.run(); + }) + .expect("Failed to spawn worker thread"); + } + + { + let ctx = this.clone(); + let barrier = barrier.clone(); + + std::thread::Builder::new() + .name("heartbeat-thread".to_string()) + .spawn(move || { + barrier.wait(); + HeartbeatThread::new(ctx).run(); + }) + .expect("Failed to spawn heartbeat thread"); + } + + barrier.wait(); + + this + } + + pub fn new() -> Arc { + Self::new_with_threads(crate::util::available_parallelism()) + } + + pub fn global_context() -> &'static Arc { + static GLOBAL_CONTEXT: OnceLock> = OnceLock::new(); + + GLOBAL_CONTEXT.get_or_init(|| Self::new()) + } + + pub fn inject_job(&self, job: NonNull) { + let mut shared = self.shared.lock(); + shared.injected_jobs.push(job); + self.notify_shared_job(); + } + + pub fn notify_shared_job(&self) { + self.shared_job.notify_one(); + } + + /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. + fn run_in_worker_cross(self: &Arc, worker: &WorkerThread, f: F) -> T + where + F: FnOnce(&WorkerThread) -> T + Send, + T: Send, + { + // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. + + let latch = WakeLatch::new(self.clone(), worker.index); + + let job = StackJob::new( + move || { + let worker = WorkerThread::current_ref() + .expect("WorkerThread::run_in_worker called outside of worker thread"); + + f(worker) + }, + LatchRef::new(&latch), + ); + + let job = job.as_job(); + job.set_pending(); + + self.inject_job(Into::into(&job)); + + worker.wait_until_latch(&latch); + + let t = unsafe { job.transmute_ref::().wait().into_result() }; + + t + } + + /// Run closure in this context, sleeping until the job is done. + pub fn run_in_worker_cold(self: &Arc, f: F) -> T + where + F: FnOnce(&WorkerThread) -> T + Send, + T: Send, + { + use crate::latch::MutexLatch; + // current thread isn't a worker thread, create job and inject into global context + + let latch = MutexLatch::new(); + + let job = StackJob::new( + move || { + let worker = WorkerThread::current_ref() + .expect("WorkerThread::run_in_worker called outside of worker thread"); + + f(worker) + }, + LatchRef::new(&latch), + ); + + let job = job.as_job(); + job.set_pending(); + + self.inject_job(Into::into(&job)); + latch.wait(); + + let t = unsafe { job.transmute_ref::().wait().into_result() }; + + t + } + + /// Run closure in this context. + pub fn run_in_worker(self: &Arc, f: F) -> T + where + T: Send, + F: FnOnce(&WorkerThread) -> T + Send, + { + match WorkerThread::current_ref() { + Some(worker) => { + // check if worker is in the same context + if Arc::ptr_eq(&worker.context, self) { + tracing::trace!("run_in_worker: current thread"); + f(worker) + } else { + // current thread is a worker for a different context + tracing::trace!("run_in_worker: cross-context"); + self.run_in_worker_cross(worker, f) + } + } + None => { + // current thread is not a worker for any context + tracing::trace!("run_in_worker: inject into context"); + self.run_in_worker_cold(f) + } + } + } +} + +pub fn run_in_worker(f: F) -> T +where + T: Send, + F: FnOnce(&WorkerThread) -> T + Send, +{ + Context::global_context().run_in_worker(f) +} diff --git a/distaff/src/job.rs b/distaff/src/job.rs new file mode 100644 index 0000000..ed5830f --- /dev/null +++ b/distaff/src/job.rs @@ -0,0 +1,666 @@ +use core::{ + any::Any, + cell::UnsafeCell, + fmt::Debug, + hint::cold_path, + mem::{self, ManuallyDrop}, + ptr::{self, NonNull}, + sync::atomic::Ordering, +}; + +use alloc::boxed::Box; +use parking_lot_core::SpinWait; + +use crate::util::{SmallBox, TaggedAtomicPtr}; + +#[repr(u8)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum JobState { + Empty, + Locked = 1, + Pending, + Finished, + // Inline = 1 << (u8::BITS - 1), + // IsError = 1 << (u8::BITS - 2), +} + +impl JobState { + #[allow(dead_code)] + const MASK: u8 = 0; // Self::Inline as u8 | Self::IsError as u8; + + fn from_u8(v: u8) -> Option { + match v { + 0 => Some(Self::Empty), + 1 => Some(Self::Locked), + 2 => Some(Self::Pending), + 3 => Some(Self::Finished), + _ => None, + } + } +} + +pub use joblist::JobList; + +mod joblist { + use core::{fmt::Debug, ptr::NonNull}; + + use alloc::boxed::Box; + + use super::Job; + + // the list looks like this: + // head <-> job1 <-> job2 <-> ... <-> jobN <-> tail + pub struct JobList { + // these cannot be boxes because boxes are noalias. + head: NonNull, + tail: NonNull, + // the number of jobs in the list. + // this is used to judge whether or not to join sync or async. + job_count: usize, + } + + impl JobList { + pub fn new() -> Self { + let head = Box::into_raw(Box::new(Job::empty())); + let tail = Box::into_raw(Box::new(Job::empty())); + + // head and tail point at themselves + unsafe { + (&*head).link_mut().prev = None; + (&*head).link_mut().next = Some(NonNull::new_unchecked(tail)); + + (&*tail).link_mut().prev = Some(NonNull::new_unchecked(head)); + (&*tail).link_mut().next = None; + + Self { + head: NonNull::new_unchecked(head), + tail: NonNull::new_unchecked(tail), + job_count: 0, + } + } + } + + fn head(&self) -> NonNull { + self.head + } + fn tail(&self) -> NonNull { + self.tail + } + + /// `job` must be valid until it is removed from the list. + pub unsafe fn push_front(&mut self, job: *const Job) { + self.job_count += 1; + let headlink = unsafe { self.head.as_ref().link_mut() }; + + let next = headlink.next.unwrap(); + let next_link = unsafe { next.as_ref().link_mut() }; + + let job_ptr = unsafe { NonNull::new_unchecked(job as _) }; + + headlink.next = Some(job_ptr); + next_link.prev = Some(job_ptr); + + let job_link = unsafe { job_ptr.as_ref().link_mut() }; + job_link.next = Some(next); + job_link.prev = Some(self.head); + } + + /// `job` must be valid until it is removed from the list. + pub unsafe fn push_back(&mut self, job: *const Job) { + self.job_count += 1; + let taillink = unsafe { self.tail.as_ref().link_mut() }; + + let prev = taillink.prev.unwrap(); + let prev_link = unsafe { prev.as_ref().link_mut() }; + + let job_ptr = unsafe { NonNull::new_unchecked(job as _) }; + + taillink.prev = Some(job_ptr); + prev_link.next = Some(job_ptr); + + let job_link = unsafe { job_ptr.as_ref().link_mut() }; + job_link.prev = Some(prev); + job_link.next = Some(self.tail); + } + + pub fn pop_front(&mut self) -> Option> { + self.job_count -= 1; + + let headlink = unsafe { self.head.as_ref().link_mut() }; + + // SAFETY: headlink.next is guaranteed to be Some. + let job = headlink.next.unwrap(); + let job_link = unsafe { job.as_ref().link_mut() }; + + // short-circuit here if the job is the tail + let next = job_link.next?; + let next_link = unsafe { next.as_ref().link_mut() }; + + headlink.next = Some(next); + next_link.prev = Some(self.head); + + Some(job) + } + + pub fn pop_back(&mut self) -> Option> { + self.job_count -= 1; + + let taillink = unsafe { self.tail.as_ref().link_mut() }; + + // SAFETY: taillink.prev is guaranteed to be Some. + let job = taillink.prev.unwrap(); + let job_link = unsafe { job.as_ref().link_mut() }; + + // short-circuit here if the job is the head + let prev = job_link.prev?; + let prev_link = unsafe { prev.as_ref().link_mut() }; + + taillink.prev = Some(prev); + prev_link.next = Some(self.tail); + + Some(job) + } + + pub fn is_empty(&self) -> bool { + self.job_count == 0 + } + + pub fn len(&self) -> usize { + self.job_count + } + } + + impl Drop for JobList { + fn drop(&mut self) { + // Need to drop the head and tail, which were allocated on the heap. + // elements of the list are managed externally. + unsafe { + drop((Box::from_non_null(self.head), Box::from_non_null(self.tail))); + } + } + } + + impl Debug for JobList { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JobList") + .field("head", &self.head) + .field("tail", &self.tail) + .field("job_count", &self.job_count) + .field_with("jobs", |f| { + let mut jobs = f.debug_list(); + + // SAFETY: head.next is guaranteed to be non-null and valid + let mut job = unsafe { self.head.as_ref().link_mut().next.unwrap() }; + + while job != self.tail { + let job_ref = unsafe { job.as_ref() }; + jobs.entry(job_ref); + + // SAFETY: job is guaranteed to be non-null and valid + // only the tail has a next of None + job = unsafe { job_ref.link_mut().next.unwrap() }; + } + + jobs.finish() + }) + .finish() + } + } +} + +#[repr(transparent)] +pub struct JobResult { + inner: std::thread::Result, +} + +impl JobResult { + pub fn new(result: std::thread::Result) -> Self { + Self { inner: result } + } + + /// convert JobResult into a thread result. + #[allow(dead_code)] + pub fn into_inner(self) -> std::thread::Result { + self.inner + } + + // unwraps the result, propagating panics + pub fn into_result(self) -> T { + match self.inner { + Ok(val) => val, + Err(payload) => { + cold_path(); + + std::panic::resume_unwind(payload); + // #[cfg(feature = "std")] + // { + // std::panic::resume_unwind(err); + // } + // #[cfg(not(feature = "std"))] + // { + // // in no-std, we just panic with the error + // // TODO: figure out how to propagate the error + // panic!("Job failed: {:?}", payload); + // } + } + } + } +} + +#[derive(Debug, PartialEq, Eq)] +struct Link { + prev: Option>, + next: Option>, +} + +// `Link` is invariant over `T` +impl Clone for Link { + fn clone(&self) -> Self { + Self { + prev: self.prev.clone(), + next: self.next.clone(), + } + } +} + +// `Link` is invariant over `T` +impl Copy for Link {} + +struct Thread; + +union ValueOrThis { + uninit: (), + value: ManuallyDrop>, + this: NonNull<()>, +} + +union LinkOrError { + link: Link, + waker: ManuallyDrop>, + error: ManuallyDrop>>, +} + +#[repr(C)] +pub struct Job { + /// stores the job's harness as a *const usize + harness_and_state: TaggedAtomicPtr, + /// `this` before `execute()` is called, or `value` after `execute()` + value_or_this: UnsafeCell>, + /// `link` before `execute()` is called, or `error` after `execute()` + error_or_link: UnsafeCell>, +} + +unsafe impl Send for Job {} + +impl Debug for Job { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let state = JobState::from_u8(self.harness_and_state.tag(Ordering::Relaxed) as u8).unwrap(); + let mut debug = f.debug_struct("Job"); + debug.field("state", &state).field_with("harness", |f| { + write!(f, "{:?}", self.harness_and_state.ptr(Ordering::Relaxed)) + }); + + match state { + JobState::Empty => { + debug + .field_with("this", |f| { + write!(f, "{:?}", unsafe { &(&*self.value_or_this.get()).this }) + }) + .field_with("link", |f| { + write!(f, "{:?}", unsafe { &(&*self.error_or_link.get()).link }) + }); + } + JobState::Locked => { + #[derive(Debug)] + struct Locked; + debug.field("locked", &Locked); + } + JobState::Pending => { + debug + .field_with("this", |f| { + write!(f, "{:?}", unsafe { &(&*self.value_or_this.get()).this }) + }) + .field_with("waker", |f| { + write!(f, "{:?}", unsafe { &(&*self.error_or_link.get()).waker }) + }); + } + JobState::Finished => { + let err = unsafe { &(&*self.error_or_link.get()).error }; + + let result = match err.as_ref() { + Some(err) => Err(err), + None => Ok(unsafe { (&*self.value_or_this.get()).value.0.as_ptr() }), + }; + + debug.field("result", &result); + } + } + + debug.finish() + } +} + +impl Job { + pub fn empty() -> Job { + Self { + harness_and_state: TaggedAtomicPtr::new(ptr::dangling_mut(), JobState::Empty as usize), + value_or_this: UnsafeCell::new(ValueOrThis { + this: NonNull::dangling(), + }), + error_or_link: UnsafeCell::new(LinkOrError { + link: Link { + prev: None, + next: None, + }, + }), + // _phantom: PhantomPinned, + } + } + pub fn new(harness: unsafe fn(*const (), *const Job), this: NonNull<()>) -> Job { + Self { + harness_and_state: TaggedAtomicPtr::new( + unsafe { mem::transmute(harness) }, + JobState::Empty as usize, + ), + value_or_this: UnsafeCell::new(ValueOrThis { this }), + error_or_link: UnsafeCell::new(LinkOrError { + link: Link { + prev: None, + next: None, + }, + }), + // _phantom: PhantomPinned, + } + } + + // Job is passed around type-erased as `Job<()>`, to complete the job we + // need to cast it back to the original type. + pub unsafe fn transmute_ref(&self) -> &Job { + unsafe { mem::transmute::<&Job, &Job>(self) } + } + + #[inline] + unsafe fn link_mut(&self) -> &mut Link { + unsafe { &mut (&mut *self.error_or_link.get()).link } + } + + /// assumes job is in a `JobList` + pub unsafe fn unlink(&self) { + unsafe { + let mut dummy = None; + let Link { prev, next } = *self.link_mut(); + + *prev + .map(|ptr| &mut ptr.as_ref().link_mut().next) + .unwrap_or(&mut dummy) = next; + *next + .map(|ptr| &mut ptr.as_ref().link_mut().prev) + .unwrap_or(&mut dummy) = prev; + } + } + + pub fn state(&self) -> u8 { + self.harness_and_state.tag(Ordering::Relaxed) as u8 + } + + pub fn wait(&self) -> JobResult { + let mut spin = SpinWait::new(); + loop { + match self.harness_and_state.compare_exchange_weak_tag( + JobState::Pending as usize, + JobState::Locked as usize, + Ordering::Acquire, + Ordering::Relaxed, + ) { + // if still pending, sleep until completed + Ok(state) => { + debug_assert_eq!(state, JobState::Pending as usize); + unsafe { + *(&mut *self.error_or_link.get()).waker = Some(std::thread::current()); + } + + self.harness_and_state.set_tag( + JobState::Pending as usize, + Ordering::Release, + Ordering::Relaxed, + ); + + std::thread::park(); + spin.reset(); + + // after sleeping, state should be `Finished` + } + Err(state) => { + // job finished under us, check if it was successful + if state == JobState::Finished as usize { + let err = unsafe { (&mut *self.error_or_link.get()).error.take() }; + + let result: std::thread::Result = if let Some(err) = err { + cold_path(); + Err(err) + } else { + let val = unsafe { + ManuallyDrop::take(&mut (&mut *self.value_or_this.get()).value) + }; + + Ok(val.into_inner()) + }; + + return JobResult::new(result); + } else { + // spin until lock is released. + tracing::trace!("spin-waiting for job: {:?}", self); + spin.spin(); + } + } + } + } + } + + /// call this when popping value from local queue + pub fn set_pending(&self) { + let mut spin = SpinWait::new(); + loop { + match self.harness_and_state.compare_exchange_weak_tag( + JobState::Empty as usize, + JobState::Pending as usize, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(state) => { + debug_assert_eq!(state, JobState::Empty as usize); + // set waker to None + unsafe { + (&mut *self.error_or_link.get()).waker = ManuallyDrop::new(None); + } + return; + } + Err(_) => { + // debug_assert_ne!(state, JobState::Empty as usize); + + tracing::error!("######## what the sigma?"); + spin.spin(); + } + } + } + } + + pub fn execute(job: NonNull) { + tracing::trace!("executing job: {:?}", job); + + // SAFETY: self is non-null + unsafe { + let this = job.as_ref(); + let (ptr, state) = this.harness_and_state.ptr_and_tag(Ordering::Relaxed); + + debug_assert_eq!(state, JobState::Pending as usize); + let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr()); + + let this = (*this.value_or_this.get()).this; + + harness(this.as_ptr().cast(), job.as_ptr()); + } + } + + pub(crate) fn complete(&self, result: std::thread::Result) { + let mut spin = SpinWait::new(); + loop { + match self.harness_and_state.compare_exchange_weak_tag( + JobState::Pending as usize, + JobState::Locked as usize, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(state) => { + debug_assert_eq!(state, JobState::Pending as usize); + break; + } + Err(_) => { + // debug_assert_ne!(state, JobState::Pending as usize); + spin.spin(); + } + } + } + + let waker = unsafe { (&mut *self.error_or_link.get()).waker.take() }; + + match result { + Ok(val) => unsafe { + (&mut *self.value_or_this.get()).value = ManuallyDrop::new(SmallBox::new(val)); + (&mut *self.error_or_link.get()).error = ManuallyDrop::new(None); + }, + Err(err) => unsafe { + (&mut *self.value_or_this.get()).uninit = (); + (&mut *self.error_or_link.get()).error = ManuallyDrop::new(Some(err)); + }, + } + + if let Some(thread) = waker { + thread.unpark(); + } + + self.harness_and_state.set_tag( + JobState::Finished as usize, + Ordering::Release, + Ordering::Relaxed, + ); + } +} + +mod stackjob { + use crate::latch::Latch; + + use super::*; + + pub struct StackJob { + latch: L, + f: UnsafeCell>, + } + + impl StackJob { + pub fn new(f: F, latch: L) -> Self { + Self { + latch, + f: UnsafeCell::new(ManuallyDrop::new(f)), + } + } + + pub unsafe fn unwrap(&self) -> F { + unsafe { ManuallyDrop::take(&mut *self.f.get()) } + } + } + + impl StackJob + where + L: Latch, + { + pub fn as_job(&self) -> Job<()> + where + F: FnOnce() -> T + Send, + T: Send, + { + #[align(8)] + unsafe fn harness(this: *const (), job: *const Job<()>) + where + F: FnOnce() -> T + Send, + T: Sized + Send, + { + let this = unsafe { &*this.cast::>() }; + let f = unsafe { this.unwrap() }; + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); + + let job = unsafe { &*job.cast::>() }; + job.complete(result); + + unsafe { + Latch::set_raw(&this.latch); + } + } + + Job::new(harness::, unsafe { + NonNull::new_unchecked(self as *const _ as *mut ()) + }) + } + } +} + +mod heapjob { + use super::*; + + pub struct HeapJob { + f: F, + } + + impl HeapJob { + pub fn new(f: F) -> Self { + Self { f } + } + + pub fn into_inner(self) -> F { + self.f + } + + pub fn into_boxed_job(self: Box) -> *mut Job<()> + where + F: FnOnce() -> T + Send, + T: Send, + { + #[align(8)] + unsafe fn harness(this: *const (), job: *const Job<()>) + where + F: FnOnce() -> T + Send, + T: Send, + { + let job = job.cast_mut(); + + // turn `this`, which was allocated at (2), into box. + // miri complains this is a use-after-free, but it isn't? silly miri... + // Turns out this is actually correct on miri's end, but because + // we ensure that the scope lives as long as any jobs, this is + // actually fine, as far as I can tell. + let this = unsafe { Box::from_raw(this.cast::>().cast_mut()) }; + let f = this.into_inner(); + + _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); + + // drop job (this is fine because the job of a HeapJob is pure POD). + unsafe { + ptr::drop_in_place(job); + } + + // free box that was allocated at (1) + _ = unsafe { Box::>>::from_raw(job.cast()) }; + } + + // (1) allocate box for job + Box::into_raw(Box::new(Job::new(harness::, { + // (2) convert self into a pointer + Box::into_non_null(self).cast() + }))) + } + } +} + +pub use heapjob::HeapJob; +pub use stackjob::StackJob; diff --git a/distaff/src/join.rs b/distaff/src/join.rs new file mode 100644 index 0000000..eef4ea4 --- /dev/null +++ b/distaff/src/join.rs @@ -0,0 +1,110 @@ +use std::hint::cold_path; + +use crate::{ + job::{JobState, StackJob}, + latch::{AsCoreLatch, LatchRef, WakeLatch}, + workerthread::WorkerThread, +}; + +impl WorkerThread { + #[inline] + fn join_seq(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + let rb = b(); + let ra = a(); + + (ra, rb) + } + + /// This function must be called from a worker thread. + #[inline] + pub(crate) fn join_heartbeat_every( + &self, + a: A, + b: B, + ) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + // SAFETY: each worker is only ever used by one thread, so this is safe. + let count = self.join_count.get(); + self.join_count.set(count.wrapping_add(1) % TIMES as u8); + + // TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq. + // see: chili + + // SAFETY: this function runs in a worker thread, so we can access the queue safely. + if count == 0 || unsafe { self.queue.as_ref_unchecked().len() } < 3 { + cold_path(); + self.join_heartbeat(a, b) + } else { + self.join_seq(a, b) + } + } + + /// This function must be called from a worker thread. + #[inline] + fn join_heartbeat(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; + + let latch = WakeLatch::new(self.context.clone(), self.index); + let a = StackJob::new( + move || { + // TODO: bench whether tick'ing here is good. + // turns out this actually costs a lot of time, likely because of the thread local check. + // WorkerThread::current_ref() + // .expect("stackjob is run in workerthread.") + // .tick(); + + a() + }, + LatchRef::new(&latch), + ); + + let job = a.as_job(); + self.push_front(&job); + + let rb = match catch_unwind(AssertUnwindSafe(|| b())) { + Ok(val) => val, + Err(payload) => { + cold_path(); + // if b panicked, we need to wait for a to finish + self.wait_until_latch(&latch); + resume_unwind(payload); + } + }; + + let ra = if job.state() == JobState::Empty as u8 { + unsafe { + job.unlink(); + } + + // a is allowed to panic here, because we already finished b. + unsafe { a.unwrap()() } + } else { + match self.wait_until_job::(unsafe { job.transmute_ref() }, latch.as_core_latch()) { + Some(t) => t.into_result(), // propagate panic here + // the job was shared, but not yet stolen, so we get to run the + // job inline + None => unsafe { a.unwrap()() }, + } + }; + + drop(a); + (ra, rb) + } +} diff --git a/distaff/src/latch.rs b/distaff/src/latch.rs new file mode 100644 index 0000000..6a8c20b --- /dev/null +++ b/distaff/src/latch.rs @@ -0,0 +1,367 @@ +use core::{ + marker::PhantomData, + sync::atomic::{AtomicUsize, Ordering}, +}; +use std::sync::{Arc, atomic::AtomicU8}; + +use parking_lot::{Condvar, Mutex}; + +use crate::context::Context; + +pub trait Latch { + unsafe fn set_raw(this: *const Self); +} + +pub trait Probe { + fn probe(&self) -> bool; +} + +pub type CoreLatch = AtomicLatch; +pub trait AsCoreLatch { + fn as_core_latch(&self) -> &CoreLatch; +} + +#[derive(Debug)] +pub struct AtomicLatch { + inner: AtomicU8, +} + +impl AtomicLatch { + pub const UNSET: u8 = 0; + pub const SET: u8 = 1; + pub const SLEEPING: u8 = 2; + + #[inline] + pub const fn new() -> Self { + Self { + inner: AtomicU8::new(Self::UNSET), + } + } + #[inline] + pub fn reset(&self) { + self.inner.store(Self::UNSET, Ordering::Release); + } + + pub fn get(&self) -> u8 { + self.inner.load(Ordering::Acquire) + } + + /// returns true if the latch was previously sleeping. + #[inline] + pub unsafe fn set(this: *const Self) -> bool { + unsafe { + let old = (*this).inner.swap(Self::SET, Ordering::Release); + old == Self::SLEEPING + } + } +} + +impl Latch for AtomicLatch { + #[inline] + unsafe fn set_raw(this: *const Self) { + unsafe { + Self::set(this); + } + } +} + +impl Probe for AtomicLatch { + #[inline] + fn probe(&self) -> bool { + self.inner.load(Ordering::Acquire) == Self::SET + } +} +impl AsCoreLatch for AtomicLatch { + #[inline] + fn as_core_latch(&self) -> &CoreLatch { + self + } +} + +pub struct LatchRef<'a, L: Latch> { + inner: *const L, + _marker: PhantomData<&'a L>, +} + +impl<'a, L: Latch> LatchRef<'a, L> { + #[inline] + pub const fn new(latch: &'a L) -> Self { + Self { + inner: latch, + _marker: PhantomData, + } + } +} + +impl<'a, L: Latch> Latch for LatchRef<'a, L> { + #[inline] + unsafe fn set_raw(this: *const Self) { + unsafe { + let this = &*this; + Latch::set_raw(this.inner); + } + } +} + +impl<'a, L: Latch + Probe> Probe for LatchRef<'a, L> { + #[inline] + fn probe(&self) -> bool { + unsafe { + let this = &*self.inner; + Probe::probe(this) + } + } +} + +impl<'a, L> AsCoreLatch for LatchRef<'a, L> +where + L: Latch + AsCoreLatch, +{ + #[inline] + fn as_core_latch(&self) -> &CoreLatch { + unsafe { + let this = &*self.inner; + this.as_core_latch() + } + } +} + +pub struct NopLatch; + +impl Latch for NopLatch { + #[inline] + unsafe fn set_raw(_this: *const Self) { + // do nothing + } +} + +impl Probe for NopLatch { + #[inline] + fn probe(&self) -> bool { + false // always returns false + } +} + +pub struct ThreadWakeLatch { + waker: Mutex>, +} + +impl ThreadWakeLatch { + #[inline] + pub const fn new() -> Self { + Self { + waker: Mutex::new(None), + } + } + + #[inline] + pub fn reset(&self) { + let mut waker = self.waker.lock(); + *waker = None; + } + + #[inline] + pub fn set_waker(&self, thread: std::thread::Thread) { + let mut waker = self.waker.lock(); + *waker = Some(thread); + } + + pub unsafe fn wait(&self) { + assert!( + self.waker.lock().replace(std::thread::current()).is_none(), + "ThreadWakeLatch can only be waited once per thread" + ); + + std::thread::park(); + } +} + +impl Latch for ThreadWakeLatch { + #[inline] + unsafe fn set_raw(this: *const Self) { + unsafe { + if let Some(thread) = (&*this).waker.lock().take() { + thread.unpark(); + } + } + } +} + +impl Probe for ThreadWakeLatch { + #[inline] + fn probe(&self) -> bool { + self.waker.lock().is_some() + } +} + +pub struct CountLatch { + count: AtomicUsize, + inner: L, +} + +impl CountLatch { + #[inline] + pub const fn new(inner: L) -> Self { + Self { + count: AtomicUsize::new(0), + inner, + } + } + + pub fn count(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + pub fn inner(&self) -> &L { + &self.inner + } + + #[inline] + pub fn increment(&self) { + self.count.fetch_add(1, Ordering::Release); + } + + #[inline] + pub fn decrement(&self) { + if self.count.fetch_sub(1, Ordering::Release) == 1 { + unsafe { + Latch::set_raw(&self.inner); + } + } + } +} + +impl Latch for CountLatch { + #[inline] + unsafe fn set_raw(this: *const Self) { + unsafe { + let this = &*this; + this.decrement(); + } + } +} + +impl Probe for CountLatch { + #[inline] + fn probe(&self) -> bool { + self.inner.probe() + } +} + +impl AsCoreLatch for CountLatch { + #[inline] + fn as_core_latch(&self) -> &CoreLatch { + self.inner.as_core_latch() + } +} + +pub struct MutexLatch { + inner: Mutex, + condvar: Condvar, +} + +impl MutexLatch { + #[inline] + pub const fn new() -> Self { + Self { + inner: Mutex::new(false), + condvar: Condvar::new(), + } + } + + #[inline] + pub fn reset(&self) { + let mut guard = self.inner.lock(); + *guard = false; + } + + pub fn wait(&self) { + let mut guard = self.inner.lock(); + while !*guard { + self.condvar.wait(&mut guard); + } + } + + pub fn set(&self) { + unsafe { + Latch::set_raw(self); + } + } + + pub fn wait_and_reset(&self) { + let mut guard = self.inner.lock(); + while !*guard { + self.condvar.wait(&mut guard); + } + *guard = false; + } +} + +impl Latch for MutexLatch { + #[inline] + unsafe fn set_raw(this: *const Self) { + unsafe { + *(&*this).inner.lock() = true; + (&*this).condvar.notify_all(); + } + } +} + +impl Probe for MutexLatch { + #[inline] + fn probe(&self) -> bool { + *self.inner.lock() + } +} + +pub struct WakeLatch { + inner: AtomicLatch, + context: Arc, + worker_index: AtomicUsize, +} + +impl WakeLatch { + pub fn new(context: Arc, worker_index: usize) -> Self { + Self { + inner: AtomicLatch::new(), + context, + worker_index: AtomicUsize::new(worker_index), + } + } + + pub(crate) fn set_worker_index(&self, worker_index: usize) { + self.worker_index.store(worker_index, Ordering::Relaxed); + } +} + +impl Latch for WakeLatch { + #[inline] + unsafe fn set_raw(this: *const Self) { + unsafe { + let ctx = (&*this).context.clone(); + let worker_index = (&*this).worker_index.load(Ordering::Relaxed); + + if CoreLatch::set(&(&*this).inner) { + // If the latch was sleeping, wake the worker thread + ctx.shared().heartbeats.get(&worker_index).and_then(|weak| { + weak.upgrade() + .map(|heartbeat| Latch::set_raw(&heartbeat.latch)) + }); + } + } + } +} + +impl Probe for WakeLatch { + #[inline] + fn probe(&self) -> bool { + self.inner.probe() + } +} + +impl AsCoreLatch for WakeLatch { + #[inline] + fn as_core_latch(&self) -> &CoreLatch { + &self.inner + } +} diff --git a/distaff/src/lib.rs b/distaff/src/lib.rs new file mode 100644 index 0000000..2bad781 --- /dev/null +++ b/distaff/src/lib.rs @@ -0,0 +1,22 @@ +// #![no_std] +#![feature( + fn_align, + cold_path, + stmt_expr_attributes, + debug_closure_helpers, + unsafe_cell_access, + box_as_ptr, + box_vec_non_null, + let_chains +)] + +extern crate alloc; + +mod context; +mod job; +mod join; +mod latch; +mod scope; +mod threadpool; +pub mod util; +mod workerthread; diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs new file mode 100644 index 0000000..106ab11 --- /dev/null +++ b/distaff/src/scope.rs @@ -0,0 +1,267 @@ +use std::{ + any::Any, + marker::PhantomData, + ptr::{self, NonNull}, + sync::{ + Arc, + atomic::{AtomicPtr, Ordering}, + }, +}; + +use async_task::Runnable; + +use crate::{ + context::{Context, run_in_worker}, + job::{HeapJob, Job}, + latch::{AsCoreLatch, CountLatch, WakeLatch}, + util::{DropGuard, SendPtr}, + workerthread::WorkerThread, +}; + +pub struct Scope<'scope> { + // latch to wait on before the scope finishes + job_counter: CountLatch, + // local threadpool + context: Arc, + // panic error + panic: AtomicPtr>, + // variant lifetime + _pd: PhantomData, +} + +pub fn scope<'scope, F, R>(f: F) -> R +where + F: FnOnce(&Scope<'scope>) -> R + Send, + R: Send, +{ + Scope::<'scope>::scope(f) +} + +impl<'scope> Scope<'scope> { + fn wait_for_jobs(&self, worker: &WorkerThread) { + tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); + tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe { + worker.queue.as_ref_unchecked() + }); + + // set worker index in the job counter + self.job_counter.inner().set_worker_index(worker.index); + worker.wait_until_latch(self.job_counter.as_core_latch()); + } + + pub fn scope(f: F) -> R + where + F: FnOnce(&Self) -> R + Send, + R: Send, + { + run_in_worker(|worker| { + // SAFETY: we call complete() after creating this scope, which + // ensures that any jobs spawned from the scope exit before the + // scope closes. + let this = unsafe { Self::from_context(worker.context.clone()) }; + this.complete(worker, || f(&this)) + }) + } + + fn scope_with_context(context: Arc, f: F) -> R + where + F: FnOnce(&Self) -> R + Send, + R: Send, + { + context.run_in_worker(|worker| { + // SAFETY: we call complete() after creating this scope, which + // ensures that any jobs spawned from the scope exit before the + // scope closes. + let this = unsafe { Self::from_context(context.clone()) }; + this.complete(worker, || f(&this)) + }) + } + + /// should be called from within a worker thread. + fn complete(&self, worker: &WorkerThread, f: F) -> R + where + F: FnOnce() -> R + Send, + R: Send, + { + use std::panic::{AssertUnwindSafe, catch_unwind}; + + #[allow(dead_code)] + fn make_job T, T>(f: F) -> Job { + #[align(8)] + unsafe fn harness T, T>(this: *const (), job: *const Job) { + let f = unsafe { Box::from_raw(this.cast::().cast_mut()) }; + + let result = catch_unwind(AssertUnwindSafe(move || f())); + + let job = unsafe { Box::from_raw(job.cast_mut()) }; + job.complete(result); + } + + Job::::new(harness::, unsafe { + NonNull::new_unchecked(Box::into_raw(Box::new(f))).cast() + }) + } + + let result = match catch_unwind(AssertUnwindSafe(|| f())) { + Ok(val) => Some(val), + Err(payload) => { + self.panicked(payload); + None + } + }; + + self.wait_for_jobs(worker); + self.maybe_propagate_panic(); + + // SAFETY: if result panicked, we would have propagated the panic above. + result.unwrap() + } + + /// resumes the panic if one happened in this scope. + fn maybe_propagate_panic(&self) { + let err_ptr = self.panic.load(Ordering::Relaxed); + if !err_ptr.is_null() { + unsafe { + let err = Box::from_raw(err_ptr); + std::panic::resume_unwind(*err); + } + } + } + + /// stores the first panic that happened in this scope. + fn panicked(&self, err: Box) { + self.panic.load(Ordering::Relaxed).is_null().then(|| { + use core::mem::ManuallyDrop; + let mut boxed = ManuallyDrop::new(Box::new(err)); + + let err_ptr: *mut Box = &mut **boxed; + if self + .panic + .compare_exchange( + ptr::null_mut(), + err_ptr, + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + // we successfully set the panic, no need to drop + } else { + // drop the error, someone else already set it + _ = ManuallyDrop::into_inner(boxed); + } + }); + } + + pub fn spawn(&self, f: F) + where + F: FnOnce(&Scope<'scope>) + Send, + { + self.context.run_in_worker(|worker| { + self.job_counter.increment(); + + let this = SendPtr::new_const(self).unwrap(); + + let job = Box::new(HeapJob::new(move || unsafe { + _ = f(this.as_ref()); + this.as_ref().job_counter.decrement(); + })) + .into_boxed_job(); + + tracing::trace!("allocated heapjob"); + + worker.push_front(job); + + tracing::trace!("leaked heapjob"); + }); + } + + pub fn spawn_future(&self, future: F) -> async_task::Task + where + F: Future + Send + 'scope, + T: Send + 'scope, + { + self.context.run_in_worker(|worker| { + self.job_counter.increment(); + + let this = SendPtr::new_const(&self.job_counter).unwrap(); + + let future = async move { + let _guard = DropGuard::new(move || unsafe { + this.as_ref().decrement(); + }); + future.await + }; + + let schedule = move |runnable: Runnable| { + #[align(8)] + unsafe fn harness(this: *const (), job: *const Job) { + unsafe { + let runnable = + Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); + runnable.run(); + + // SAFETY: job was turned into raw + drop(Box::from_raw(job.cast_mut())); + } + } + + let job = Box::new(Job::::new(harness::, runnable.into_raw())); + + // casting into Job<()> here + worker.push_front(Box::into_raw(job) as _); + }; + + let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; + + runnable.schedule(); + + task + }) + } + + #[allow(dead_code)] + fn spawn_async<'a, T, Fut, Fn>(&'a self, f: Fn) -> async_task::Task + where + Fn: FnOnce(&Scope) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + T: Send + 'static, + { + let this = SendPtr::new_const(self).unwrap(); + let future = async move { f(unsafe { this.as_ref() }).await }; + + self.spawn_future(future) + } + + #[inline] + pub fn join(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce(&Self) -> RA + Send, + B: FnOnce(&Self) -> RB + Send, + { + let worker = WorkerThread::current_ref().expect("join is run in workerthread."); + let this = SendPtr::new_const(self).unwrap(); + + worker.join_heartbeat_every::<_, _, _, _, 64>( + { + let this = this; + move || a(unsafe { this.as_ref() }) + }, + { + let this = this; + move || b(unsafe { this.as_ref() }) + }, + ) + } + + unsafe fn from_context(ctx: Arc) -> Self { + Self { + context: ctx.clone(), + job_counter: CountLatch::new(WakeLatch::new(ctx, 0)), + panic: AtomicPtr::new(ptr::null_mut()), + _pd: PhantomData, + } + } +} diff --git a/distaff/src/threadpool.rs b/distaff/src/threadpool.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/distaff/src/threadpool.rs @@ -0,0 +1 @@ + diff --git a/distaff/src/util.rs b/distaff/src/util.rs new file mode 100644 index 0000000..5f250cd --- /dev/null +++ b/distaff/src/util.rs @@ -0,0 +1,404 @@ +use core::{ + borrow::{Borrow, BorrowMut}, + cell::UnsafeCell, + fmt::Display, + marker::PhantomData, + mem::{self, ManuallyDrop, MaybeUninit}, + ops::{Deref, DerefMut}, + ptr::NonNull, + sync::atomic::{AtomicPtr, Ordering}, +}; + +use alloc::boxed::Box; + +/// A guard that runs a closure when it is dropped. +pub struct DropGuard(UnsafeCell>); + +impl DropGuard +where + F: FnOnce(), +{ + pub fn new(f: F) -> DropGuard { + Self(UnsafeCell::new(ManuallyDrop::new(f))) + } +} + +impl Drop for DropGuard +where + F: FnOnce(), +{ + fn drop(&mut self) { + // SAFETY: We are the only owner of `self.0`, and we ensure that the + // closure is only called once. + unsafe { + ManuallyDrop::take(&mut *self.0.get())(); + } + } +} + +#[repr(transparent)] +#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct SendPtr(NonNull); + +impl Copy for SendPtr {} + +impl Clone for SendPtr { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl core::fmt::Pointer for SendPtr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + as core::fmt::Pointer>::fmt(&self.0, f) + } +} + +unsafe impl Send for SendPtr {} + +impl Deref for SendPtr { + type Target = NonNull; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SendPtr { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl SendPtr { + pub const fn new(ptr: *mut T) -> Option { + match NonNull::new(ptr) { + Some(ptr) => Some(Self(ptr)), + None => None, + } + } + + /// ptr must be non-null + #[allow(dead_code)] + pub const unsafe fn new_unchecked(ptr: *mut T) -> Self { + unsafe { Self(NonNull::new_unchecked(ptr)) } + } + + pub const fn new_const(ptr: *const T) -> Option { + Self::new(ptr.cast_mut()) + } + + /// ptr must be non-null + #[allow(dead_code)] + pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self { + unsafe { Self::new_unchecked(ptr.cast_mut()) } + } +} + +/// A tagged atomic pointer that can store a pointer and a tag `BITS` wide in the same space +/// as the pointer. +/// The pointer must be aligned to `BITS` bits, i.e. `align_of::() >= 2^BITS`. +#[repr(transparent)] +pub struct TaggedAtomicPtr { + ptr: AtomicPtr<()>, + _pd: PhantomData, +} + +impl TaggedAtomicPtr { + const fn mask() -> usize { + !(!0usize << BITS) + } + + pub fn new(ptr: *mut T, tag: usize) -> TaggedAtomicPtr { + debug_assert!(core::mem::align_of::().ilog2() as u8 >= BITS); + let mask = Self::mask(); + Self { + ptr: AtomicPtr::new(ptr.with_addr((ptr.addr() & !mask) | (tag & mask)).cast()), + _pd: PhantomData, + } + } + + pub fn ptr(&self, order: Ordering) -> NonNull { + unsafe { + NonNull::new_unchecked( + self.ptr + .load(order) + .map_addr(|addr| addr & !Self::mask()) + .cast(), + ) + } + } + + pub fn tag(&self, order: Ordering) -> usize { + self.ptr.load(order).addr() & Self::mask() + } + + /// returns tag + #[inline(always)] + fn compare_exchange_tag_inner( + &self, + old: usize, + new: usize, + success: Ordering, + failure: Ordering, + cmpxchg: fn( + &AtomicPtr<()>, + *mut (), + *mut (), + Ordering, + Ordering, + ) -> Result<*mut (), *mut ()>, + ) -> Result { + let mask = Self::mask(); + let old_ptr = self.ptr.load(failure); + + let old = old_ptr.map_addr(|addr| (addr & !mask) | (old & mask)); + let new = old_ptr.map_addr(|addr| (addr & !mask) | (new & mask)); + + let result = cmpxchg(&self.ptr, old, new, success, failure); + + result + .map(|ptr| ptr.addr() & mask) + .map_err(|ptr| ptr.addr() & mask) + } + + /// returns tag + #[inline] + #[allow(dead_code)] + pub fn compare_exchange_tag( + &self, + old: usize, + new: usize, + success: Ordering, + failure: Ordering, + ) -> Result { + self.compare_exchange_tag_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange, + ) + } + + /// returns tag + #[inline] + pub fn compare_exchange_weak_tag( + &self, + old: usize, + new: usize, + success: Ordering, + failure: Ordering, + ) -> Result { + self.compare_exchange_tag_inner( + old, + new, + success, + failure, + AtomicPtr::<()>::compare_exchange_weak, + ) + } + + #[allow(dead_code)] + pub fn set_ptr(&self, ptr: *mut T, success: Ordering, failure: Ordering) { + let mask = Self::mask(); + let ptr = ptr.cast::<()>(); + loop { + let old = self.ptr.load(failure); + let new = ptr.map_addr(|addr| (addr & !mask) | (old.addr() & mask)); + if self + .ptr + .compare_exchange_weak(old, new, success, failure) + .is_ok() + { + break; + } + } + } + + pub fn set_tag(&self, tag: usize, success: Ordering, failure: Ordering) { + let mask = Self::mask(); + loop { + let ptr = self.ptr.load(failure); + let new = ptr.map_addr(|addr| (addr & !mask) | (tag & mask)); + + if self + .ptr + .compare_exchange_weak(ptr, new, success, failure) + .is_ok() + { + break; + } + } + } + + pub fn ptr_and_tag(&self, order: Ordering) -> (NonNull, usize) { + let mask = Self::mask(); + let ptr = self.ptr.load(order); + let tag = ptr.addr() & mask; + let ptr = ptr.map_addr(|addr| addr & !mask); + let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) }; + (ptr, tag) + } +} + +/// A small box that can store a value inline if the size and alignment of T is +/// less than or equal to the size and alignment of a boxed type. Typically this +/// will be `sizeof::()` bytes, but might be larger if +/// `sizeof::>()` is larger than that, like it is for dynamically sized +/// types like `[T]` or `dyn Trait`. +#[derive(Debug)] +#[repr(transparent)] +// We use a box here because a box can be unboxed, while a pointer cannot. +pub struct SmallBox(pub MaybeUninit>); + +impl Display for SmallBox { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + (**self).fmt(f) + } +} + +impl Ord for SmallBox { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.as_ref().cmp(other.as_ref()) + } +} + +impl PartialOrd for SmallBox { + fn partial_cmp(&self, other: &Self) -> Option { + self.as_ref().partial_cmp(other.as_ref()) + } +} + +impl Eq for SmallBox {} + +impl PartialEq for SmallBox { + fn eq(&self, other: &Self) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl Default for SmallBox { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl Clone for SmallBox { + fn clone(&self) -> Self { + Self::new(self.as_ref().clone()) + } +} + +impl Deref for SmallBox { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl DerefMut for SmallBox { + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut() + } +} + +impl AsRef for SmallBox { + fn as_ref(&self) -> &T { + Self::as_ref(self) + } +} +impl AsMut for SmallBox { + fn as_mut(&mut self) -> &mut T { + Self::as_mut(self) + } +} + +impl Borrow for SmallBox { + fn borrow(&self) -> &T { + &**self + } +} +impl BorrowMut for SmallBox { + fn borrow_mut(&mut self) -> &mut T { + &mut **self + } +} + +impl SmallBox { + /// must only be called once. takes a reference so this can be called in + /// drop() + unsafe fn get_unchecked(&self, inline: bool) -> T { + if inline { + unsafe { mem::transmute_copy::>, T>(&self.0) } + } else { + unsafe { *self.0.assume_init_read() } + } + } + + pub fn as_ref(&self) -> &T { + unsafe { + if Self::is_inline() { + mem::transmute::<&MaybeUninit>, &T>(&self.0) + } else { + self.0.assume_init_ref() + } + } + } + pub fn as_mut(&mut self) -> &mut T { + unsafe { + if Self::is_inline() { + mem::transmute::<&mut MaybeUninit>, &mut T>(&mut self.0) + } else { + self.0.assume_init_mut() + } + } + } + + pub fn into_inner(self) -> T { + let this = ManuallyDrop::new(self); + let inline = Self::is_inline(); + + // SAFETY: inline is correctly calculated and this function + // consumes `self` + unsafe { this.get_unchecked(inline) } + } + + #[inline(always)] + pub const fn is_inline() -> bool { + // the value can be stored inline iff the size of T is equal or + // smaller than the size of the boxed type and the alignment of the + // boxed type is an integer multiple of the alignment of T + mem::size_of::() <= mem::size_of::>>() + && mem::align_of::>>() % mem::align_of::() == 0 + } + + pub fn new(value: T) -> Self { + let inline = Self::is_inline(); + + if inline { + let mut this = MaybeUninit::new(Self(MaybeUninit::uninit())); + unsafe { + this.as_mut_ptr().cast::().write(value); + this.assume_init() + } + } else { + Self(MaybeUninit::new(Box::new(value))) + } + } +} + +impl Drop for SmallBox { + fn drop(&mut self) { + // drop contained value. + drop(unsafe { self.get_unchecked(Self::is_inline()) }); + } +} + +/// returns the number of available hardware threads, or 1 if it cannot be determined. +pub fn available_parallelism() -> usize { + std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1) +} diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs new file mode 100644 index 0000000..f68a8bb --- /dev/null +++ b/distaff/src/workerthread.rs @@ -0,0 +1,396 @@ +use std::{ + cell::{Cell, UnsafeCell}, + ptr::NonNull, + sync::Arc, + time::Duration, +}; + +use crossbeam_utils::CachePadded; +use parking_lot_core::SpinWait; + +use crate::{ + context::{Context, Heartbeat}, + job::{Job, JobList, JobResult}, + latch::{AsCoreLatch, CoreLatch, Probe}, + util::DropGuard, +}; + +pub struct WorkerThread { + pub(crate) context: Arc, + pub(crate) index: usize, + pub(crate) queue: UnsafeCell, + heartbeat: Arc>, + pub(crate) join_count: Cell, +} + +thread_local! { + static WORKER: UnsafeCell>> = const { UnsafeCell::new(None) }; +} + +impl WorkerThread { + pub fn new_in(context: Arc) -> Self { + let (heartbeat, index) = context.shared().new_heartbeat(); + + Self { + context, + index, + queue: UnsafeCell::new(JobList::new()), + heartbeat, + join_count: Cell::new(0), + } + } + + fn new() -> Self { + let context = Context::global_context().clone(); + Self::new_in(context) + } +} + +impl WorkerThread { + pub fn run(self: Box) { + let this = Box::into_raw(self); + unsafe { + Self::set_current(this); + } + + let _guard = DropGuard::new(|| unsafe { + // SAFETY: this is only called when the thread is exiting + Self::unset_current(); + Self::drop_in_place(this); + }); + + tracing::trace!("WorkerThread::run: starting worker thread"); + + unsafe { + (&*this).run_inner(); + } + + tracing::trace!("WorkerThread::run: worker thread finished"); + } + + fn run_inner(&self) { + let mut job = self.context.shared().pop_job(); + 'outer: loop { + let mut guard = loop { + if let Some(job) = job { + self.execute(job); + } + + let mut guard = self.context.shared(); + if guard.should_exit() { + // if the context is stopped, break out of the outer loop which + // will exit the thread. + break 'outer; + } + + match guard.pop_job() { + Some(job) => { + tracing::trace!("worker: popping job: {:?}", job); + // found job, continue inner loop + continue; + } + None => { + tracing::trace!("worker: no job, waiting for shared job"); + // no more jobs, break out of inner loop and wait for shared job + break guard; + } + } + }; + + self.context.shared_job.wait(&mut guard); + job = guard.pop_job(); + } + } +} + +impl WorkerThread { + #[inline(always)] + fn tick(&self) { + if self.heartbeat.is_pending() { + self.heartbeat_cold(); + } + } + + #[inline] + fn execute(&self, job: NonNull) { + self.tick(); + Job::execute(job); + } + + #[cold] + fn heartbeat_cold(&self) { + let mut guard = self.context.shared(); + + if !guard.jobs.contains_key(&self.index) { + if let Some(job) = self.pop_back() { + tracing::trace!("heartbeat: sharing job: {:?}", job); + unsafe { + job.as_ref().set_pending(); + } + guard.jobs.insert(self.index, job); + self.context.notify_shared_job(); + } + } + + self.heartbeat.clear(); + } +} + +impl WorkerThread { + #[inline] + pub fn pop_back(&self) -> Option> { + unsafe { self.queue.as_mut_unchecked().pop_back() } + } + + #[inline] + pub fn push_back(&self, job: *const Job) { + unsafe { self.queue.as_mut_unchecked().push_back(job) } + } + + #[inline] + pub fn pop_front(&self) -> Option> { + unsafe { self.queue.as_mut_unchecked().pop_front() } + } + + #[inline] + pub fn push_front(&self, job: *const Job) { + unsafe { self.queue.as_mut_unchecked().push_front(job) } + } +} + +impl WorkerThread { + #[inline] + pub fn current_ref<'a>() -> Option<&'a Self> { + unsafe { (*WORKER.with(UnsafeCell::get)).map(|ptr| ptr.as_ref()) } + } + + unsafe fn set_current(this: *const Self) { + WORKER.with(|cell| { + unsafe { + // SAFETY: this cell is only ever accessed from the current thread + assert!( + (&mut *cell.get()) + .replace(NonNull::new_unchecked( + this as *const WorkerThread as *mut WorkerThread, + )) + .is_none() + ); + } + }); + } + + unsafe fn unset_current() { + WORKER.with(|cell| { + unsafe { + // SAFETY: this cell is only ever accessed from the current thread + (&mut *cell.get()).take(); + } + }); + } + + unsafe fn drop_in_place(this: *mut Self) { + unsafe { + this.drop_in_place(); + drop(Box::from_raw(this)); + } + } +} + +pub struct HeartbeatThread { + ctx: Arc, +} + +impl HeartbeatThread { + const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100); + + pub fn new(ctx: Arc) -> Self { + Self { ctx } + } + + pub fn run(self) { + tracing::trace!("new heartbeat thread {:?}", std::thread::current()); + + let mut i = 0; + loop { + let sleep_for = { + let mut guard = self.ctx.shared(); + if guard.should_exit() { + break; + } + + let mut n = 0; + guard.heartbeats.retain(|_, b| { + b.upgrade() + .inspect(|heartbeat| { + if n == i { + if heartbeat.set_pending() { + heartbeat.latch.set(); + } + } + n += 1; + }) + .is_some() + }); + let num_heartbeats = guard.heartbeats.len(); + + drop(guard); + + if i >= num_heartbeats { + i = 0; + } else { + i += 1; + } + + Self::HEARTBEAT_INTERVAL.checked_div(num_heartbeats as u32) + }; + + if let Some(duration) = sleep_for { + std::thread::sleep(duration); + } + } + } +} + +impl WorkerThread { + #[cold] + fn wait_until_latch_cold(&self, latch: &CoreLatch) { + // does this optimise? + assert!(!latch.probe()); + + 'outer: while !latch.probe() { + // take a shared job, if it exists + if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { + self.execute(shared_job); + } + + // process local jobs before locking shared context + while let Some(job) = self.pop_front() { + unsafe { + job.as_ref().set_pending(); + } + self.execute(job); + } + + while !latch.probe() { + let job = self.context.shared().pop_job(); + + match job { + Some(job) => { + self.execute(job); + + continue 'outer; + } + None => { + tracing::trace!("waiting for shared job, thread id: {:?}", self.index); + + // TODO: wait on latch? if we have something that can + // signal being done, e.g. can be waited on instead of + // shared jobs, we should wait on it instead, but we + // would also want to receive shared jobs still? + // Spin? probably just wastes CPU time. + // self.context.shared_job.wait(&mut guard); + // if spin.spin() { + // // wait for more shared jobs. + // // self.context.shared_job.wait(&mut guard); + // return; + // } + // Yield? same as spinning, really, so just exit and let the upstream use wait + // std::thread::yield_now(); + + self.heartbeat.latch.wait_and_reset(); + // since we were sleeping, the shared job can't be populated, + // so resuming the inner loop is fine. + } + } + } + } + return; + } + + pub fn wait_until_job(&self, job: &Job, latch: &CoreLatch) -> Option> { + // we've already checked that the job was popped from the queue + // check if shared job is our job + if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { + if core::ptr::eq(shared_job.as_ptr(), job as *const Job as _) { + // this is the job we are looking for, so we want to + // short-circuit and call it inline + return None; + } else { + // this isn't the job we are looking for, but we still need to + // execute it + self.execute(shared_job); + } + } + + // do the usual thing and wait for the job's latch + if !latch.probe() { + self.wait_until_latch_cold(latch); + } + + Some(job.wait()) + } + + pub fn wait_until_latch(&self, latch: &L) + where + L: AsCoreLatch, + { + let latch = latch.as_core_latch(); + if !latch.probe() { + self.wait_until_latch_cold(latch) + } + } + + #[inline] + fn wait_until_predicate(&self, pred: F) + where + F: Fn() -> bool, + { + 'outer: while !pred() { + // take a shared job, if it exists + if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) { + self.execute(shared_job); + } + + // process local jobs before locking shared context + while let Some(job) = self.pop_front() { + unsafe { + job.as_ref().set_pending(); + } + self.execute(job); + } + + while !pred() { + let mut guard = self.context.shared(); + let mut _spin = SpinWait::new(); + + match guard.pop_job() { + Some(job) => { + drop(guard); + self.execute(job); + + continue 'outer; + } + None => { + tracing::trace!("waiting for shared job, thread id: {:?}", self.index); + + // TODO: wait on latch? if we have something that can + // signal being done, e.g. can be waited on instead of + // shared jobs, we should wait on it instead, but we + // would also want to receive shared jobs still? + // Spin? probably just wastes CPU time. + // self.context.shared_job.wait(&mut guard); + // if spin.spin() { + // // wait for more shared jobs. + // // self.context.shared_job.wait(&mut guard); + // return; + // } + // Yield? same as spinning, really, so just exit and let the upstream use wait + // std::thread::yield_now(); + return; + } + } + } + } + return; + } +}