mod util { use std::{ cell::UnsafeCell, marker::PhantomData, mem::ManuallyDrop, ops::{Deref, DerefMut}, ptr::NonNull, sync::atomic::{AtomicPtr, Ordering}, }; pub struct DropGuard(UnsafeCell>); impl DropGuard where F: FnOnce(), { pub fn new(f: F) -> DropGuard { Self(UnsafeCell::new(ManuallyDrop::new(f))) } } impl Drop for DropGuard where F: FnOnce(), { fn drop(&mut self) { unsafe { ManuallyDrop::take(&mut *self.0.get())(); } } } #[repr(transparent)] #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct SendPtr(NonNull); impl Copy for SendPtr {} impl Clone for SendPtr { fn clone(&self) -> Self { Self(self.0.clone()) } } impl std::fmt::Pointer for SendPtr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { as core::fmt::Pointer>::fmt(&self.0, f) } } unsafe impl Send for SendPtr {} impl Deref for SendPtr { type Target = NonNull; fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for SendPtr { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } impl SendPtr { pub const fn new(ptr: *mut T) -> Option { match NonNull::new(ptr) { Some(ptr) => Some(Self(ptr)), None => None, } } #[allow(dead_code)] pub const unsafe fn new_unchecked(ptr: *mut T) -> Self { unsafe { Self(NonNull::new_unchecked(ptr)) } } pub const fn new_const(ptr: *const T) -> Option { Self::new(ptr.cast_mut()) } #[allow(dead_code)] pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self { Self::new_unchecked(ptr.cast_mut()) } } // Miri doesn't like tagging pointers that it doesn't know the alignment of. // This includes function pointers, which aren't guaranteed to be aligned to // anything, but generally have an alignment of 8, and can be specified to // be aligned to `n` with `#[align(n)]`. #[repr(transparent)] pub struct TaggedAtomicPtr { ptr: AtomicPtr<()>, _pd: PhantomData, } impl TaggedAtomicPtr { const fn mask() -> usize { !(!0usize << BITS) } pub fn new(ptr: *mut T, tag: usize) -> TaggedAtomicPtr { debug_assert!(core::mem::align_of::().ilog2() as usize >= BITS); let mask = Self::mask(); Self { ptr: AtomicPtr::new(ptr.with_addr((ptr.addr() & !mask) | (tag & mask)).cast()), _pd: PhantomData, } } pub fn ptr(&self, order: Ordering) -> NonNull { unsafe { NonNull::new_unchecked( self.ptr .load(order) .map_addr(|addr| addr & !Self::mask()) .cast(), ) } } pub fn tag(&self, order: Ordering) -> usize { self.ptr.load(order).addr() & Self::mask() } /// returns tag #[inline(always)] fn compare_exchange_tag_inner( &self, old: usize, new: usize, success: Ordering, failure: Ordering, cmpxchg: fn( &AtomicPtr<()>, *mut (), *mut (), Ordering, Ordering, ) -> Result<*mut (), *mut ()>, ) -> Result { let mask = Self::mask(); let old_ptr = self.ptr.load(failure); let old = old_ptr.map_addr(|addr| (addr & !mask) | (old & mask)); let new = old_ptr.map_addr(|addr| (addr & !mask) | (new & mask)); let result = cmpxchg(&self.ptr, old, new, success, failure); result .map(|ptr| ptr.addr() & mask) .map_err(|ptr| ptr.addr() & mask) } /// returns tag #[inline] #[allow(dead_code)] pub fn compare_exchange_tag( &self, old: usize, new: usize, success: Ordering, failure: Ordering, ) -> Result { self.compare_exchange_tag_inner( old, new, success, failure, AtomicPtr::<()>::compare_exchange, ) } /// returns tag #[inline] pub fn compare_exchange_weak_tag( &self, old: usize, new: usize, success: Ordering, failure: Ordering, ) -> Result { self.compare_exchange_tag_inner( old, new, success, failure, AtomicPtr::<()>::compare_exchange_weak, ) } #[allow(dead_code)] pub fn set_ptr(&self, ptr: *mut T, success: Ordering, failure: Ordering) { let mask = Self::mask(); let ptr = ptr.cast::<()>(); loop { let old = self.ptr.load(failure); let new = ptr.map_addr(|addr| (addr & !mask) | (old.addr() & mask)); if self .ptr .compare_exchange_weak(old, new, success, failure) .is_ok() { break; } } } pub fn set_tag(&self, tag: usize, success: Ordering, failure: Ordering) { let mask = Self::mask(); loop { let ptr = self.ptr.load(failure); let new = ptr.map_addr(|addr| (addr & !mask) | (tag & mask)); if self .ptr .compare_exchange_weak(ptr, new, success, failure) .is_ok() { break; } } } pub fn ptr_and_tag(&self, order: Ordering) -> (NonNull, usize) { let mask = Self::mask(); let ptr = self.ptr.load(order); let tag = ptr.addr() & mask; let ptr = ptr.map_addr(|addr| addr & !mask); let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) }; (ptr, tag) } } } mod job { use std::{ any::Any, borrow::{Borrow, BorrowMut}, cell::UnsafeCell, fmt::{Debug, Display}, hint::cold_path, mem::{self, ManuallyDrop, MaybeUninit}, ops::{Deref, DerefMut}, panic::resume_unwind, ptr::{self, NonNull}, sync::atomic::Ordering, thread::Thread, }; use parking_lot_core::SpinWait; use crate::latch::Latch; use super::util::TaggedAtomicPtr; #[derive(Debug)] #[repr(transparent)] pub struct SmallBox(pub MaybeUninit>); impl Display for SmallBox { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { (**self).fmt(f) } } impl Ord for SmallBox { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.as_ref().cmp(other.as_ref()) } } impl PartialOrd for SmallBox { fn partial_cmp(&self, other: &Self) -> Option { self.as_ref().partial_cmp(other.as_ref()) } } impl Eq for SmallBox {} impl PartialEq for SmallBox { fn eq(&self, other: &Self) -> bool { self.as_ref().eq(other.as_ref()) } } impl Default for SmallBox { fn default() -> Self { Self::new(Default::default()) } } impl Clone for SmallBox { fn clone(&self) -> Self { Self::new(self.as_ref().clone()) } } impl Deref for SmallBox { type Target = T; fn deref(&self) -> &Self::Target { self.as_ref() } } impl DerefMut for SmallBox { fn deref_mut(&mut self) -> &mut Self::Target { self.as_mut() } } impl AsRef for SmallBox { fn as_ref(&self) -> &T { Self::as_ref(self) } } impl AsMut for SmallBox { fn as_mut(&mut self) -> &mut T { Self::as_mut(self) } } impl Borrow for SmallBox { fn borrow(&self) -> &T { &**self } } impl BorrowMut for SmallBox { fn borrow_mut(&mut self) -> &mut T { &mut **self } } impl SmallBox { /// must only be called once. takes a reference so this can be called in /// drop() unsafe fn get_unchecked(&self, inline: bool) -> T { if inline { unsafe { mem::transmute_copy::>, T>(&self.0) } } else { unsafe { *self.0.assume_init_read() } } } pub fn as_ref(&self) -> &T { unsafe { if Self::is_inline() { mem::transmute::<&MaybeUninit>, &T>(&self.0) } else { self.0.assume_init_ref() } } } pub fn as_mut(&mut self) -> &mut T { unsafe { if Self::is_inline() { mem::transmute::<&mut MaybeUninit>, &mut T>(&mut self.0) } else { self.0.assume_init_mut() } } } pub fn into_inner(self) -> T { let this = ManuallyDrop::new(self); let inline = Self::is_inline(); // SAFETY: inline is correctly calculated and this function // consumes `self` unsafe { this.get_unchecked(inline) } } pub fn is_inline() -> bool { // the value can be stored inline iff the size of T is equal or // smaller than the size of the boxed type and the alignment of the // boxed type is an integer multiple of the alignment of T mem::size_of::() <= mem::size_of::>>() && mem::align_of::>>() % mem::align_of::() == 0 } pub fn new(value: T) -> Self { let inline = Self::is_inline(); if inline { let mut this = MaybeUninit::new(Self(MaybeUninit::uninit())); unsafe { this.as_mut_ptr().cast::().write(value); this.assume_init() } } else { Self(MaybeUninit::new(Box::new(value))) } } } impl Drop for SmallBox { fn drop(&mut self) { // drop contained value. drop(unsafe { self.get_unchecked(Self::is_inline()) }); } } #[repr(u8)] #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum JobState { Empty, Locked = 1, Pending, Finished, // Inline = 1 << (u8::BITS - 1), // IsError = 1 << (u8::BITS - 2), } impl JobState { #[allow(dead_code)] const MASK: u8 = 0; // Self::Inline as u8 | Self::IsError as u8; fn from_u8(v: u8) -> Option { match v { 0 => Some(Self::Empty), 1 => Some(Self::Locked), 2 => Some(Self::Pending), 3 => Some(Self::Finished), _ => None, } } } // for some reason I confused head and tail here and the list is something like this: // tail <-> job1 <-> job2 <-> ... <-> head pub struct JobList { // these cannot be boxes because boxes are noalias. head: NonNull, tail: NonNull, job_count: usize, } impl Debug for JobList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("JobList") .field("head", &self.head) .field("tail", &self.tail) .field_with("jobs", |f| { let mut list = f.debug_list(); // SAFETY: head always has prev let mut job = unsafe { self.head().as_ref().link_mut().prev.unwrap() }; loop { if job == self.tail() { break; } let job_ref = unsafe { job.as_ref() }; list.entry(&job_ref); // SAFETY: we are iterating over the linked list if let Some(next) = unsafe { job_ref.link_mut().prev } { job = next; } else { tracing::trace!("prev job is none?"); break; }; } list.finish() }) .finish() } } impl JobList { pub fn new() -> JobList { let head = Box::into_raw(Box::new(Job::empty())); let tail = Box::into_raw(Box::new(Job::empty())); // head and tail point at themselves unsafe { (&mut *(&mut *head).err_or_link.get()).link.next = None; (&mut *(&mut *head).err_or_link.get()).link.prev = Some(NonNull::new_unchecked(tail)); (&mut *(&mut *tail).err_or_link.get()).link.prev = None; (&mut *(&mut *tail).err_or_link.get()).link.next = Some(NonNull::new_unchecked(head)); Self { head: NonNull::new_unchecked(head), tail: NonNull::new_unchecked(tail), job_count: 0, } } } fn head(&self) -> NonNull { self.head } fn tail(&self) -> NonNull { self.tail } /// elem must be valid until it is popped. pub unsafe fn push_front(&mut self, elem: *const Job) { self.job_count += 1; let head_link = unsafe { self.head.as_ref().link_mut() }; // SAFETY: head will always have a previous element. let prev = head_link.prev.unwrap(); let prev_link = unsafe { prev.as_ref().link_mut() }; let elem_ptr = unsafe { NonNull::new_unchecked(elem as _) }; head_link.prev = Some(elem_ptr); prev_link.next = Some(elem_ptr); let elem_link = unsafe { (*elem).link_mut() }; elem_link.prev = Some(prev); elem_link.next = Some(self.head()); } /// elem must be valid until it is popped. pub unsafe fn push_back(&mut self, elem: *const Job) { self.job_count += 1; let tail_link = unsafe { self.tail.as_ref().link_mut() }; // SAFETY: tail will always have a previous element. let next = tail_link.next.unwrap(); let next_link = unsafe { next.as_ref().link_mut() }; let elem_ptr = unsafe { NonNull::new_unchecked(elem as _) }; tail_link.next = Some(elem_ptr); next_link.prev = Some(elem_ptr); let elem_link = unsafe { (*elem).link_mut() }; elem_link.next = Some(next); elem_link.prev = Some(self.tail()); } #[allow(dead_code)] pub fn pop_front(&mut self) -> Option> { self.job_count -= 1; let head_link = unsafe { self.head.as_ref().link_mut() }; // SAFETY: head will always have a previous element. let elem = head_link.prev.unwrap(); let elem_link = unsafe { elem.as_ref().link_mut() }; let prev = elem_link.prev?.as_ptr(); head_link.prev = unsafe { Some(NonNull::new_unchecked(prev)) }; let prev_link = unsafe { (&*prev).link_mut() }; prev_link.next = Some(self.head()); Some(elem) } pub fn pop_back(&mut self) -> Option> { self.job_count -= 1; // TODO: next and elem might be the same let tail_link = unsafe { self.tail.as_ref().link_mut() }; // SAFETY: head will always have a previous element. let elem = tail_link.next.unwrap(); let elem_link = unsafe { elem.as_ref().link_mut() }; let next = elem_link.next?.as_ptr(); tail_link.next = unsafe { Some(NonNull::new_unchecked(next)) }; let next_link = unsafe { (&*next).link_mut() }; next_link.prev = Some(self.tail()); Some(elem) } #[allow(dead_code)] pub fn is_empty(&self) -> bool { self.job_count == 0 } pub fn len(&self) -> usize { self.job_count } } impl Drop for JobList { fn drop(&mut self) { // Need to drop the head and tail, which were allocated on the heap. // elements of the list are managed externally. unsafe { drop((Box::from_non_null(self.head), Box::from_non_null(self.tail))); }; } } union ValueOrThis { uninit: (), value: ManuallyDrop>, this: NonNull<()>, } #[derive(Debug, PartialEq, Eq)] struct Link { prev: Option>, next: Option>, } impl Clone for Link { fn clone(&self) -> Self { Self { prev: self.prev.clone(), next: self.next.clone(), } } } // because Copy is invariant over `T` impl Copy for Link {} union LinkOrError { link: Link, waker: ManuallyDrop>, error: ManuallyDrop>>, } #[repr(C)] pub struct Job { /// tagged pointer, 8-aligned harness_and_state: TaggedAtomicPtr, /// NonNull<()> before execute(), Value after val_or_this: UnsafeCell>, /// (prev,next) before execute(), Box<...> after err_or_link: UnsafeCell>, // _phantom: PhantomPinned, } impl Debug for Job { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let state = JobState::from_u8(self.harness_and_state.tag(Ordering::Relaxed) as u8).unwrap(); let mut debug = f.debug_struct("Job"); debug.field("state", &state).field_with("harness", |f| { write!(f, "{:?}", self.harness_and_state.ptr(Ordering::Relaxed)) }); match state { JobState::Empty => { debug .field_with("this", |f| { write!(f, "{:?}", unsafe { &(&*self.val_or_this.get()).this }) }) .field_with("link", |f| { write!(f, "{:?}", unsafe { &(&*self.err_or_link.get()).link }) }); } JobState::Locked => { #[derive(Debug)] struct Locked; debug.field("locked", &Locked); } JobState::Pending => { debug .field_with("this", |f| { write!(f, "{:?}", unsafe { &(&*self.val_or_this.get()).this }) }) .field_with("waker", |f| { write!(f, "{:?}", unsafe { &(&*self.err_or_link.get()).waker }) }); } JobState::Finished => { let err = unsafe { &(&*self.err_or_link.get()).error }; let result = match err.as_ref() { Some(err) => Err(err), None => Ok(unsafe { (&*self.val_or_this.get()).value.0.as_ptr() }), }; debug.field("result", &result); } } debug.finish() } } unsafe impl Send for Job {} impl Job { pub fn new(harness: unsafe fn(*const (), *const Job), this: NonNull<()>) -> Job { Self { harness_and_state: TaggedAtomicPtr::new( unsafe { mem::transmute(harness) }, JobState::Empty as usize, ), val_or_this: UnsafeCell::new(ValueOrThis { this }), err_or_link: UnsafeCell::new(LinkOrError { link: Link { prev: None, next: None, }, }), // _phantom: PhantomPinned, } } // Job is passed around type-erased as `Job<()>`, to complete the job we // need to cast it back to the original type. pub unsafe fn transmute_ref(&self) -> &Job { mem::transmute::<&Job, &Job>(self) } /// unwraps the `this` pointer, which is only valid if the job is in the empty state. #[allow(dead_code)] pub unsafe fn unwrap_this(&self) -> NonNull<()> { assert!(self.state() == JobState::Empty as u8); unsafe { (&*self.val_or_this.get()).this } } pub fn empty() -> Job { Self { harness_and_state: TaggedAtomicPtr::new( ptr::dangling_mut(), JobState::Empty as usize, ), val_or_this: UnsafeCell::new(ValueOrThis { this: NonNull::dangling(), }), err_or_link: UnsafeCell::new(LinkOrError { link: Link { prev: None, next: None, }, }), // _phantom: PhantomPinned, } } #[inline] unsafe fn link_mut(&self) -> &mut Link { unsafe { &mut (&mut *self.err_or_link.get()).link } } /// assumes job is in joblist pub unsafe fn unlink(&self) { unsafe { let mut dummy = None; let Link { prev, next } = *self.link_mut(); *prev .map(|ptr| &mut ptr.as_ref().link_mut().next) .unwrap_or(&mut dummy) = next; *next .map(|ptr| &mut ptr.as_ref().link_mut().prev) .unwrap_or(&mut dummy) = prev; } } pub fn state(&self) -> u8 { self.harness_and_state.tag(Ordering::Relaxed) as u8 } pub fn wait(&self) -> JobResult { let mut spin = SpinWait::new(); loop { match self.harness_and_state.compare_exchange_weak_tag( JobState::Pending as usize, JobState::Locked as usize, Ordering::Acquire, Ordering::Relaxed, ) { // if still pending, sleep until completed Ok(state) => { debug_assert_eq!(state, JobState::Pending as usize); unsafe { *(&mut *self.err_or_link.get()).waker = Some(std::thread::current()); } self.harness_and_state.set_tag( JobState::Pending as usize, Ordering::Release, Ordering::Relaxed, ); std::thread::park(); spin.reset(); // after sleeping, state should be `Finished` } Err(state) => { // job finished under us, check if it was successful if state == JobState::Finished as usize { let err = unsafe { (&mut *self.err_or_link.get()).error.take() }; let result: std::thread::Result = if let Some(err) = err { cold_path(); Err(err) } else { let val = unsafe { ManuallyDrop::take(&mut (&mut *self.val_or_this.get()).value) }; Ok(val.into_inner()) }; return JobResult::new(result); } else { // spin until lock is released. tracing::trace!("spin-waiting for job: {:?}", self); spin.spin(); } } } } } /// must be called before `execute()` pub fn set_pending(&self) { let mut spin = SpinWait::new(); loop { match self.harness_and_state.compare_exchange_weak_tag( JobState::Empty as usize, JobState::Pending as usize, Ordering::Acquire, Ordering::Relaxed, ) { Ok(state) => { debug_assert_eq!(state, JobState::Empty as usize); // set waker to None unsafe { (&mut *self.err_or_link.get()).waker = ManuallyDrop::new(None); } return; } Err(_) => { // debug_assert_ne!(state, JobState::Empty as usize); eprintln!("######## what the sigma?"); spin.spin(); } } } } pub fn execute(job: NonNull) { tracing::trace!( "thread {:?}: executing job: {:?}", std::thread::current().name(), job ); // SAFETY: self is non-null unsafe { let this = job.as_ref(); let (ptr, state) = this.harness_and_state.ptr_and_tag(Ordering::Relaxed); debug_assert_eq!(state, JobState::Pending as usize); let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr()); let this = (*this.val_or_this.get()).this; harness(this.as_ptr().cast(), job.as_ptr()); } } pub(crate) fn complete(&self, result: std::thread::Result) { let mut spin = SpinWait::new(); loop { match self.harness_and_state.compare_exchange_weak_tag( JobState::Pending as usize, JobState::Locked as usize, Ordering::Acquire, Ordering::Relaxed, ) { Ok(state) => { debug_assert_eq!(state, JobState::Pending as usize); break; } Err(_) => { // debug_assert_ne!(state, JobState::Pending as usize); spin.spin(); } } } let waker = unsafe { (&mut *self.err_or_link.get()).waker.take() }; match result { Ok(val) => unsafe { (&mut *self.val_or_this.get()).value = ManuallyDrop::new(SmallBox::new(val)); (&mut *self.err_or_link.get()).error = ManuallyDrop::new(None); }, Err(err) => unsafe { (&mut *self.val_or_this.get()).uninit = (); (&mut *self.err_or_link.get()).error = ManuallyDrop::new(Some(err)); }, } if let Some(thread) = waker { thread.unpark(); } self.harness_and_state.set_tag( JobState::Finished as usize, Ordering::Release, Ordering::Relaxed, ); } } impl crate::Probe for Job { fn probe(&self) -> bool { self.state() == JobState::Finished as u8 } } #[allow(dead_code)] pub struct HeapJob { f: F, // _phantom: PhantomPinned, } impl HeapJob { #[allow(dead_code)] pub fn new(f: F) -> Box { Box::new(Self { f, // _phantom: PhantomPinned, }) } /// unwraps the job into it's closure. #[allow(dead_code)] pub fn into_inner(self) -> F { self.f } #[allow(dead_code)] pub fn into_boxed_job(self: Box) -> *mut Job<()> where F: FnOnce() -> T + Send, T: Send, { #[align(8)] unsafe fn harness(this: *const (), job: *const Job<()>) where F: FnOnce() -> T + Send, T: Sized + Send, { let job = job.cast_mut(); // turn `this`, which was allocated at (2), into box. // miri complains this is a use-after-free, but it isn't? silly miri... // Turns out this is actually correct on miri's end, but because // we ensure that the scope lives as long as any jobs, this is // actually fine, as far as I can tell. let this = unsafe { Box::from_raw(this.cast::>().cast_mut()) }; let f = this.into_inner(); _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); // drop job (this is fine because the job of a HeapJob is pure POD). ptr::drop_in_place(job); // free box that was allocated at (1) _ = unsafe { Box::>>::from_raw(job.cast()) }; } // (1) allocate box for job Box::into_raw(Box::new(Job::new(harness::, { // (2) convert self into a pointer Box::into_non_null(self).cast() }))) } } pub struct StackJob { latch: L, f: UnsafeCell>, // _phantom: PhantomPinned, } impl StackJob { pub fn new(f: F, latch: L) -> Self { Self { latch, f: UnsafeCell::new(ManuallyDrop::new(f)), // _phantom: PhantomPinned, } } pub unsafe fn unwrap(&self) -> F { unsafe { ManuallyDrop::take(&mut *self.f.get()) } } } impl StackJob { pub fn as_job(&self) -> Job<()> where F: FnOnce() -> T + Send, T: Send, { #[align(8)] unsafe fn harness(this: *const (), job: *const Job<()>) where F: FnOnce() -> T + Send, T: Sized + Send, { let this = unsafe { &*this.cast::>() }; let f = unsafe { this.unwrap() }; let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); let job_ref = unsafe { &*job.cast::>() }; job_ref.complete(result); crate::latch::Latch::set_raw(&this.latch); } Job::new(harness::, unsafe { NonNull::new_unchecked(&*self as *const _ as *mut ()) }) } } pub struct JobResult { result: std::thread::Result, } impl JobResult { pub fn new(result: std::thread::Result) -> Self { Self { result } } /// convert JobResult into a thread result. #[allow(dead_code)] pub fn into_inner(self) -> std::thread::Result { self.result } // unwraps the result, propagating panics pub fn into_result(self) -> T { match self.result { Ok(val) => val, Err(payload) => { cold_path(); resume_unwind(payload); } } } } } use std::{ any::Any, cell::{Cell, UnsafeCell}, collections::BTreeMap, future::Future, hint::cold_path, marker::PhantomData, mem::MaybeUninit, ptr::{self, NonNull}, sync::{ atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}, Arc, OnceLock, Weak, }, time::Duration, }; use async_task::Runnable; use crossbeam::utils::CachePadded; use job::*; use parking_lot::{Condvar, Mutex}; use parking_lot_core::SpinWait; use util::{DropGuard, SendPtr}; use crate::latch::{AtomicLatch, LatchRef, NopLatch}; #[derive(Debug, Default)] pub struct JobCounter { jobs_pending: AtomicUsize, waker: Mutex>, } impl JobCounter { pub fn increment(&self) { self.jobs_pending.fetch_add(1, Ordering::Relaxed); } pub fn count(&self) -> usize { self.jobs_pending.load(Ordering::Relaxed) } pub fn decrement(&self) { if self.jobs_pending.fetch_sub(1, Ordering::SeqCst) == 1 { if let Some(thread) = self.waker.lock().take() { thread.unpark(); } } } /// must only be called once pub unsafe fn wait(&self) { // SAFETY: this is only called once, so the waker is guaranteed to be None. assert!(self.waker.lock().replace(std::thread::current()).is_none()); let count = self.jobs_pending.load(Ordering::SeqCst); if count > 0 { std::thread::park(); } } } impl crate::latch::Probe for JobCounter { fn probe(&self) -> bool { self.count() == 0 } } struct WorkerThread { context: Arc, index: usize, heartbeat: Arc>, queue: UnsafeCell, join_count: Cell, } pub struct Scope<'scope> { // latch to wait on before the scope finishes job_counter: JobCounter, // local threadpool context: Arc, // panic error panic: AtomicPtr>, // variant lifetime _pd: PhantomData, } thread_local! { static WORKER: UnsafeCell>> = const { UnsafeCell::new(None) }; } impl WorkerThread { /// locks shared context #[allow(dead_code)] fn new() -> Self { let context = Context::global().clone(); Self::new_in(context) } /// locks shared context fn new_in(context: Arc) -> Self { let (heartbeat, index) = context.shared.lock().new_heartbeat(); Self { context, index, heartbeat, queue: UnsafeCell::new(JobList::new()), join_count: Cell::new(0), // _pd: PhantomData, } } #[allow(dead_code)] fn drop_current_guard(new: Option>) -> DropGuard { DropGuard::new(move || unsafe { if let Some(old) = Self::unset_current() { Self::drop_in_place_and_dealloc(old); } else { cold_path(); tracing::error!("WorkerThread drop guard tried to drop None."); } if let Some(new) = new { Self::set_current(new.as_ptr().cast_const()); } }) } unsafe fn drop_in_place_and_dealloc(this: NonNull) { unsafe { let ptr = this.as_ptr(); ptr.drop_in_place(); _ = Box::>::from_raw(ptr.cast()); } } /// sets the thread-local worker to this. unsafe fn set_current(this: *const WorkerThread) { WORKER.with(|ptr| unsafe { _ = (&mut *ptr.get()).insert(NonNull::new_unchecked(this.cast_mut())); }) } /// sets the thread-local worker to None and returns it, if it was occupied. unsafe fn unset_current() -> Option> { WORKER.with(|ptr| unsafe { (&mut *ptr.get()).take() }) } #[allow(dead_code)] #[inline(always)] fn current() -> Option> { unsafe { *WORKER.with(UnsafeCell::get) } } #[inline(always)] fn current_ref<'a>() -> Option<&'a WorkerThread> { unsafe { (*WORKER.with(UnsafeCell::get)).map(|ptr| ptr.as_ref()) } } fn push_front(&self, job: *const Job) { unsafe { self.queue.as_mut_unchecked().push_front(job); } } #[allow(dead_code)] fn push_back(&self, job: *const Job) { unsafe { self.queue.as_mut_unchecked().push_back(job); } } fn pop_back(&self) -> Option> { unsafe { self.queue.as_mut_unchecked().pop_back() } } #[allow(dead_code)] fn pop_front(&self) -> Option> { unsafe { self.queue.as_mut_unchecked().pop_front() } } #[inline(always)] fn tick(&self) { if self.heartbeat.load(Ordering::Relaxed) { self.heartbeat_cold(); } } #[inline] fn execute(&self, job: NonNull) { self.tick(); Job::execute(job); } #[cold] fn heartbeat_cold(&self) { let mut guard = self.context.shared.lock(); if !guard.jobs.contains_key(&self.index) { if let Some(job) = self.pop_back() { tracing::trace!("heartbeat: sharing job: {:?}", job); unsafe { job.as_ref().set_pending(); } guard.jobs.insert(self.index, job); self.context.notify_shared_job(); } } self.heartbeat.store(false, Ordering::Relaxed); } #[inline] fn join_seq(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { let rb = b(); let ra = a(); (ra, rb) } /// This function must be called from a worker thread. #[inline] fn join_heartbeat_every(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { // SAFETY: each worker is only ever used by one thread, so this is safe. let count = self.join_count.get(); self.join_count.set(count.wrapping_add(1) % TIMES as u8); // TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq. // see: chili // SAFETY: this function runs in a worker thread, so we can access the queue safely. if count == 0 || unsafe { self.queue.as_ref_unchecked().len() } < 3 { cold_path(); self.join_heartbeat(a, b) } else { self.join_seq(a, b) } } /// This function must be called from a worker thread. #[inline] fn join_heartbeat(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; let a = StackJob::new( move || { // TODO: bench whether tick'ing here is good. // turns out this actually costs a lot of time, likely because of the thread local check. // WorkerThread::current_ref() // .expect("stackjob is run in workerthread.") // .tick(); a() }, NopLatch, ); let job = a.as_job(); self.push_front(&job); let rb = match catch_unwind(AssertUnwindSafe(|| b())) { Ok(val) => val, Err(payload) => { cold_path(); // if b panicked, we need to wait for a to finish self.wait_until_job::(unsafe { job.transmute_ref::() }); resume_unwind(payload); } }; let ra = if job.state() == JobState::Empty as u8 { unsafe { job.unlink(); } // a is allowed to panic here, because we already finished b. unsafe { a.unwrap()() } } else { match self.wait_until_job::(unsafe { job.transmute_ref::() }) { Some(t) => t.into_result(), // propagate panic here None => unsafe { a.unwrap()() }, } }; drop(a); (ra, rb) } #[cold] fn wait_until_latch_cold(&self, latch: &Latch) { // does this optimise? assert!(!latch.probe()); self.wait_until_predicate(|| latch.probe()) } pub fn wait_until_latch(&self, latch: &Latch) { if !latch.probe() { self.wait_until_latch_cold(latch) } } #[inline] fn wait_until_predicate(&self, pred: F) where F: Fn() -> bool, { 'outer: while !pred() { // take a shared job, if it exists if let Some(shared_job) = self.context.shared.lock().jobs.remove(&self.index) { self.execute(shared_job); } // process local jobs before locking shared context while let Some(job) = self.pop_front() { unsafe { job.as_ref().set_pending(); } self.execute(job); } while !pred() { let mut guard = self.context.shared.lock(); let mut _spin = SpinWait::new(); match guard.pop_job() { Some(job) => { drop(guard); self.execute(job); continue 'outer; } None => { tracing::trace!("waiting for shared job, thread id: {:?}", self.index); // TODO: wait on latch? if we have something that can // signal being done, e.g. can be waited on instead of // shared jobs, we should wait on it instead, but we // would also want to receive shared jobs still? // Spin? probably just wastes CPU time. // self.context.shared_job.wait(&mut guard); // if spin.spin() { // // wait for more shared jobs. // // self.context.shared_job.wait(&mut guard); // return; // } // Yield? same as spinning, really, so just exit and let the upstream use wait // std::thread::yield_now(); return; } } } } return; } pub fn wait_until_job(&self, job: &Job) -> Option> { self.wait_until_predicate(|| { // check if job is finished job.state() == JobState::Finished as u8 }); // someone else has this job and is working on it, // while job isn't done, suspend thread. Some(job.wait()) } } pub fn scope<'scope, F, R>(f: F) -> R where F: FnOnce(&Scope<'scope>) -> R + Send, R: Send, { Scope::<'scope>::scope(f) } impl<'scope> Scope<'scope> { fn wait_for_jobs(&self, worker: &WorkerThread) { tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe { worker.queue.as_ref_unchecked() }); worker.wait_until_latch(&self.job_counter); unsafe { self.job_counter.wait() }; } pub fn scope(f: F) -> R where F: FnOnce(&Self) -> R + Send, R: Send, { run_in_worker(|worker| { // SAFETY: we call complete() after creating this scope, which // ensures that any jobs spawned from the scope exit before the // scope closes. let this = unsafe { Self::from_context(worker.context.clone()) }; this.complete(worker, || f(&this)) }) } fn scope_with_context(context: Arc, f: F) -> R where F: FnOnce(&Self) -> R + Send, R: Send, { context.run_in_worker(|worker| { // SAFETY: we call complete() after creating this scope, which // ensures that any jobs spawned from the scope exit before the // scope closes. let this = unsafe { Self::from_context(context.clone()) }; this.complete(worker, || f(&this)) }) } /// should be called from within a worker thread. fn complete(&self, worker: &WorkerThread, f: F) -> R where F: FnOnce() -> R + Send, R: Send, { use std::panic::{catch_unwind, AssertUnwindSafe}; #[allow(dead_code)] fn make_job T, T>(f: F) -> Job { #[align(8)] unsafe fn harness T, T>(this: *const (), job: *const Job) { let f = unsafe { Box::from_raw(this.cast::().cast_mut()) }; let result = catch_unwind(AssertUnwindSafe(move || f())); let job = unsafe { Box::from_raw(job.cast_mut()) }; job.complete(result); } Job::::new(harness::, unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(f))).cast() }) } let result = match catch_unwind(AssertUnwindSafe(|| f())) { Ok(val) => Some(val), Err(payload) => { self.panicked(payload); None } }; self.wait_for_jobs(worker); self.maybe_propagate_panic(); // SAFETY: if result panicked, we would have propagated the panic above. result.unwrap() } /// resumes the panic if one happened in this scope. fn maybe_propagate_panic(&self) { let err_ptr = self.panic.load(Ordering::Relaxed); if !err_ptr.is_null() { unsafe { let err = Box::from_raw(err_ptr); std::panic::resume_unwind(*err); } } } /// stores the first panic that happened in this scope. fn panicked(&self, err: Box) { self.panic.load(Ordering::Relaxed).is_null().then(|| { use core::mem::ManuallyDrop; let mut boxed = ManuallyDrop::new(Box::new(err)); let err_ptr: *mut Box = &mut **boxed; if self .panic .compare_exchange( ptr::null_mut(), err_ptr, Ordering::SeqCst, Ordering::Relaxed, ) .is_ok() { // we successfully set the panic, no need to drop } else { // drop the error, someone else already set it _ = ManuallyDrop::into_inner(boxed); } }); } pub fn spawn(&self, f: F) where F: FnOnce(&Scope<'scope>) + Send, { self.context.run_in_worker(|worker| { self.job_counter.increment(); let this = SendPtr::new_const(self).unwrap(); let job = HeapJob::new(move || unsafe { _ = f(this.as_ref()); this.as_ref().job_counter.decrement(); }) .into_boxed_job(); tracing::trace!("allocated heapjob"); worker.push_front(job); tracing::trace!("leaked heapjob"); }); } pub fn spawn_future(&self, future: F) -> async_task::Task where F: Future + Send + 'scope, T: Send + 'scope, { self.context.run_in_worker(|worker| { self.job_counter.increment(); let this = SendPtr::new_const(&self.job_counter).unwrap(); let future = async move { let _guard = DropGuard::new(move || unsafe { this.as_ref().decrement(); }); future.await }; let schedule = move |runnable: Runnable| { #[align(8)] unsafe fn harness(this: *const (), job: *const Job) { let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); runnable.run(); // SAFETY: job was turned into raw drop(Box::from_raw(job.cast_mut())); } let job = Box::new(Job::::new(harness::, runnable.into_raw())); worker.push_front(Box::into_raw(job)); }; let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) }; runnable.schedule(); task }) } #[allow(dead_code)] fn spawn_async<'a, T, Fut, Fn>(&'a self, f: Fn) -> async_task::Task where Fn: FnOnce(&Scope) -> Fut + Send + 'static, Fut: Future + Send + 'static, T: Send + 'static, { let this = SendPtr::new_const(self).unwrap(); let future = async move { f(unsafe { this.as_ref() }).await }; self.spawn_future(future) } #[inline] pub fn join(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce(&Self) -> RA + Send, B: FnOnce(&Self) -> RB + Send, { let worker = WorkerThread::current_ref().expect("join is run in workerthread."); let this = SendPtr::new_const(self).unwrap(); worker.join_heartbeat_every::<_, _, _, _, 64>( { let this = this; move || a(unsafe { this.as_ref() }) }, { let this = this; move || b(unsafe { this.as_ref() }) }, ) } unsafe fn from_context(ctx: Arc) -> Self { Self { context: ctx, job_counter: JobCounter::default(), panic: AtomicPtr::new(ptr::null_mut()), _pd: PhantomData, } } } /// run two closures potentially in parallel, in the global threadpool. #[allow(dead_code)] pub fn join(a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { join_in(Context::global().clone(), a, b) } /// run two closures potentially in parallel, in the global threadpool. #[allow(dead_code)] fn join_in(context: Arc, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { context.join(a, b) } pub struct ThreadPool { context: Arc, } impl ThreadPool { pub fn new() -> ThreadPool { Self { context: Context::new(), } } pub fn global() -> ThreadPool { ThreadPool { context: Context::global().clone(), } } pub fn join(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { self.context.join(a, b) } pub fn scope<'scope, R, F>(&self, f: F) -> R where F: FnOnce(&Scope<'scope>) -> R + Send, R: Send, { Scope::scope_with_context(self.context.clone(), f) } } struct Context { shared: Mutex, shared_job: Condvar, } struct SharedContext { jobs: BTreeMap>, heartbeats: BTreeMap>>, injected_jobs: Vec>, // monotonic increasing id heartbeats_id: usize, should_stop: bool, } unsafe impl Send for SharedContext {} impl SharedContext { fn new_heartbeat(&mut self) -> (Arc>, usize) { let index = self.heartbeats_id; self.heartbeats_id = self.heartbeats_id.checked_add(1).unwrap(); let is_set = Arc::new(CachePadded::new(AtomicBool::new(false))); let weak = Arc::downgrade(&is_set); self.heartbeats.insert(index, weak); (is_set, index) } fn pop_job(&mut self) -> Option> { // this is unlikely, so make the function cold? // TODO: profile this if !self.injected_jobs.is_empty() { return Some(unsafe { self.pop_injected_job() }); } self.jobs.pop_first().map(|(_, job)| job) } #[cold] unsafe fn pop_injected_job(&mut self) -> NonNull { self.injected_jobs.pop().unwrap() } } impl Context { fn new() -> Arc { let this = Arc::new(Self { shared: Mutex::new(SharedContext { jobs: BTreeMap::new(), heartbeats: BTreeMap::new(), injected_jobs: Vec::new(), heartbeats_id: 0, should_stop: false, }), shared_job: Condvar::new(), }); tracing::trace!("created threadpool {:?}", Arc::as_ptr(&this)); let num_threads = available_parallelism(); // let num_threads = 2; let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1)); for i in 0..num_threads { let ctx = this.clone(); let barrier = barrier.clone(); std::thread::Builder::new() .name(format!("{:?}-worker-{}", Arc::as_ptr(&this), i)) .spawn(|| worker(ctx, barrier)) .expect("Failed to spawn worker thread"); } let ctx = this.clone(); std::thread::Builder::new() .name(format!("{:?}-heartbeat", Arc::as_ptr(&this))) .spawn(|| heartbeat_worker(ctx)) .expect("Failed to spawn heartbeat thread"); barrier.wait(); this } pub fn global() -> &'static Arc { GLOBAL_CONTEXT.get_or_init(|| Self::new()) } pub fn inject_job(&self, job: NonNull) { let mut guard = self.shared.lock(); guard.injected_jobs.push(job); self.notify_shared_job(); } fn notify_shared_job(&self) { self.shared_job.notify_one(); } #[inline] pub fn join(self: &Arc, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { // SAFETY: join_heartbeat_every is safe to call from a worker thread. self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b)) } /// Runs closure in this context, processing the other context's worker's jobs while waiting for the result. fn run_in_worker_cross(self: &Arc, worker: &WorkerThread, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, T: Send, { // current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs. let latch = AtomicLatch::new(); let job = StackJob::new( move || { let worker = WorkerThread::current_ref() .expect("WorkerThread::run_in_worker called outside of worker thread"); f(worker) }, LatchRef::new(&latch), ); let job = job.as_job(); job.set_pending(); self.inject_job(Into::into(&job)); // no need to wait for latch to signal, because we're waiting on the job anyway worker.wait_until_latch(&latch); let t = unsafe { job.transmute_ref::().wait().into_result() }; t } /// Run closure in this context, sleeping until the job is done. pub fn run_in_worker_cold(self: &Arc, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, T: Send, { use crate::latch::MutexLatch; // current thread isn't a worker thread, create job and inject into global context let latch = MutexLatch::new(); let job = StackJob::new( move || { let worker = WorkerThread::current_ref() .expect("WorkerThread::run_in_worker called outside of worker thread"); f(worker) }, LatchRef::new(&latch), ); let job = job.as_job(); job.set_pending(); self.inject_job(Into::into(&job)); latch.wait(); let t = unsafe { job.transmute_ref::().wait().into_result() }; t } /// Run closure in this context. pub fn run_in_worker(self: &Arc, f: F) -> T where T: Send, F: FnOnce(&WorkerThread) -> T + Send, { match WorkerThread::current_ref() { Some(worker) => { // check if worker is in the same context if Arc::ptr_eq(&worker.context, self) { tracing::trace!("run_in_worker: current thread"); f(worker) } else { // current thread is a worker for a different context tracing::trace!("run_in_worker: cross-context"); self.run_in_worker_cross(worker, f) } } None => { // current thread is not a worker for any context tracing::trace!("run_in_worker: inject into context"); self.run_in_worker_cold(f) } } } } fn run_in_worker(f: F) -> T where T: Send, F: FnOnce(&WorkerThread) -> T + Send, { Context::global().run_in_worker(f) } static GLOBAL_CONTEXT: OnceLock> = OnceLock::new(); const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100); /// returns the number of available hardware threads, or 1 if it cannot be determined. fn available_parallelism() -> usize { std::thread::available_parallelism() .map(|n| n.get()) .unwrap_or(1) } fn worker(ctx: Arc, barrier: Arc) { tracing::trace!("new worker thread {:?}", std::thread::current()); unsafe { WorkerThread::set_current( Box::into_raw(Box::new(WorkerThread::new_in(ctx.clone()))).cast_const(), ); } let _guard = DropGuard::new(|| unsafe { tracing::trace!("worker thread dropping {:?}", std::thread::current()); WorkerThread::drop_in_place_and_dealloc(WorkerThread::unset_current().unwrap()); }); let worker = WorkerThread::current_ref().unwrap(); barrier.wait(); let mut job = ctx.shared.lock().pop_job(); 'outer: loop { let mut guard = loop { if let Some(job) = job { worker.execute(job); } let mut guard = ctx.shared.lock(); if guard.should_stop { // if the context is stopped, break out of the outer loop which // will exit the thread. break 'outer; } match guard.pop_job() { Some(job) => { tracing::trace!("worker: popping job: {:?}", job); // found job, continue inner loop continue; } None => { tracing::trace!("worker: no job, waiting for shared job"); // no more jobs, break out of inner loop and wait for shared job break guard; } } }; ctx.shared_job.wait(&mut guard); job = guard.pop_job(); } } fn heartbeat_worker(ctx: Arc) { tracing::trace!("new heartbeat thread {:?}", std::thread::current()); let mut i = 0; loop { let sleep_for = { let mut guard = ctx.shared.lock(); if guard.should_stop { break; } let mut n = 0; guard.heartbeats.retain(|_, b| { b.upgrade() .inspect(|heartbeat| { if n == i { heartbeat.store(true, Ordering::Relaxed); } n += 1; }) .is_some() }); let num_heartbeats = guard.heartbeats.len(); drop(guard); if i >= num_heartbeats { i = 0; } else { i += 1; } HEARTBEAT_INTERVAL.checked_div(num_heartbeats as u32) }; if let Some(duration) = sleep_for { std::thread::sleep(duration); } } } #[cfg(test)] mod tests;