executor/src/praetor/mod.rs

1301 lines
38 KiB
Rust

mod util {
use std::{
cell::UnsafeCell,
marker::PhantomData,
mem::ManuallyDrop,
num::NonZero,
ptr::NonNull,
sync::atomic::{AtomicPtr, Ordering},
};
pub struct DropGuard<F: FnOnce()>(UnsafeCell<ManuallyDrop<F>>);
impl<F> DropGuard<F>
where
F: FnOnce(),
{
pub fn new(f: F) -> DropGuard<F> {
Self(UnsafeCell::new(ManuallyDrop::new(f)))
}
}
impl<F> Drop for DropGuard<F>
where
F: FnOnce(),
{
fn drop(&mut self) {
unsafe {
ManuallyDrop::take(&mut *self.0.get())();
}
}
}
#[repr(transparent)]
pub struct TaggedAtomicPtr<T, const BITS: usize>(AtomicPtr<()>, PhantomData<T>);
impl<T, const BITS: usize> TaggedAtomicPtr<T, BITS> {
const fn mask() -> usize {
!(!0usize << BITS)
}
pub fn new(ptr: *mut T, tag: usize) -> TaggedAtomicPtr<T, BITS> {
debug_assert!(core::mem::align_of::<T>().ilog2() as usize >= BITS);
let mask = Self::mask();
Self(
AtomicPtr::new(ptr.with_addr((ptr.addr() & !mask) | (tag & mask)).cast()),
PhantomData,
)
}
pub fn ptr(&self, order: Ordering) -> NonNull<T> {
unsafe {
NonNull::new_unchecked(self.0.load(order) as _)
.map_addr(|addr| NonZero::new_unchecked(addr.get() & !Self::mask()))
}
}
pub fn tag(&self, order: Ordering) -> usize {
self.0.load(order).addr() & Self::mask()
}
/// returns tag
#[inline(always)]
fn compare_exchange_tag_inner(
&self,
old: usize,
new: usize,
success: Ordering,
failure: Ordering,
cmpxchg: fn(
&AtomicPtr<()>,
*mut (),
*mut (),
Ordering,
Ordering,
) -> Result<*mut (), *mut ()>,
) -> Result<usize, usize> {
let mask = Self::mask();
let old_ptr = self.0.load(failure);
let old = old_ptr.with_addr((old_ptr.addr() & !mask) | (old & mask));
let new = old_ptr.with_addr((old_ptr.addr() & !mask) | (new & mask));
let result = cmpxchg(&self.0, old, new, success, failure);
result
.map(|ptr| ptr.addr() & mask)
.map_err(|ptr| ptr.addr() & mask)
}
/// returns tag
#[inline]
#[allow(dead_code)]
pub fn compare_exchange_tag(
&self,
old: usize,
new: usize,
success: Ordering,
failure: Ordering,
) -> Result<usize, usize> {
self.compare_exchange_tag_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange,
)
}
/// returns tag
#[inline]
pub fn compare_exchange_weak_tag(
&self,
old: usize,
new: usize,
success: Ordering,
failure: Ordering,
) -> Result<usize, usize> {
self.compare_exchange_tag_inner(
old,
new,
success,
failure,
AtomicPtr::<()>::compare_exchange_weak,
)
}
#[allow(dead_code)]
pub fn set_ptr(&self, ptr: *mut T, success: Ordering, failure: Ordering) {
let mask = Self::mask();
let ptr = ptr.cast::<()>();
loop {
let old = self.0.load(failure);
let new = ptr.with_addr((ptr.addr() & !mask) | (old.addr() & mask));
if self
.0
.compare_exchange_weak(old, new, success, failure)
.is_ok()
{
break;
}
}
}
pub fn set_tag(&self, tag: usize, success: Ordering, failure: Ordering) {
let mask = Self::mask();
loop {
let ptr = self.0.load(failure);
let new = ptr.with_addr((ptr.addr() & !mask) | (tag & mask));
if self
.0
.compare_exchange_weak(ptr, new, success, failure)
.is_ok()
{
break;
}
}
}
pub fn ptr_and_tag(&self, order: Ordering) -> (NonNull<T>, usize) {
let mask = Self::mask();
let ptr = self.0.load(order);
let tag = ptr.addr() & mask;
let addr = ptr.addr() & !mask;
let ptr = unsafe { NonNull::new_unchecked(ptr.with_addr(addr).cast()) };
(ptr, tag)
}
}
}
mod job {
use std::{
any::Any,
borrow::{Borrow, BorrowMut},
cell::UnsafeCell,
fmt::{Debug, Display},
hint::cold_path,
marker::PhantomPinned,
mem::{self, ManuallyDrop, MaybeUninit},
ops::{Deref, DerefMut},
pin::Pin,
ptr::{self, NonNull},
sync::atomic::Ordering,
thread::Thread,
};
use parking_lot_core::SpinWait;
use super::util::TaggedAtomicPtr;
#[derive(Debug)]
pub struct SmallBox<T>(pub MaybeUninit<Box<T>>);
impl<T: Display> Display for SmallBox<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
(**self).fmt(f)
}
}
impl<T: Ord> Ord for SmallBox<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_ref().cmp(other.as_ref())
}
}
impl<T: PartialOrd> PartialOrd for SmallBox<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.as_ref().partial_cmp(other.as_ref())
}
}
impl<T: Eq> Eq for SmallBox<T> {}
impl<T: PartialEq> PartialEq for SmallBox<T> {
fn eq(&self, other: &Self) -> bool {
self.as_ref().eq(other.as_ref())
}
}
impl<T: Default> Default for SmallBox<T> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<T: Clone> Clone for SmallBox<T> {
fn clone(&self) -> Self {
Self::new(self.as_ref().clone())
}
}
impl<T> Deref for SmallBox<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl<T> DerefMut for SmallBox<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut()
}
}
impl<T> AsRef<T> for SmallBox<T> {
fn as_ref(&self) -> &T {
Self::as_ref(self)
}
}
impl<T> AsMut<T> for SmallBox<T> {
fn as_mut(&mut self) -> &mut T {
Self::as_mut(self)
}
}
impl<T> Borrow<T> for SmallBox<T> {
fn borrow(&self) -> &T {
&**self
}
}
impl<T> BorrowMut<T> for SmallBox<T> {
fn borrow_mut(&mut self) -> &mut T {
&mut **self
}
}
impl<T> SmallBox<T> {
/// must only be called once. takes a reference so this can be called in
/// drop()
unsafe fn get_unchecked(&self, inline: bool) -> T {
if inline {
unsafe { mem::transmute_copy::<MaybeUninit<Box<T>>, T>(&self.0) }
} else {
unsafe { *self.0.assume_init_read() }
}
}
pub fn as_ref(&self) -> &T {
unsafe {
if Self::is_inline() {
mem::transmute::<&MaybeUninit<Box<T>>, &T>(&self.0)
} else {
self.0.assume_init_ref()
}
}
}
pub fn as_mut(&mut self) -> &mut T {
unsafe {
if Self::is_inline() {
mem::transmute::<&mut MaybeUninit<Box<T>>, &mut T>(&mut self.0)
} else {
self.0.assume_init_mut()
}
}
}
pub fn into_inner(self) -> T {
let this = ManuallyDrop::new(self);
let inline = Self::is_inline();
// SAFETY: inline is correctly calculated and this function
// consumes `self`
unsafe { this.get_unchecked(inline) }
}
pub fn is_inline() -> bool {
// the value can be stored inline iff the size of T is equal or
// smaller than the size of the boxed type and the alignment of the
// boxed type is an integer multiple of the alignment of T
mem::size_of::<T>() <= mem::size_of::<Box<MaybeUninit<T>>>()
&& mem::align_of::<Box<MaybeUninit<T>>>() % mem::align_of::<T>() == 0
}
pub fn new(value: T) -> Self {
let inline = Self::is_inline();
if inline {
let mut this = MaybeUninit::new(Self(MaybeUninit::uninit()));
unsafe {
this.as_mut_ptr().cast::<T>().write(value);
this.assume_init()
}
} else {
Self(MaybeUninit::new(Box::new(value)))
}
}
}
impl<T> Drop for SmallBox<T> {
fn drop(&mut self) {
// drop contained value.
drop(unsafe { self.get_unchecked(Self::is_inline()) });
}
}
#[repr(u8)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum JobState {
Empty,
Locked = 1,
Pending,
Finished,
// Inline = 1 << (u8::BITS - 1),
// IsError = 1 << (u8::BITS - 2),
}
impl JobState {
#[allow(dead_code)]
const MASK: u8 = 0; // Self::Inline as u8 | Self::IsError as u8;
fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Empty),
1 => Some(Self::Locked),
2 => Some(Self::Pending),
3 => Some(Self::Finished),
_ => None,
}
}
}
#[derive(Debug)]
pub struct JobList {
head: Pin<Box<Job>>,
tail: Pin<Box<Job>>,
}
impl JobList {
pub fn new() -> JobList {
let head = Box::pin(Job::empty());
let tail = Box::pin(Job::empty());
// head and tail point at themselves
unsafe {
(&mut *head.err_or_link.get()).link.next = None;
(&mut *head.err_or_link.get()).link.prev =
Some(NonNull::new_unchecked((&raw const *tail).cast_mut()));
(&mut *tail.err_or_link.get()).link.next =
Some(NonNull::new_unchecked((&raw const *head).cast_mut()));
(&mut *tail.err_or_link.get()).link.prev = None;
}
Self { head, tail }
}
fn head_ptr(&self) -> *const Job {
&raw const *self.head
}
fn tail_ptr(&self) -> *const Job {
&raw const *self.tail
}
fn head(&self) -> NonNull<Job> {
unsafe { NonNull::new_unchecked(self.head_ptr().cast_mut()) }
}
fn tail(&self) -> NonNull<Job> {
unsafe { NonNull::new_unchecked(self.tail_ptr().cast_mut()) }
}
/// elem must be valid until it is popped.
pub unsafe fn push_front<T>(&mut self, elem: Pin<&Job<T>>) {
let head_link = unsafe { self.head.link_mut() };
// SAFETY: head will always have a previous element.
let prev = head_link.prev.unwrap();
let prev_link = unsafe { prev.as_ref().link_mut() };
let elem_ptr = unsafe { NonNull::new_unchecked(&*elem as *const Job<T> as *mut Job) };
head_link.prev = Some(elem_ptr);
prev_link.next = Some(elem_ptr);
let elem_link = unsafe { elem.link_mut() };
elem_link.prev = Some(prev);
elem_link.next = Some(self.head());
}
/// elem must be valid until it is popped.
pub unsafe fn push_back<T>(&mut self, elem: Pin<&Job<T>>) {
let tail_link = unsafe { self.tail.link_mut() };
// SAFETY: tail will always have a previous element.
let next = tail_link.next.unwrap();
let next_link = unsafe { next.as_ref().link_mut() };
let elem_ptr = unsafe { NonNull::new_unchecked(&*elem as *const Job<T> as *mut Job) };
tail_link.next = Some(elem_ptr);
next_link.prev = Some(elem_ptr);
let elem_link = unsafe { elem.link_mut() };
elem_link.next = Some(next);
elem_link.prev = Some(self.tail());
}
#[allow(dead_code)]
pub fn pop_front(&mut self) -> Option<NonNull<Job>> {
let head_link = unsafe { self.head.link_mut() };
// SAFETY: head will always have a previous element.
let elem = head_link.prev.unwrap();
let elem_link = unsafe { elem.as_ref().link_mut() };
let prev = elem_link.prev?.as_ptr();
head_link.prev = unsafe { Some(NonNull::new_unchecked(prev)) };
let prev_link = unsafe { (&*prev).link_mut() };
prev_link.next = Some(self.head());
Some(elem)
}
pub fn pop_back(&mut self) -> Option<NonNull<Job>> {
// TODO: next and elem might be the same
let tail_link = unsafe { self.tail.link_mut() };
// SAFETY: head will always have a previous element.
let elem = tail_link.next.unwrap();
let elem_link = unsafe { elem.as_ref().link_mut() };
let next = elem_link.next?.as_ptr();
tail_link.next = unsafe { Some(NonNull::new_unchecked(next)) };
let next_link = unsafe { (&*next).link_mut() };
next_link.prev = Some(self.tail());
Some(elem)
}
}
union ValueOrThis<T> {
uninit: (),
value: ManuallyDrop<SmallBox<T>>,
this: NonNull<()>,
}
#[derive(Debug, PartialEq, Eq)]
struct Link<T> {
prev: Option<NonNull<T>>,
next: Option<NonNull<T>>,
}
impl<T> Clone for Link<T> {
fn clone(&self) -> Self {
Self {
prev: self.prev.clone(),
next: self.next.clone(),
}
}
}
// because Copy is invariant over `T`
impl<T> Copy for Link<T> {}
union LinkOrError<T> {
link: Link<T>,
waker: ManuallyDrop<Option<Thread>>,
error: ManuallyDrop<Option<Box<dyn Any + Send + 'static>>>,
}
pub struct Job<T = ()> {
/// tagged pointer, 8-aligned
harness_and_state: TaggedAtomicPtr<usize, 3>,
/// NonNull<()> before execute(), Value<T> after
val_or_this: UnsafeCell<ValueOrThis<T>>,
/// (prev,next) before execute(), Box<...> after
err_or_link: UnsafeCell<LinkOrError<Job>>,
_phantom: PhantomPinned,
}
impl<T> Debug for Job<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state =
JobState::from_u8(self.harness_and_state.tag(Ordering::Relaxed) as u8).unwrap();
let mut debug = f.debug_struct("Job");
debug.field("state", &state).field_with("harness", |f| {
write!(f, "{:?}", self.harness_and_state.ptr(Ordering::Relaxed))
});
match state {
JobState::Empty => {
debug
.field_with("this", |f| {
write!(f, "{:?}", unsafe { &(&*self.val_or_this.get()).this })
})
.field_with("link", |f| {
write!(f, "{:?}", unsafe { &(&*self.err_or_link.get()).link })
});
}
JobState::Locked => {
#[derive(Debug)]
struct Locked;
debug.field("locked", &Locked);
}
JobState::Pending => {
debug
.field_with("this", |f| {
write!(f, "{:?}", unsafe { &(&*self.val_or_this.get()).this })
})
.field_with("waker", |f| {
write!(f, "{:?}", unsafe { &(&*self.err_or_link.get()).waker })
});
}
JobState::Finished => {
let err = unsafe { &(&*self.err_or_link.get()).error };
let result = match err.as_ref() {
Some(err) => Err(err),
None => Ok(unsafe { (&*self.val_or_this.get()).value.0.as_ptr() }),
};
debug.field("result", &result);
}
}
debug.finish()
}
}
unsafe impl<T> Send for Job<T> {}
impl<T> Job<T> {
pub fn new(
harness: unsafe fn(*const (), *const Job<T>, &super::Scope),
this: NonNull<()>,
) -> Job<T> {
Self {
harness_and_state: TaggedAtomicPtr::new(
unsafe { mem::transmute(harness) },
JobState::Empty as usize,
),
val_or_this: UnsafeCell::new(ValueOrThis { this }),
err_or_link: UnsafeCell::new(LinkOrError {
link: Link {
prev: None,
next: None,
},
}),
_phantom: PhantomPinned,
}
}
pub fn empty() -> Job<T> {
Self {
harness_and_state: TaggedAtomicPtr::new(
ptr::dangling_mut(),
JobState::Empty as usize,
),
val_or_this: UnsafeCell::new(ValueOrThis {
this: NonNull::dangling(),
}),
err_or_link: UnsafeCell::new(LinkOrError {
link: Link {
prev: None,
next: None,
},
}),
_phantom: PhantomPinned,
}
}
#[inline]
unsafe fn link_mut(&self) -> &mut Link<Job> {
unsafe { &mut (&mut *self.err_or_link.get()).link }
}
/// assumes job is in joblist
pub unsafe fn unlink(&self) {
unsafe {
let mut dummy = None;
let Link { prev, next } = *self.link_mut();
*prev
.map(|ptr| &mut ptr.as_ref().link_mut().next)
.unwrap_or(&mut dummy) = next;
*next
.map(|ptr| &mut ptr.as_ref().link_mut().prev)
.unwrap_or(&mut dummy) = prev;
}
}
pub fn state(&self) -> u8 {
self.harness_and_state.tag(Ordering::Relaxed) as u8
}
pub fn wait(&self) -> std::thread::Result<T> {
let mut spin = SpinWait::new();
loop {
match self.harness_and_state.compare_exchange_weak_tag(
JobState::Pending as usize,
JobState::Locked as usize,
Ordering::Acquire,
Ordering::Relaxed,
) {
// if still pending, sleep until completed
Ok(state) => {
debug_assert_eq!(state, JobState::Pending as usize);
unsafe {
*(&mut *self.err_or_link.get()).waker = Some(std::thread::current());
}
self.harness_and_state.set_tag(
JobState::Pending as usize,
Ordering::Release,
Ordering::Relaxed,
);
std::thread::park();
spin.reset();
// after sleeping, state should be `Finished`
}
Err(state) => {
// debug_assert_ne!(state, JobState::Pending as usize);
if state == JobState::Finished as usize {
let err = unsafe { (&mut *self.err_or_link.get()).error.take() };
let result: std::thread::Result<T> = if let Some(err) = err {
cold_path();
Err(err)
} else {
let val = unsafe {
ManuallyDrop::take(&mut (&mut *self.val_or_this.get()).value)
};
Ok(val.into_inner())
};
return result;
} else {
// spin until lock is released.
spin.spin();
}
}
}
}
}
/// call this when popping value from local queue
pub fn set_pending(&self) {
let mut spin = SpinWait::new();
loop {
match self.harness_and_state.compare_exchange_weak_tag(
JobState::Empty as usize,
JobState::Pending as usize,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(state) => {
debug_assert_eq!(state, JobState::Empty as usize);
// set waker to None
unsafe {
(&mut *self.err_or_link.get()).waker = ManuallyDrop::new(None);
}
return;
}
Err(_) => {
// debug_assert_ne!(state, JobState::Empty as usize);
eprintln!("######## what the sigma?");
spin.spin();
}
}
}
}
pub fn execute(&self, scope: &super::Scope) {
// SAFETY: self is non-null
unsafe {
let (ptr, state) = self.harness_and_state.ptr_and_tag(Ordering::Relaxed);
debug_assert_eq!(state, JobState::Pending as usize);
let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) =
mem::transmute(ptr.as_ptr());
let this = (*self.val_or_this.get()).this;
harness(this.as_ptr().cast(), (self as *const Self).cast(), scope);
}
}
pub(crate) fn complete(&self, result: std::thread::Result<T>) {
let mut spin = SpinWait::new();
loop {
match self.harness_and_state.compare_exchange_weak_tag(
JobState::Pending as usize,
JobState::Locked as usize,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(state) => {
debug_assert_eq!(state, JobState::Pending as usize);
break;
}
Err(_) => {
// debug_assert_ne!(state, JobState::Pending as usize);
spin.spin();
}
}
}
let waker = unsafe { (&mut *self.err_or_link.get()).waker.take() };
match result {
Ok(val) => unsafe {
(&mut *self.val_or_this.get()).value = ManuallyDrop::new(SmallBox::new(val));
(&mut *self.err_or_link.get()).error = ManuallyDrop::new(None);
},
Err(err) => unsafe {
(&mut *self.val_or_this.get()).uninit = ();
(&mut *self.err_or_link.get()).error = ManuallyDrop::new(Some(err));
},
}
if let Some(thread) = waker {
thread.unpark();
}
self.harness_and_state.set_tag(
JobState::Finished as usize,
Ordering::Release,
Ordering::Relaxed,
);
}
}
#[allow(dead_code)]
pub struct HeapJob<F> {
f: F,
}
impl<F> HeapJob<F> {
#[allow(dead_code)]
pub fn new(f: F) -> Box<Self> {
Box::new(Self { f })
}
#[allow(dead_code)]
pub fn into_boxed_job<T>(self: Box<Self>) -> Box<Job<()>>
where
F: FnOnce(&super::Scope) -> T + Send,
T: Send,
{
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
where
F: FnOnce(&super::Scope) -> T + Send,
T: Sized + Send,
{
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
let f = this.f;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
let job = unsafe { &*job.cast::<Job<T>>() };
job.complete(result);
}
Box::new(Job::new(harness::<F, T>, unsafe {
NonNull::new_unchecked(Box::into_raw(self)).cast()
}))
}
}
pub struct StackJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
_phantom: PhantomPinned,
}
impl<F> StackJob<F> {
pub fn new(f: F) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
_phantom: PhantomPinned,
}
}
pub unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
pub fn as_job<T>(self: Pin<&Self>) -> Job<()>
where
F: FnOnce(&super::Scope) -> T + Send,
T: Send,
{
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope)
where
F: FnOnce(&super::Scope) -> T + Send,
T: Sized + Send,
{
let this = unsafe { &*this.cast::<StackJob<F>>() };
let f = unsafe { this.unwrap() };
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope)));
let job_ref = unsafe { &*job.cast::<Job<T>>() };
job_ref.complete(result);
}
Job::new(harness::<F, T>, unsafe {
NonNull::new_unchecked(&*self as *const _ as *mut ())
})
}
}
}
use std::{
cell::{Cell, UnsafeCell},
collections::BTreeMap,
mem,
pin::{pin, Pin},
ptr::NonNull,
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock, Weak,
},
time::Duration,
};
use crossbeam::utils::CachePadded;
use job::*;
use parking_lot::{Condvar, Mutex};
use util::DropGuard;
pub struct Scope {
join_count: Cell<usize>,
context: Arc<Context>,
index: usize,
heartbeat: Arc<CachePadded<AtomicBool>>,
queue: UnsafeCell<JobList>,
}
thread_local! {
static SCOPE: UnsafeCell<Option<NonNull<Scope>>> = const { UnsafeCell::new(None) };
}
impl Scope {
/// locks shared context
#[allow(dead_code)]
fn new() -> Self {
let context = Context::global().clone();
Self::new_in(context)
}
/// locks shared context
fn new_in(context: Arc<Context>) -> Self {
let (heartbeat, index) = context.shared.lock().new_heartbeat();
Self {
context,
index,
heartbeat,
join_count: Cell::new(0),
queue: UnsafeCell::new(JobList::new()),
}
}
fn with_in<T, F: FnOnce(&Scope) -> T>(ctx: &Arc<Context>, f: F) -> T {
let mut guard = Option::<DropGuard<Box<dyn FnOnce()>>>::None;
let scope = match Self::current_ref() {
Some(scope) if Arc::ptr_eq(&scope.context, ctx) => scope,
Some(_) => {
let old = unsafe { Self::unset_current().unwrap().as_ptr() };
guard = Some(DropGuard::new(Box::new(move || unsafe {
_ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
Self::set_current(old.cast_const());
})));
let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
unsafe {
Self::set_current(current.cast_const());
&*current
}
}
None => {
let current = Box::into_raw(Box::new(Self::new_in(ctx.clone())));
guard = Some(DropGuard::new(Box::new(|| unsafe {
_ = Box::from_raw(Self::unset_current().unwrap().as_ptr());
})));
unsafe {
Self::set_current(current.cast_const());
&*current
}
}
};
let t = f(scope);
drop(guard);
t
}
pub fn with<T, F: FnOnce(&Scope) -> T>(f: F) -> T {
Self::with_in(Context::global(), f)
}
unsafe fn set_current(scope: *const Scope) {
SCOPE.with(|ptr| unsafe {
_ = (&mut *ptr.get()).insert(NonNull::new_unchecked(scope.cast_mut()));
})
}
unsafe fn unset_current() -> Option<NonNull<Scope>> {
SCOPE.with(|ptr| unsafe { (&mut *ptr.get()).take() })
}
#[allow(dead_code)]
fn current() -> Option<NonNull<Scope>> {
SCOPE.with(|ptr| unsafe { *ptr.get() })
}
fn current_ref<'a>() -> Option<&'a Scope> {
SCOPE.with(|ptr| unsafe { (&*ptr.get()).map(|ptr| ptr.as_ref()) })
}
fn push_front<T>(&self, job: Pin<&Job<T>>) {
unsafe {
self.queue.as_mut_unchecked().push_front(job);
}
}
#[allow(dead_code)]
fn push_back<T>(&self, job: Pin<&Job<T>>) {
unsafe {
self.queue.as_mut_unchecked().push_back(job);
}
}
fn pop_back(&self) -> Option<NonNull<Job>> {
unsafe { self.queue.as_mut_unchecked().pop_back() }
}
#[allow(dead_code)]
fn pop_front(&self) -> Option<NonNull<Job>> {
unsafe { self.queue.as_mut_unchecked().pop_front() }
}
#[inline]
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
self.join_heartbeat_every::<_, _, _, _, 64>(a, b)
// self.join_heartbeat(a, b)
}
pub fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
let rb = b(&self);
let ra = a(&self);
(ra, rb)
}
pub fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
// let count = self.join_count.get();
// self.join_count.set(count.wrapping_add(1) % TIMES);
let count = self.join_count.update(|n| n.wrapping_add(1) % TIMES);
if count == 1 {
self.join_heartbeat(a, b)
} else {
self.join_seq(a, b)
}
}
pub fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
{
let a = pin!(StackJob::new(move |scope: &Scope| {
scope.tick();
a(scope)
}));
let job = pin!(a.as_ref().as_job());
self.push_front(job.as_ref());
let rb = b(self);
let ra = if job.state() == JobState::Empty as u8 {
unsafe {
job.unlink();
}
unsafe { a.unwrap()(self) }
} else {
match self.wait_until::<RA>(unsafe {
mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref())
}) {
Some(Ok(t)) => t,
Some(Err(payload)) => std::panic::resume_unwind(payload),
None => unsafe { a.unwrap()(self) },
}
};
drop(a);
(ra, rb)
}
#[inline(always)]
fn tick(&self) {
if self.heartbeat.load(Ordering::Relaxed) {
self.heartbeat_cold();
}
}
#[inline]
fn execute(&self, job: &Job) {
self.tick();
job.execute(self);
}
#[cold]
fn heartbeat_cold(&self) {
let mut guard = self.context.shared.lock();
if !guard.jobs.contains_key(&self.index) {
if let Some(job) = self.pop_back() {
unsafe {
job.as_ref().set_pending();
}
guard.jobs.insert(self.index, job);
self.context.shared_job.notify_one();
}
}
self.heartbeat.store(false, Ordering::Relaxed);
}
pub fn wait_until<T>(&self, job: Pin<&Job<T>>) -> Option<std::thread::Result<T>> {
let shared_job = self.context.shared.lock().jobs.remove(&self.index);
if let Some(ptr) = shared_job {
if ptr.as_ptr() == &*job as *const _ as *mut _ {
return None;
} else {
unsafe {
self.execute(ptr.as_ref());
}
}
}
while job.state() != JobState::Finished as u8 {
let Some(job) = self
.context
.shared
.lock()
.jobs
.pop_first()
.map(|(_, job)| job)
// .or_else(|| {
// self.pop_front().inspect(|job| unsafe {
// job.as_ref().set_pending();
// })
// })
else {
break;
};
unsafe {
self.execute(job.as_ref());
}
}
// while job isn't done, run other jobs.
Some(job.wait())
}
}
#[allow(dead_code)]
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
Scope::with(|scope| scope.join(|_| a(), |_| b()))
}
pub struct ThreadPool {
context: Arc<Context>,
}
impl ThreadPool {
pub fn new() -> ThreadPool {
Self {
context: Context::new(),
}
}
pub fn global() -> ThreadPool {
ThreadPool {
context: Context::global().clone(),
}
}
pub fn scope<T, F: FnOnce(&Scope) -> T>(&self, f: F) -> T {
Scope::with_in(&self.context, f)
}
}
struct Context {
shared: Mutex<SharedContext>,
shared_job: Condvar,
}
struct SharedContext {
jobs: BTreeMap<usize, NonNull<Job>>,
heartbeats: BTreeMap<usize, Weak<CachePadded<AtomicBool>>>,
// monotonic increasing id
heartbeats_id: usize,
should_stop: bool,
}
unsafe impl Send for SharedContext {}
impl SharedContext {
fn new_heartbeat(&mut self) -> (Arc<CachePadded<AtomicBool>>, usize) {
let index = self.heartbeats_id;
self.heartbeats_id.checked_add(1).unwrap();
let is_set = Arc::new(CachePadded::new(AtomicBool::new(false)));
let weak = Arc::downgrade(&is_set);
self.heartbeats.insert(index, weak);
(is_set, index)
}
}
impl Context {
fn new() -> Arc<Context> {
let this = Arc::new(Self {
shared: Mutex::new(SharedContext {
jobs: BTreeMap::new(),
heartbeats: BTreeMap::new(),
heartbeats_id: 0,
should_stop: false,
}),
shared_job: Condvar::new(),
});
eprintln!("created threadpool {:?}", Arc::as_ptr(&this));
let num_threads = available_parallelism();
// let num_threads = 2;
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1));
for _ in 0..num_threads {
let ctx = this.clone();
let barrier = barrier.clone();
std::thread::spawn(|| worker(ctx, barrier));
}
let ctx = this.clone();
std::thread::spawn(|| heartbeat_worker(ctx));
barrier.wait();
this
}
pub fn global() -> &'static Arc<Self> {
GLOBAL_CONTEXT.get_or_init(|| Self::new())
}
}
static GLOBAL_CONTEXT: OnceLock<Arc<Context>> = OnceLock::new();
const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
fn available_parallelism() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
fn worker(ctx: Arc<Context>, barrier: Arc<std::sync::Barrier>) {
unsafe {
Scope::set_current(Box::into_raw(Box::new(Scope::new_in(ctx.clone()))).cast_const());
}
let _guard =
DropGuard::new(|| unsafe { drop(Box::from_raw(Scope::unset_current().unwrap().as_ptr())) });
let scope = Scope::current_ref().unwrap();
barrier.wait();
let mut job = ctx.shared.lock().jobs.pop_first();
loop {
if let Some((_, job)) = job {
unsafe {
scope.execute(job.as_ref());
}
}
let mut guard = ctx.shared.lock();
if guard.should_stop {
break;
}
ctx.shared_job.wait(&mut guard);
job = guard.jobs.pop_first();
}
}
fn heartbeat_worker(ctx: Arc<Context>) {
let mut i = 0;
loop {
let sleep_for = {
let mut guard = ctx.shared.lock();
if guard.should_stop {
break;
}
let mut n = 0;
guard.heartbeats.retain(|_, b| {
b.upgrade()
.inspect(|heartbeat| {
if n == i {
heartbeat.store(true, Ordering::Relaxed);
}
n += 1;
})
.is_some()
});
let num_heartbeats = guard.heartbeats.len();
drop(guard);
if i >= num_heartbeats {
i = 0;
} else {
i += 1;
}
HEARTBEAT_INTERVAL.checked_div(num_heartbeats as u32)
};
if let Some(duration) = sleep_for {
std::thread::sleep(duration);
}
}
}
#[cfg(test)]
mod tests;