This commit is contained in:
Janis 2025-01-31 01:17:01 +01:00
parent 8b35cb7f45
commit a691b614bc

View file

@ -51,11 +51,13 @@ pub mod task {
Self { ptr, execute_fn } Self { ptr, execute_fn }
} }
#[inline]
pub fn id(&self) -> impl Eq { pub fn id(&self) -> impl Eq {
(self.ptr, self.execute_fn) (self.ptr, self.execute_fn)
} }
/// caller must ensure that this particular task is [`Send`] /// caller must ensure that this particular task is [`Send`]
#[inline]
pub fn execute(self) { pub fn execute(self) {
unsafe { (self.execute_fn)(self.ptr) } unsafe { (self.execute_fn)(self.ptr) }
} }
@ -77,19 +79,24 @@ pub mod task {
} }
} }
#[inline]
pub fn run(self) { pub fn run(self) {
self.task.into_inner().unwrap()(); self.task.into_inner().unwrap()();
} }
#[inline]
pub unsafe fn run_as_ref(&self) { pub unsafe fn run_as_ref(&self) {
((&mut *self.task.get()).take().unwrap())(); ((&mut *self.task.get()).take().unwrap())();
} }
#[inline]
pub fn as_task_ref(self: Pin<&Self>) -> TaskRef { pub fn as_task_ref(self: Pin<&Self>) -> TaskRef {
unsafe { TaskRef::new(&*self) } unsafe { TaskRef::new(&*self) }
} }
} }
impl<F: FnOnce() + Send> Task for StackTask<F> { impl<F: FnOnce() + Send> Task for StackTask<F> {
#[inline]
unsafe fn execute(this: *const ()) { unsafe fn execute(this: *const ()) {
let this = &*this.cast::<Self>(); let this = &*this.cast::<Self>();
let task = (&mut *this.task.get()).take().unwrap(); let task = (&mut *this.task.get()).take().unwrap();
@ -110,6 +117,7 @@ pub mod task {
}) })
} }
#[inline]
pub unsafe fn into_static_task_ref(self: Box<Self>) -> TaskRef pub unsafe fn into_static_task_ref(self: Box<Self>) -> TaskRef
where where
F: 'static, F: 'static,
@ -117,11 +125,13 @@ pub mod task {
self.into_task_ref() self.into_task_ref()
} }
#[inline]
pub unsafe fn into_task_ref(self: Box<Self>) -> TaskRef { pub unsafe fn into_task_ref(self: Box<Self>) -> TaskRef {
TaskRef::new(Box::into_raw(self)) TaskRef::new(Box::into_raw(self))
} }
} }
impl<F: FnOnce() + Send> Task for HeapTask<F> { impl<F: FnOnce() + Send> Task for HeapTask<F> {
#[inline]
unsafe fn execute(this: *const ()) { unsafe fn execute(this: *const ()) {
let this = Box::from_raw(this.cast::<Self>().cast_mut()); let this = Box::from_raw(this.cast::<Self>().cast_mut());
(this.task)(); (this.task)();
@ -154,21 +164,25 @@ pub mod latch {
pub struct AtomicLatch(AtomicBool); pub struct AtomicLatch(AtomicBool);
impl AtomicLatch { impl AtomicLatch {
#[inline]
pub const fn new() -> AtomicLatch { pub const fn new() -> AtomicLatch {
Self(AtomicBool::new(false)) Self(AtomicBool::new(false))
} }
#[inline]
pub fn reset(&self) { pub fn reset(&self) {
self.0.store(false, Ordering::Release); self.0.store(false, Ordering::Release);
} }
} }
impl Latch for AtomicLatch { impl Latch for AtomicLatch {
#[inline]
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
(*this).0.store(true, Ordering::Release); (*this).0.store(true, Ordering::Release);
} }
} }
impl Probe for AtomicLatch { impl Probe for AtomicLatch {
#[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
self.0.load(Ordering::Acquire) self.0.load(Ordering::Acquire)
} }
@ -181,6 +195,7 @@ pub mod latch {
} }
impl ThreadWakeLatch { impl ThreadWakeLatch {
#[inline]
pub const fn new(thread: &WorkerThread) -> ThreadWakeLatch { pub const fn new(thread: &WorkerThread) -> ThreadWakeLatch {
Self { Self {
inner: AtomicLatch::new(), inner: AtomicLatch::new(),
@ -188,12 +203,14 @@ pub mod latch {
index: thread.index, index: thread.index,
} }
} }
#[inline]
pub fn reset(&self) { pub fn reset(&self) {
self.inner.reset() self.inner.reset()
} }
} }
impl Latch for ThreadWakeLatch { impl Latch for ThreadWakeLatch {
#[inline]
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
let (pool, index) = { let (pool, index) = {
let this = &*this; let this = &*this;
@ -205,6 +222,7 @@ pub mod latch {
} }
impl Probe for ThreadWakeLatch { impl Probe for ThreadWakeLatch {
#[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
self.inner.probe() self.inner.probe()
} }
@ -216,6 +234,7 @@ pub mod latch {
} }
impl MutexLatch { impl MutexLatch {
#[inline]
pub const fn new() -> MutexLatch { pub const fn new() -> MutexLatch {
Self { Self {
mutex: Mutex::new(false), mutex: Mutex::new(false),
@ -223,12 +242,14 @@ pub mod latch {
} }
} }
#[inline]
pub fn wait(&self) { pub fn wait(&self) {
let mut guard = self.mutex.lock(); let mut guard = self.mutex.lock();
while !*guard { while !*guard {
self.signal.wait(&mut guard); self.signal.wait(&mut guard);
} }
} }
#[inline]
pub fn wait_and_reset(&self) { pub fn wait_and_reset(&self) {
let mut guard = self.mutex.lock(); let mut guard = self.mutex.lock();
while !*guard { while !*guard {
@ -239,6 +260,7 @@ pub mod latch {
} }
impl Latch for MutexLatch { impl Latch for MutexLatch {
#[inline]
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
let mut guard = (*this).mutex.lock(); let mut guard = (*this).mutex.lock();
*guard = true; *guard = true;
@ -252,6 +274,7 @@ pub mod latch {
} }
impl CountWakeLatch { impl CountWakeLatch {
#[inline]
pub const fn new(count: usize, thread: &WorkerThread) -> CountWakeLatch { pub const fn new(count: usize, thread: &WorkerThread) -> CountWakeLatch {
Self { Self {
counter: AtomicUsize::new(count), counter: AtomicUsize::new(count),
@ -259,12 +282,14 @@ pub mod latch {
} }
} }
#[inline]
pub fn increment(&self) { pub fn increment(&self) {
self.counter.fetch_add(1, Ordering::Relaxed); self.counter.fetch_add(1, Ordering::Relaxed);
} }
} }
impl Latch for CountWakeLatch { impl Latch for CountWakeLatch {
#[inline]
unsafe fn set_raw(this: *const Self) { unsafe fn set_raw(this: *const Self) {
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 { if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
Latch::set_raw(&(*this).inner); Latch::set_raw(&(*this).inner);
@ -273,6 +298,7 @@ pub mod latch {
} }
impl Probe for CountWakeLatch { impl Probe for CountWakeLatch {
#[inline]
fn probe(&self) -> bool { fn probe(&self) -> bool {
self.inner.probe() self.inner.probe()
} }
@ -281,9 +307,11 @@ pub mod latch {
pub struct LatchWaker<L>(L); pub struct LatchWaker<L>(L);
impl<L> LatchWaker<L> { impl<L> LatchWaker<L> {
#[inline]
pub fn new(latch: L) -> Arc<Self> { pub fn new(latch: L) -> Arc<Self> {
Arc::new(Self(latch)) Arc::new(Self(latch))
} }
#[inline]
pub fn latch(&self) -> &L { pub fn latch(&self) -> &L {
&self.0 &self.0
} }
@ -293,9 +321,11 @@ pub mod latch {
where where
L: Latch, L: Latch,
{ {
#[inline]
fn wake(self: Arc<Self>) { fn wake(self: Arc<Self>) {
self.wake_by_ref(); self.wake_by_ref();
} }
#[inline]
fn wake_by_ref(self: &Arc<Self>) { fn wake_by_ref(self: &Arc<Self>) {
unsafe { unsafe {
Latch::set_raw(&self.0); Latch::set_raw(&self.0);
@ -328,6 +358,7 @@ pub struct ThreadState {
impl ThreadState { impl ThreadState {
/// returns true if thread was sleeping /// returns true if thread was sleeping
#[inline]
fn wake(&self) -> bool { fn wake(&self) -> bool {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
guard.insert(ThreadStatus::SHOULD_WAKE); guard.insert(ThreadStatus::SHOULD_WAKE);
@ -335,6 +366,7 @@ impl ThreadState {
guard.contains(ThreadStatus::SLEEPING) guard.contains(ThreadStatus::SLEEPING)
} }
#[inline]
fn wait_for_running(&self) { fn wait_for_running(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::RUNNING) { while !guard.contains(ThreadStatus::RUNNING) {
@ -342,6 +374,7 @@ impl ThreadState {
} }
} }
#[inline]
fn wait_for_should_wake(&self) { fn wait_for_should_wake(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::SHOULD_WAKE) { while !guard.contains(ThreadStatus::SHOULD_WAKE) {
@ -351,6 +384,7 @@ impl ThreadState {
guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING);
} }
#[inline]
fn wait_for_should_wake_timeout(&self, timeout: Duration) { fn wait_for_should_wake_timeout(&self, timeout: Duration) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while !guard.contains(ThreadStatus::SHOULD_WAKE) { while !guard.contains(ThreadStatus::SHOULD_WAKE) {
@ -366,6 +400,7 @@ impl ThreadState {
guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING);
} }
#[inline]
fn wait_for_termination(&self) { fn wait_for_termination(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
while guard.contains(ThreadStatus::RUNNING) { while guard.contains(ThreadStatus::RUNNING) {
@ -373,18 +408,21 @@ impl ThreadState {
} }
} }
#[inline]
fn notify_running(&self) { fn notify_running(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
guard.insert(ThreadStatus::RUNNING); guard.insert(ThreadStatus::RUNNING);
self.status_changed.notify_all(); self.status_changed.notify_all();
} }
#[inline]
fn notify_termination(&self) { fn notify_termination(&self) {
let mut guard = self.status.lock(); let mut guard = self.status.lock();
*guard = ThreadStatus::empty(); *guard = ThreadStatus::empty();
self.status_changed.notify_all(); self.status_changed.notify_all();
} }
#[inline]
fn notify_should_terminate(&self) { fn notify_should_terminate(&self) {
unsafe { unsafe {
Latch::set_raw(&self.should_terminate); Latch::set_raw(&self.should_terminate);
@ -451,6 +489,7 @@ impl ThreadPool {
} }
} }
#[inline]
fn threads(&self) -> &[CachePadded<ThreadState>] { fn threads(&self) -> &[CachePadded<ThreadState>] {
&self.threads[..self.pool_state.num_threads.load(Ordering::Relaxed) as usize] &self.threads[..self.pool_state.num_threads.load(Ordering::Relaxed) as usize]
} }
@ -473,6 +512,7 @@ impl ThreadPool {
} }
} }
#[inline]
pub fn id(&self) -> impl Eq { pub fn id(&self) -> impl Eq {
core::ptr::from_ref(self) as usize core::ptr::from_ref(self) as usize
} }
@ -824,32 +864,40 @@ std::thread_local! {
} }
impl WorkerThread { impl WorkerThread {
#[inline]
fn info(&self) -> &ThreadState { fn info(&self) -> &ThreadState {
&self.pool.threads[self.index as usize] &self.pool.threads[self.index as usize]
} }
#[inline]
fn pool(&self) -> &'static ThreadPool { fn pool(&self) -> &'static ThreadPool {
self.pool self.pool
} }
#[inline]
fn index(&self) -> usize { fn index(&self) -> usize {
self.index self.index
} }
#[inline]
fn is_worker_thread() -> bool { fn is_worker_thread() -> bool {
Self::with(|worker| worker.is_some()) Self::with(|worker| worker.is_some())
} }
fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T { fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T {
WORKER_THREAD_STATE.with(|thread| f(thread.get())) WORKER_THREAD_STATE.with(|thread| f(thread.get()))
} }
#[inline]
fn pop_task(&self) -> Option<TaskRef> { fn pop_task(&self) -> Option<TaskRef> {
self.queue.pop_front() self.queue.pop_front()
} }
#[inline]
fn push_task(&self, task: TaskRef) { fn push_task(&self, task: TaskRef) {
self.queue.push_front(task); self.queue.push_front(task);
} }
#[inline]
fn drain(&self) -> impl Iterator<Item = TaskRef> { fn drain(&self) -> impl Iterator<Item = TaskRef> {
self.queue.drain() self.queue.drain()
} }
#[inline]
fn claim_shoved_task(&self) -> Option<TaskRef> { fn claim_shoved_task(&self) -> Option<TaskRef> {
if let Some(task) = self.info().shoved_task.try_take() { if let Some(task) = self.info().shoved_task.try_take() {
return Some(task); return Some(task);
@ -884,6 +932,7 @@ impl WorkerThread {
task.execute(); task.execute();
} }
#[inline]
fn try_promote(&self) { fn try_promote(&self) {
#[cfg(feature = "internal_heartbeat")] #[cfg(feature = "internal_heartbeat")]
let now = std::time::Instant::now(); let now = std::time::Instant::now();
@ -907,6 +956,7 @@ impl WorkerThread {
} }
} }
#[inline]
fn find_any_task(&self) -> Option<TaskRef> { fn find_any_task(&self) -> Option<TaskRef> {
// TODO: attempt stealing work here, too. // TODO: attempt stealing work here, too.
self.pop_task() self.pop_task()
@ -914,6 +964,7 @@ impl WorkerThread {
.or_else(|| self.pool.global_queue.pop()) .or_else(|| self.pool.global_queue.pop())
} }
#[inline]
fn run_until<L>(&self, latch: &L) fn run_until<L>(&self, latch: &L)
where where
L: Probe, L: Probe,
@ -933,12 +984,14 @@ impl WorkerThread {
} }
} }
#[inline]
fn run_until_inner(&self) { fn run_until_inner(&self) {
match self.find_any_task() { match self.find_any_task() {
Some(task) => { Some(task) => {
self.execute(task); self.execute(task);
} }
None => { None => {
debug!("waiting for tasks");
self.info().wait_for_should_wake(); self.info().wait_for_should_wake();
} }
} }
@ -1012,28 +1065,36 @@ pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>);
impl<T> TaskQueue<T> { impl<T> TaskQueue<T> {
/// Creates a new [`TaskQueue<T>`]. /// Creates a new [`TaskQueue<T>`].
#[inline]
const fn new() -> Self { const fn new() -> Self {
Self(UnsafeCell::new(VecDeque::new())) Self(UnsafeCell::new(VecDeque::new()))
} }
#[inline]
fn get_mut(&self) -> &mut VecDeque<T> { fn get_mut(&self) -> &mut VecDeque<T> {
unsafe { &mut *self.0.get() } unsafe { &mut *self.0.get() }
} }
#[inline]
fn pop_front(&self) -> Option<T> { fn pop_front(&self) -> Option<T> {
self.get_mut().pop_front() self.get_mut().pop_front()
} }
#[inline]
fn pop_back(&self) -> Option<T> { fn pop_back(&self) -> Option<T> {
self.get_mut().pop_back() self.get_mut().pop_back()
} }
#[inline]
fn push_back(&self, t: T) { fn push_back(&self, t: T) {
self.get_mut().push_back(t); self.get_mut().push_back(t);
} }
#[inline]
fn push_front(&self, t: T) { fn push_front(&self, t: T) {
self.get_mut().push_front(t); self.get_mut().push_front(t);
} }
#[inline]
fn take(&self) -> VecDeque<T> { fn take(&self) -> VecDeque<T> {
let this = core::mem::replace(self.get_mut(), VecDeque::new()); let this = core::mem::replace(self.get_mut(), VecDeque::new());
this this
} }
#[inline]
fn drain(&self) -> impl Iterator<Item = T> { fn drain(&self) -> impl Iterator<Item = T> {
self.take().into_iter() self.take().into_iter()
} }
@ -1084,6 +1145,7 @@ impl<T> Slot<T> {
} }
} }
#[inline]
pub fn try_put(&self, t: T) -> Option<T> { pub fn try_put(&self, t: T) -> Option<T> {
match self.state.compare_exchange( match self.state.compare_exchange(
SlotState::empty().into(), SlotState::empty().into(),
@ -1105,6 +1167,7 @@ impl<T> Slot<T> {
} }
} }
#[inline]
pub fn try_take(&self) -> Option<T> { pub fn try_take(&self) -> Option<T> {
match self.state.compare_exchange( match self.state.compare_exchange(
SlotState::OCCUPIED.into(), SlotState::OCCUPIED.into(),
@ -1281,7 +1344,7 @@ mod tests {
1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
]; ];
const REPEAT: usize = 0x8000; const REPEAT: usize = 0x100;
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T {
let pool = Box::new(pool); let pool = Box::new(pool);
@ -1316,8 +1379,7 @@ mod tests {
pool.scope(|s| { pool.scope(|s| {
for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
s.spawn(move |_| { s.spawn(move |_| {
let tmp = (0..p).reduce(|a, b| black_box(a & b)); black_box(spinning(p));
black_box(tmp);
}); });
} }
}); });
@ -1339,8 +1401,7 @@ mod tests {
pool.scope(|s| { pool.scope(|s| {
for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
s.spawn(async move { s.spawn(async move {
let tmp = (0..p).reduce(|a, b| black_box(a & b)); black_box(spinning(p));
black_box(tmp);
}); });
} }
}); });
@ -1373,15 +1434,7 @@ mod tests {
run_in_scope(pool, |s| { run_in_scope(pool, |s| {
for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
s.spawn(move |_| { s.spawn(move |_| {
// std::thread::sleep(Duration::from_micros(p as u64)); black_box(spinning(p));
// spin for
let tmp = (0..p).reduce(|a, b| black_box(a & b));
black_box(tmp);
// WAIT_COUNT.with(|count| {
// // eprintln!("{} + {p}", count.get());
// count.set(count.get() + p);
// });
}); });
} }
}); });
@ -1389,4 +1442,27 @@ mod tests {
// eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); // eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
} }
#[test]
#[tracing_test::traced_test]
fn sync() {
let now = std::time::Instant::now();
for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() {
black_box(spinning(p));
}
let elapsed = now.elapsed().as_micros();
eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3);
}
#[inline]
fn spinning(i: usize) {
let rng = rng::XorShift64Star::new(i as u64);
(0..i).reduce(|a, b| {
black_box({
let a = rng.next_usize(a.max(1));
((b as f32).exp() * (a as f32).sin().cbrt()).to_bits() as usize
})
});
}
} }