diff --git a/benches/join.rs b/benches/join.rs index fc200c6..4d2a63b 100644 --- a/benches/join.rs +++ b/benches/join.rs @@ -102,7 +102,35 @@ fn join_melange(b: &mut Bencher) { } b.iter(move || { - assert_ne!(sum(&tree, tree.root().unwrap(), &mut scope), 0); + let sum = sum(&tree, tree.root().unwrap(), &mut scope); + //eprintln!("{sum}"); + assert_ne!(sum, 0); + }); +} + +#[bench] +fn join_praetor(b: &mut Bencher) { + use executor::praetor::Scope; + let pool = executor::praetor::ThreadPool::new(); + + let tree = tree::Tree::new(TREE_SIZE, 1u32); + + fn sum(tree: &tree::Tree, node: usize) -> u32 { + let node = tree.get(node); + Scope::with(|s| { + let (l, r) = s.join( + || node.left.map(|node| sum(tree, node)).unwrap_or_default(), + || node.right.map(|node| sum(tree, node)).unwrap_or_default(), + ); + + node.leaf + l + r + }) + } + + b.iter(move || { + let sum = pool.scope(|_| sum(&tree, tree.root().unwrap())); + // eprintln!("{sum}"); + assert_ne!(sum, 0); }); } diff --git a/src/lib.rs b/src/lib.rs index c3db87e..e0f9611 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -387,6 +387,7 @@ pub mod latch { } pub mod melange; +pub mod praetor; pub struct ThreadPoolState { num_threads: AtomicUsize, diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs new file mode 100644 index 0000000..5e8af9d --- /dev/null +++ b/src/praetor/mod.rs @@ -0,0 +1,1141 @@ +mod util { + use std::{ + cell::{Cell, UnsafeCell}, + marker::PhantomData, + mem::ManuallyDrop, + num::NonZero, + ops::{Deref, DerefMut}, + ptr::NonNull, + sync::atomic::{AtomicPtr, Ordering}, + }; + + pub struct DropGuard(UnsafeCell>); + + impl DropGuard + where + F: FnOnce(), + { + pub fn new(f: F) -> DropGuard { + Self(UnsafeCell::new(ManuallyDrop::new(f))) + } + } + + impl Drop for DropGuard + where + F: FnOnce(), + { + fn drop(&mut self) { + unsafe { + ManuallyDrop::take(&mut *self.0.get())(); + } + } + } + + #[repr(transparent)] + pub struct TaggedAtomicPtr(AtomicPtr<()>, PhantomData); + + impl TaggedAtomicPtr { + const fn mask() -> usize { + !(!0usize << BITS) + } + + pub fn new(ptr: *mut T, tag: usize) -> TaggedAtomicPtr { + assert!(core::mem::align_of::().ilog2() as usize >= BITS); + let mask = Self::mask(); + Self( + AtomicPtr::new(ptr.with_addr(ptr.addr() | tag & mask).cast()), + PhantomData, + ) + } + + pub fn ptr(&self, order: Ordering) -> NonNull { + unsafe { + NonNull::new_unchecked(self.0.load(order) as _) + .map_addr(|addr| NonZero::new_unchecked(addr.get() & !Self::mask())) + } + } + + pub fn tag(&self, order: Ordering) -> usize { + self.0.load(order).addr() & Self::mask() + } + + /// returns tag + #[inline] + pub fn compare_exchange_weak_tag( + &self, + old: usize, + new: usize, + success: Ordering, + failure: Ordering, + ) -> Result { + let mask = Self::mask(); + let old_ptr = self.0.load(failure); + + let old = old_ptr.with_addr((old_ptr.addr() & !mask) | (old & mask)); + let new = old_ptr.with_addr((old_ptr.addr() & !mask) | (new & mask)); + + let result = self.0.compare_exchange_weak(old, new, success, failure); + + result + .map(|ptr| ptr.addr() & mask) + .map_err(|ptr| ptr.addr() & mask) + } + + pub fn set_ptr(&self, ptr: *mut T, success: Ordering, failure: Ordering) { + let mask = Self::mask(); + let ptr = ptr.cast::<()>(); + loop { + let old = self.0.load(failure); + let new = ptr.with_addr((ptr.addr() & !mask) | (old.addr() & mask)); + if self + .0 + .compare_exchange_weak(old, new, success, failure) + .is_ok() + { + break; + } + } + } + + pub fn set_tag(&self, tag: usize, success: Ordering, failure: Ordering) { + let mask = Self::mask(); + loop { + let ptr = self.0.load(failure); + let new = ptr.with_addr(ptr.addr() | (tag & mask)); + if self + .0 + .compare_exchange_weak(ptr, new, success, failure) + .is_ok() + { + break; + } + } + } + + pub fn ptr_and_tag(&self, order: Ordering) -> (NonNull, usize) { + let mask = Self::mask(); + let ptr = self.0.load(order); + let tag = ptr.addr() & mask; + let addr = ptr.addr() & !mask; + let ptr = unsafe { NonNull::new_unchecked(ptr.with_addr(addr).cast()) }; + (ptr, tag) + } + } + + #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] + #[repr(transparent)] + pub struct SendPtr(NonNull); + + impl SendPtr { + pub fn as_ptr(&self) -> *mut T { + self.0.as_ptr() + } + pub unsafe fn new_unchecked(t: *const T) -> Self { + unsafe { Self(NonNull::new_unchecked(t.cast_mut())) } + } + pub fn new(t: *const T) -> Option { + NonNull::new(t.cast_mut()).map(Self) + } + pub fn cast(self) -> SendPtr { + SendPtr(self.0.cast::()) + } + } + + impl Deref for SendPtr { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.0.as_ptr() } + } + } + + impl DerefMut for SendPtr { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.0.as_ptr() } + } + } + + pub struct XorShift64Star { + state: Cell, + } + + impl XorShift64Star { + /// Initializes the prng with a seed. Provided seed must be nonzero. + pub fn new(seed: u64) -> Self { + XorShift64Star { + state: Cell::new(seed), + } + } + + /// Returns a pseudorandom number. + pub fn next(&self) -> u64 { + let mut x = self.state.get(); + debug_assert_ne!(x, 0); + x ^= x >> 12; + x ^= x << 25; + x ^= x >> 27; + self.state.set(x); + x.wrapping_mul(0x2545_f491_4f6c_dd1d) + } + + /// Return a pseudorandom number from `0..n`. + pub fn next_usize(&self, n: usize) -> usize { + (self.next() % n as u64) as usize + } + } +} + +mod job { + use std::{ + any::Any, + cell::{Cell, UnsafeCell}, + fmt::Debug, + mem::{self, ManuallyDrop, MaybeUninit}, + pin::Pin, + ptr::{self, NonNull}, + sync::atomic::{AtomicPtr, AtomicU8, Ordering}, + thread::Thread, + }; + + use parking_lot_core::SpinWait; + + use super::util::{SendPtr, TaggedAtomicPtr}; + + #[cfg_attr(target_pointer_width = "64", repr(align(16)))] + #[cfg_attr(target_pointer_width = "32", repr(align(8)))] + #[derive(Debug, Default, Clone, Copy)] + struct Size2([usize; 2]); + + struct Value(pub MaybeUninit>>); + + impl Value { + /// must only be called once. takes a reference so this can be called in + /// drop() + unsafe fn get_unchecked(&self, inline: bool) -> T { + if inline { + unsafe { mem::transmute_copy(&self.0) } + } else { + unsafe { (*self.0.assume_init_read()).assume_init() } + } + } + + fn get(self) -> T { + let inline = Self::is_inline(); + + // SAFETY: inline is correctly calculated and this function + // consumes `self` + unsafe { self.get_unchecked(inline) } + } + + fn is_inline() -> bool { + // the value can be stored inline iff the size of T is equal or + // smaller than the size of the boxed type and the alignment of the + // boxed type is an integer multiple of the alignment of T + mem::size_of::() > mem::size_of::>>() + || mem::align_of::>>() % mem::align_of::() != 0 + } + + fn new(value: T) -> Self { + let inline = Self::is_inline(); + + // SAFETY: we know the box is allocated if state was `Pending`. + if inline { + Self(MaybeUninit::new(Box::new(MaybeUninit::new(value)))) + } else { + let mut this = Self(MaybeUninit::uninit()); + unsafe { + *mem::transmute::<_, &mut T>(&mut this.0) = value; + } + this + } + } + } + + impl Drop for Value { + fn drop(&mut self) { + unsafe { + // drop contained value. + _ = self.get_unchecked(Self::is_inline()); + } + } + } + + #[repr(u8)] + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum JobState { + Empty, + Locked = 1, + Pending, + Finished, + // Inline = 1 << (u8::BITS - 1), + // IsError = 1 << (u8::BITS - 2), + } + + impl JobState { + const MASK: u8 = 0; // Self::Inline as u8 | Self::IsError as u8; + fn from_u8(v: u8) -> Option { + match v { + 0 => Some(Self::Empty), + 1 => Some(Self::Locked), + 2 => Some(Self::Pending), + 3 => Some(Self::Finished), + _ => None, + } + } + } + + pub struct JobList { + head: Pin>, + tail: Pin>, + } + + impl JobList { + pub fn new() -> JobList { + let mut head = Box::pin(Job::empty()); + let mut tail = Box::pin(Job::empty()); + + // head and tail point at themselves + unsafe { + (&mut *head.err_or_link.get()).link.next = NonNull::new_unchecked(&mut *head); + (&mut *head.err_or_link.get()).link.prev = NonNull::new_unchecked(&mut *tail); + + (&mut *tail.err_or_link.get()).link.next = NonNull::new_unchecked(&mut *head); + (&mut *tail.err_or_link.get()).link.prev = NonNull::new_unchecked(&mut *tail); + } + Self { head, tail } + } + + /// elem must be valid until it is popped. + pub unsafe fn push_front(&mut self, elem: &Job) { + let head_link = unsafe { self.head.link_mut() }; + + let prev = head_link.prev; + let prev_link = unsafe { prev.as_ref().link_mut() }; + + let elem_ptr = unsafe { NonNull::new_unchecked(elem as *const Job as *mut Job) }; + head_link.prev = elem_ptr; + prev_link.next = elem_ptr; + + let elem_link = unsafe { elem.link_mut() }; + elem_link.prev = prev; + elem_link.next = unsafe { NonNull::new_unchecked(&mut *self.head) }; + } + + /// elem must be valid until it is popped. + pub unsafe fn push_back(&mut self, elem: &Job) { + let tail_link = unsafe { self.tail.link_mut() }; + + let next = tail_link.next; + let next_link = unsafe { next.as_ref().link_mut() }; + + let elem_ptr = unsafe { NonNull::new_unchecked(elem as *const Job as *mut Job) }; + tail_link.next = elem_ptr; + next_link.prev = elem_ptr; + + let elem_link = unsafe { elem.link_mut() }; + elem_link.next = next; + elem_link.prev = unsafe { NonNull::new_unchecked(&mut *self.tail) }; + } + + pub fn pop_front(&mut self) -> Option> { + let head_link = unsafe { self.head.link_mut() }; + + // SAFETY: head will always have a previous element. + let elem = head_link.prev; + let elem_link = unsafe { elem.as_ref().link_mut() }; + + let prev = elem_link.prev.as_ptr(); + head_link.prev = unsafe { NonNull::new_unchecked(prev) }; + + let prev_link = unsafe { (&*prev).link_mut() }; + prev_link.next = unsafe { NonNull::new_unchecked(&mut *self.head) }; + + if elem.as_ptr() == ptr::from_ref(&*self.tail).cast_mut() { + None + } else { + Some(elem) + } + } + + pub fn pop_back(&mut self) -> Option> { + // TODO: next and elem might be the same + let tail_link = unsafe { self.tail.link_mut() }; + + // SAFETY: head will always have a previous element. + let elem = tail_link.next; + let elem_link = unsafe { elem.as_ref().link_mut() }; + + let next = elem_link.next.as_ptr(); + tail_link.next = unsafe { NonNull::new_unchecked(next) }; + + let next_link = unsafe { (&*next).link_mut() }; + next_link.prev = unsafe { NonNull::new_unchecked(&mut *self.tail) }; + + if elem.as_ptr() == ptr::from_ref(&*self.head).cast_mut() { + None + } else { + Some(elem) + } + } + } + + union ValueOrThis { + uninit: (), + value: ManuallyDrop>, + this: NonNull<()>, + } + + #[derive(Debug, PartialEq, Eq)] + struct Link { + prev: NonNull, + next: NonNull, + } + + impl Clone for Link { + fn clone(&self) -> Self { + Self { + prev: self.prev.clone(), + next: self.next.clone(), + } + } + } + + // because Copy is invariant over `T` + impl Copy for Link {} + + union LinkOrError { + link: Link, + waker: ManuallyDrop>, + error: ManuallyDrop>>, + } + + pub struct Job { + /// tagged pointer, 8-aligned + harness_and_state: TaggedAtomicPtr, + /// NonNull<()> before execute(), Value after + val_or_this: UnsafeCell>, + /// (prev,next) before execute(), Box<...> after + err_or_link: UnsafeCell>, + } + + impl Debug for Job { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let state = + JobState::from_u8(self.harness_and_state.tag(Ordering::Relaxed) as u8).unwrap(); + let mut debug = f.debug_struct("Job"); + debug.field("state", &state).field_with("harness", |f| { + write!(f, "{:?}", self.harness_and_state.ptr(Ordering::Relaxed)) + }); + + match state { + JobState::Empty => { + debug + .field_with("this", |f| { + write!(f, "{:?}", unsafe { &(&*self.val_or_this.get()).this }) + }) + .field_with("link", |f| { + write!(f, "{:?}", unsafe { &(&*self.err_or_link.get()).link }) + }); + } + JobState::Locked => { + #[derive(Debug)] + struct Locked; + debug.field("locked", &Locked); + } + JobState::Pending => { + debug + .field_with("this", |f| { + write!(f, "{:?}", unsafe { &(&*self.val_or_this.get()).this }) + }) + .field_with("waker", |f| { + write!(f, "{:?}", unsafe { &(&*self.err_or_link.get()).waker }) + }); + } + JobState::Finished => { + let err = unsafe { &(&*self.err_or_link.get()).error }; + + let result = match err.as_ref() { + Some(err) => Err(err), + None => Ok(unsafe { (&*self.val_or_this.get()).value.0.as_ptr() }), + }; + + debug.field("result", &result); + } + } + + debug.finish() + } + } + + unsafe impl Send for Job {} + + impl Job { + pub fn new(harness: unsafe fn(*const (), *const Job), this: NonNull<()>) -> Job { + Self { + harness_and_state: TaggedAtomicPtr::new( + unsafe { mem::transmute(harness) }, + JobState::Empty as usize, + ), + val_or_this: UnsafeCell::new(ValueOrThis { this }), + err_or_link: UnsafeCell::new(LinkOrError { + link: Link { + prev: NonNull::dangling(), + next: NonNull::dangling(), + }, + }), + } + } + pub fn empty() -> Job { + Self { + harness_and_state: TaggedAtomicPtr::new(ptr::dangling_mut(), 0), + val_or_this: UnsafeCell::new(ValueOrThis { + this: NonNull::dangling(), + }), + err_or_link: UnsafeCell::new(LinkOrError { + link: Link { + prev: NonNull::dangling(), + next: NonNull::dangling(), + }, + }), + } + } + + unsafe fn link_mut(&self) -> &mut Link { + unsafe { &mut (&mut *self.err_or_link.get()).link } + } + + /// assumes job is in joblist + pub unsafe fn unlink(&self) { + unsafe { + let link = self.link_mut(); + link.prev.as_ref().link_mut().next = link.next; + link.next.as_ref().link_mut().prev = link.prev; + } + } + + pub fn state(&self) -> u8 { + self.harness_and_state.tag(Ordering::Relaxed) as u8 + } + pub fn wait(&self) -> std::thread::Result { + let mut spin = SpinWait::new(); + loop { + match self.harness_and_state.compare_exchange_weak_tag( + JobState::Pending as usize, + JobState::Locked as usize, + Ordering::Acquire, + Ordering::Relaxed, + ) { + // if still pending, sleep until completed + Ok(_) => { + unsafe { + *(&mut *self.err_or_link.get()).waker = Some(std::thread::current()); + } + + self.harness_and_state.set_tag( + JobState::Pending as usize, + Ordering::Release, + Ordering::Relaxed, + ); + std::thread::park(); + spin.reset(); + + // after sleeping, state should be `Finished` + continue; + } + Err(state) => { + if state == JobState::Finished as usize { + let err = unsafe { (&mut *self.err_or_link.get()).error.take() }; + + let result: std::thread::Result = if let Some(err) = err { + Err(err) + } else { + let val = unsafe { + ManuallyDrop::take(&mut (&mut *self.val_or_this.get()).value) + }; + + Ok(val.get()) + }; + + return result; + } else { + // spin until lock is released. + spin.spin(); + } + } + } + } + } + + /// call this when popping value from local queue + pub fn set_pending(&self) { + let mut spin = SpinWait::new(); + loop { + match self.harness_and_state.compare_exchange_weak_tag( + JobState::Empty as usize, + JobState::Pending as usize, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => { + // set waker to None + unsafe { + (&mut *self.err_or_link.get()).waker = ManuallyDrop::new(None); + } + return; + } + Err(_) => { + eprintln!("######## what the sigma?"); + spin.spin(); + } + } + } + } + + pub fn execute(&self) { + // SAFETY: self is non-null + unsafe { + let harness: unsafe fn(*const (), *const Self) = + mem::transmute(self.harness_and_state.ptr(Ordering::Relaxed).as_ptr()); + let this = (*self.val_or_this.get()).this; + + eprintln!("{harness:?}({this:?}, {:?})", self as *const Self); + harness(this.as_ptr().cast(), (self as *const Self).cast()); + } + } + + fn complete(&self, result: std::thread::Result) { + eprintln!("complete({:?}) {:#?}", self as *const _, self); + let mut spin = SpinWait::new(); + loop { + match self.harness_and_state.compare_exchange_weak_tag( + JobState::Pending as usize, + JobState::Locked as usize, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => { + break; + } + Err(tag) => { + // eprintln!( + // "complete(): spin waiting for lock to complete: ({:?})", + // JobState::from_u8(tag as u8).unwrap() + // ); + spin.spin(); + } + } + } + + let waker = unsafe { (&mut *self.err_or_link.get()).waker.take() }; + + match result { + Ok(val) => unsafe { + (&mut *self.val_or_this.get()).value = ManuallyDrop::new(Value::new(val)); + (&mut *self.err_or_link.get()).error = ManuallyDrop::new(None); + }, + Err(err) => unsafe { + (&mut *self.val_or_this.get()).uninit = (); + (&mut *self.err_or_link.get()).error = ManuallyDrop::new(Some(err)); + }, + } + + if let Some(thread) = waker { + thread.unpark(); + } + + self.harness_and_state.set_tag( + JobState::Finished as usize, + Ordering::Release, + Ordering::Relaxed, + ); + eprintln!("complete({:?}): finished", self as *const _); + } + } + + impl Job {} + + pub struct HeapJob { + f: F, + } + + impl HeapJob { + pub fn new(f: F) -> Box { + Box::new(Self { f }) + } + pub fn into_boxed_job(self: Box) -> Box> + where + F: FnOnce() -> T + Send, + T: Send, + { + unsafe fn harness(this: *const (), job: *const Job<()>) + where + F: FnOnce() -> T + Send, + T: Sized + Send, + { + let job = unsafe { &*job.cast::>() }; + + let this = unsafe { Box::from_raw(this.cast::>().cast_mut()) }; + let f = this.f; + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + + job.complete(result); + } + + Box::new(Job::new(harness::, unsafe { + NonNull::new_unchecked(Box::into_raw(self)).cast() + })) + } + } + + pub struct StackJob { + f: UnsafeCell>, + } + + impl StackJob { + pub fn new(f: F) -> Self { + Self { + f: UnsafeCell::new(ManuallyDrop::new(f)), + } + } + + pub unsafe fn unwrap(&self) -> F { + unsafe { ManuallyDrop::take(&mut *self.f.get()) } + } + + pub fn as_job(&self) -> Job<()> + where + F: FnOnce() -> T + Send, + T: Send, + { + unsafe fn harness(this: *const (), job: *const Job<()>) + where + F: FnOnce() -> T + Send, + T: Sized + Send, + { + let this = unsafe { &*this.cast::>() }; + let f = unsafe { this.unwrap() }; + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + + let job_ref = unsafe { &*job.cast::>() }; + job_ref.complete(result); + } + + Job::new(harness::, unsafe { + NonNull::new_unchecked(self as *const _ as *mut ()) + }) + } + } +} + +use std::{ + cell::UnsafeCell, + collections::BTreeMap, + mem, + pin::pin, + ptr::NonNull, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, OnceLock, Weak, + }, + time::Duration, +}; + +use job::*; +use parking_lot::{Condvar, Mutex}; +use util::DropGuard; + +pub struct Scope { + context: Arc, + index: usize, + heartbeat: Arc, + queue: UnsafeCell, +} + +thread_local! { + static SCOPE: UnsafeCell>> = const { UnsafeCell::new(None) }; +} + +impl Scope { + fn new() -> Self { + let context = Context::global().clone(); + Self::new_in(context) + } + + fn new_in(context: Arc) -> Self { + let (heartbeat, index) = context.shared.lock().new_heartbeat(); + + Self { + context, + index, + heartbeat, + queue: UnsafeCell::new(JobList::new()), + } + } + + fn with_in T>(ctx: Arc, f: F) -> T { + let mut guard = Option::>>::None; + + let scope = match Self::current_ref() { + Some(scope) if Arc::ptr_eq(&scope.context, &ctx) => scope, + Some(_) => { + let old = unsafe { Self::unset_current().unwrap().as_ptr() }; + guard = Some(DropGuard::new(Box::new(move || unsafe { + _ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); + Self::set_current(old.cast_const()); + }))); + let current = Box::into_raw(Box::new(Self::new())); + unsafe { + Self::set_current(current.cast_const()); + &*current + } + } + None => { + let current = Box::into_raw(Box::new(Self::new())); + + guard = Some(DropGuard::new(Box::new(|| unsafe { + _ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); + }))); + + unsafe { + Self::set_current(current.cast_const()); + + &*current + } + } + }; + + let t = f(scope); + drop(guard); + t + } + + pub fn with T>(f: F) -> T { + let mut guard = None; + + let current = Self::current_ref().unwrap_or_else(|| { + let current = Box::into_raw(Box::new(Self::new())); + + guard = Some(DropGuard::new(|| unsafe { + _ = Box::from_raw(Self::unset_current().unwrap().as_ptr()); + })); + + unsafe { + Self::set_current(current.cast_const()); + + &*current + } + }); + + f(current) + } + + unsafe fn set_current(scope: *const Scope) { + SCOPE.with(|ptr| unsafe { + _ = (&mut *ptr.get()).insert(NonNull::new_unchecked(scope.cast_mut())); + }) + } + + unsafe fn unset_current() -> Option> { + SCOPE.with(|ptr| unsafe { (&mut *ptr.get()).take() }) + } + + fn current() -> Option> { + SCOPE.with(|ptr| unsafe { *ptr.get() }) + } + + fn current_ref<'a>() -> Option<&'a Scope> { + SCOPE.with(|ptr| unsafe { (&*ptr.get()).map(|ptr| ptr.as_ref()) }) + } + + fn push_front(&self, job: &Job) { + unsafe { + self.queue.as_mut_unchecked().push_front(job); + } + } + fn push_back(&self, job: &Job) { + unsafe { + self.queue.as_mut_unchecked().push_back(job); + } + } + fn pop_back(&self) -> Option> { + unsafe { self.queue.as_mut_unchecked().pop_back() } + } + fn pop_front(&self) -> Option> { + unsafe { self.queue.as_mut_unchecked().pop_front() } + } + + pub fn join(&self, a: A, b: B) -> (RA, RB) + where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + { + let b = StackJob::new(b); + + let job = pin!(b.as_job()); + self.push_front(&job); + + let ra = a(); + + let rb = if job.state() == JobState::Empty as u8 { + unsafe { + job.unlink(); + } + + self.tick(); + unsafe { b.unwrap()() } + } else { + match self.wait_until::(unsafe { mem::transmute(&job) }) { + Some(Ok(t)) => t, + Some(Err(payload)) => std::panic::resume_unwind(payload), + None => unsafe { b.unwrap()() }, + } + }; + + (ra, rb) + } + + #[inline] + fn tick(&self) { + if self.heartbeat.load(Ordering::Relaxed) { + self.heartbeat_cold(); + } + } + + #[inline] + fn execute(&self, job: &Job) { + eprintln!("execute()"); + self.tick(); + job.execute(); + } + + #[cold] + fn heartbeat_cold(&self) { + eprintln!("heartbeat_cold()"); + let mut guard = self.context.shared.lock(); + + if !guard.jobs.contains_key(&self.index) { + if let Some(job) = self.pop_back() { + unsafe { + job.as_ref().set_pending(); + } + eprintln!("sharing {job:?}"); + guard.jobs.insert(self.index, job); + self.context.shared_job.notify_one(); + } + } + + self.heartbeat.store(false, Ordering::Relaxed); + } + + #[cold] + pub fn wait_until(&self, job: &Job) -> Option> { + // let shared_job = self.context.shared.lock().jobs.remove(&self.index); + + // if let Some(ptr) = shared_job { + // if ptr.as_ptr() == job as *const _ as *mut _ { + // eprintln!("reclaimed shared job"); + // return None; + // } else { + // unsafe { + // self.execute(ptr.as_ref()); + // } + // } + // } + + while job.state() != JobState::Finished as u8 { + let Some(job) = self.pop_front().or_else(|| { + self.context + .shared + .lock() + .jobs + .pop_first() + .map(|(_, job)| job) + }) else { + break; + }; + + unsafe { + self.execute(job.as_ref()); + } + } + // while job isn't done, run other jobs. + Some(job.wait()) + } +} + +fn join(a: A, b: B) -> (RA, RB) +where + RA: Send, + RB: Send, + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, +{ + Scope::with(|scope| scope.join(a, b)) +} + +struct Heartbeat { + weak: Arc, + index: usize, +} + +pub struct ThreadPool { + context: Arc, +} + +impl ThreadPool { + pub fn new() -> ThreadPool { + Self { + context: Context::new(), + } + } + + pub fn scope T>(&self, f: F) -> T { + Scope::with_in(self.context.clone(), f) + } +} + +struct Context { + shared: Mutex, + shared_job: Condvar, +} + +struct SharedContext { + jobs: BTreeMap>, + heartbeats: BTreeMap>, + // monotonic increasing id + heartbeats_id: usize, + should_stop: bool, + rng: util::XorShift64Star, +} + +unsafe impl Send for SharedContext {} + +impl SharedContext { + fn new_heartbeat(&mut self) -> (Arc, usize) { + let index = self.heartbeats_id; + self.heartbeats_id.checked_add(1).unwrap(); + + let is_set = Arc::new(AtomicBool::new(false)); + let weak = Arc::downgrade(&is_set); + + self.heartbeats.insert(index, weak); + + (is_set, index) + } +} + +impl Context { + fn new() -> Arc { + let this = Arc::new(Self { + shared: Mutex::new(SharedContext { + jobs: BTreeMap::new(), + heartbeats: BTreeMap::new(), + heartbeats_id: 0, + should_stop: false, + rng: util::XorShift64Star::new(37), + }), + shared_job: Condvar::new(), + }); + + let num_threads = available_parallelism(); + let barrier = Arc::new(std::sync::Barrier::new(num_threads + 1)); + + for _ in 0..num_threads { + let ctx = this.clone(); + let barrier = barrier.clone(); + std::thread::spawn(|| worker(ctx, barrier)); + } + + let ctx = this.clone(); + std::thread::spawn(|| heartbeat_worker(ctx)); + + barrier.wait(); + + this + } + + pub fn global() -> &'static Arc { + GLOBAL_CONTEXT.get_or_init(|| Self::new()) + } +} + +static GLOBAL_CONTEXT: OnceLock> = OnceLock::new(); +const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100); + +fn available_parallelism() -> usize { + std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1) +} + +fn worker(ctx: Arc, barrier: Arc) { + unsafe { + Scope::set_current(Box::into_raw(Box::new(Scope::new_in(ctx.clone()))).cast_const()); + } + let _guard = + DropGuard::new(|| unsafe { drop(Box::from_raw(Scope::unset_current().unwrap().as_ptr())) }); + + let scope = Scope::current_ref().unwrap(); + + barrier.wait(); + + loop { + let job = ctx.shared.lock().jobs.pop_first(); + if let Some((_, job)) = job { + eprintln!("worker(): found job {job:?}"); + unsafe { + scope.execute(job.as_ref()); + } + } + + let mut guard = ctx.shared.lock(); + if guard.should_stop { + break; + } + ctx.shared_job.wait(&mut guard); + } +} + +fn heartbeat_worker(ctx: Arc) { + let mut i = 0; + loop { + let sleep_for = { + let mut guard = ctx.shared.lock(); + if guard.should_stop { + break; + } + + let mut n = 0; + guard.heartbeats.retain(|_, b| { + b.upgrade() + .inspect(|heartbeat| { + if n == i { + heartbeat.store(true, Ordering::Relaxed); + } + n += 1; + }) + .is_some() + }); + let num_heartbeats = guard.heartbeats.len(); + + if i >= num_heartbeats { + i = 0; + } else { + i += 1; + } + + HEARTBEAT_INTERVAL.checked_div(num_heartbeats as u32) + }; + + if let Some(duration) = sleep_for { + std::thread::sleep(duration); + } + } +} + +#[cfg(test)] +mod tests; diff --git a/src/praetor/tests.rs b/src/praetor/tests.rs new file mode 100644 index 0000000..e5daf56 --- /dev/null +++ b/src/praetor/tests.rs @@ -0,0 +1,178 @@ +use std::pin::Pin; + +use super::{util::TaggedAtomicPtr, *}; + +#[test] +fn job_list_pop_back() { + let mut list = JobList::new(); + let mut a = Job::empty(); + let mut b = Job::empty(); + let mut c = Job::empty(); + + unsafe { + list.push_front(&a); + list.push_front(&b); + list.push_back(&c); + } + + assert_eq!(list.pop_back(), NonNull::new(&mut c)); + unsafe { + list.push_front(&c); + } + assert_eq!(list.pop_back(), NonNull::new(&mut a)); + assert_eq!(list.pop_back(), NonNull::new(&mut b)); + assert_eq!(list.pop_back(), NonNull::new(&mut c)); + assert_eq!(list.pop_back(), None); + assert_eq!(list.pop_front(), None); +} + +#[test] +fn job_list_pop_front() { + let mut list = JobList::new(); + let mut a = Job::empty(); + let mut b = Job::empty(); + let mut c = Job::empty(); + + unsafe { + list.push_front(&a); + list.push_front(&b); + list.push_back(&c); + } + + assert_eq!(list.pop_front(), NonNull::new(&mut b)); + unsafe { + list.push_back(&b); + } + assert_eq!(list.pop_front(), NonNull::new(&mut a)); + assert_eq!(list.pop_front(), NonNull::new(&mut c)); + assert_eq!(list.pop_front(), NonNull::new(&mut b)); + assert_eq!(list.pop_front(), None); + assert_eq!(list.pop_back(), None); +} + +#[test] +fn unlink_job_middle() { + let mut list = JobList::new(); + let mut a = Job::empty(); + let b = Job::<()>::empty(); + let mut c = Job::empty(); + + unsafe { + list.push_front(&a); + list.push_front(&b); + list.push_front(&c); + } + + unsafe { + b.unlink(); + } + + assert_eq!(list.pop_front(), NonNull::new(&mut c)); + assert_eq!(list.pop_front(), NonNull::new(&mut a)); + assert_eq!(list.pop_front(), None); + assert_eq!(list.pop_back(), None); +} + +#[test] +fn unlink_job_front() { + let mut list = JobList::new(); + let a = Job::<()>::empty(); + let mut b = Job::<()>::empty(); + let mut c = Job::<()>::empty(); + + unsafe { + list.push_front(&a); + list.push_front(&b); + list.push_front(&c); + } + + unsafe { + a.unlink(); + } + + assert_eq!(list.pop_front(), NonNull::new(&mut c)); + assert_eq!(list.pop_front(), NonNull::new(&mut b)); + assert_eq!(list.pop_front(), None); + assert_eq!(list.pop_back(), None); +} + +#[test] +fn unlink_job_back() { + let mut list = JobList::new(); + let mut a = Job::<()>::empty(); + let mut b = Job::<()>::empty(); + let c = Job::<()>::empty(); + + unsafe { + list.push_front(&a); + list.push_front(&b); + list.push_front(&c); + } + + unsafe { + c.unlink(); + } + + assert_eq!(list.pop_front(), NonNull::new(&mut b)); + assert_eq!(list.pop_front(), NonNull::new(&mut a)); + assert_eq!(list.pop_front(), None); + assert_eq!(list.pop_back(), None); +} + +#[test] +fn unlink_job_single() { + let mut list = JobList::new(); + let a = Job::<()>::empty(); + + unsafe { + list.push_front(&a); + } + + unsafe { + a.unlink(); + } + + assert_eq!(list.pop_front(), None); + assert_eq!(list.pop_back(), None); +} + +#[test] +fn tagged_ptr_exchange() { + let boxed = Box::into_raw(Box::new(42usize)); + let ptr = TaggedAtomicPtr::<_, 3>::new(boxed, 1usize); + + assert_eq!(ptr.tag(Ordering::Relaxed), 1); + assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); + ptr.set_tag(3, Ordering::Release, Ordering::Relaxed); + assert_eq!(ptr.tag(Ordering::Relaxed), 3); + assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); + let old = ptr.compare_exchange_weak_tag(3, 4, Ordering::Release, Ordering::Relaxed); + assert_eq!(old, Ok(3)); + assert_eq!(ptr.tag(Ordering::Relaxed), 4); + assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); +} + +#[test] +fn tagged_ptr_exchange_failure() { + let boxed = Box::into_raw(Box::new(42usize)); + let ptr = TaggedAtomicPtr::<_, 3>::new(boxed, 1usize); + + assert_eq!(ptr.tag(Ordering::Relaxed), 1); + assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); + + let old = ptr.compare_exchange_weak_tag(3, 4, Ordering::Release, Ordering::Relaxed); + assert_eq!(old, Err(1)); + assert_eq!(ptr.tag(Ordering::Relaxed), 1); + assert_eq!(ptr.ptr(Ordering::Relaxed).as_ptr(), boxed); +} + +fn pinstuff() { + let pinned = pin!(5); + + let b = pinned.as_ref(); + let a = takes_pinned_ref(pinned.as_ref()); +} + +fn takes_pinned_ref(r: Pin<&i32>) -> i32 { + *r +}