From a691b614bc455a1a80975ae68f5d6cc5f8478e09 Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 31 Jan 2025 01:17:01 +0100 Subject: [PATCH] inline --- src/lib.rs | 104 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 90 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 40fa1b3..09faa9e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,11 +51,13 @@ pub mod task { Self { ptr, execute_fn } } + #[inline] pub fn id(&self) -> impl Eq { (self.ptr, self.execute_fn) } /// caller must ensure that this particular task is [`Send`] + #[inline] pub fn execute(self) { unsafe { (self.execute_fn)(self.ptr) } } @@ -77,19 +79,24 @@ pub mod task { } } + #[inline] pub fn run(self) { self.task.into_inner().unwrap()(); } + + #[inline] pub unsafe fn run_as_ref(&self) { ((&mut *self.task.get()).take().unwrap())(); } + #[inline] pub fn as_task_ref(self: Pin<&Self>) -> TaskRef { unsafe { TaskRef::new(&*self) } } } impl Task for StackTask { + #[inline] unsafe fn execute(this: *const ()) { let this = &*this.cast::(); let task = (&mut *this.task.get()).take().unwrap(); @@ -110,6 +117,7 @@ pub mod task { }) } + #[inline] pub unsafe fn into_static_task_ref(self: Box) -> TaskRef where F: 'static, @@ -117,11 +125,13 @@ pub mod task { self.into_task_ref() } + #[inline] pub unsafe fn into_task_ref(self: Box) -> TaskRef { TaskRef::new(Box::into_raw(self)) } } impl Task for HeapTask { + #[inline] unsafe fn execute(this: *const ()) { let this = Box::from_raw(this.cast::().cast_mut()); (this.task)(); @@ -154,21 +164,25 @@ pub mod latch { pub struct AtomicLatch(AtomicBool); impl AtomicLatch { + #[inline] pub const fn new() -> AtomicLatch { Self(AtomicBool::new(false)) } + #[inline] pub fn reset(&self) { self.0.store(false, Ordering::Release); } } impl Latch for AtomicLatch { + #[inline] unsafe fn set_raw(this: *const Self) { (*this).0.store(true, Ordering::Release); } } impl Probe for AtomicLatch { + #[inline] fn probe(&self) -> bool { self.0.load(Ordering::Acquire) } @@ -181,6 +195,7 @@ pub mod latch { } impl ThreadWakeLatch { + #[inline] pub const fn new(thread: &WorkerThread) -> ThreadWakeLatch { Self { inner: AtomicLatch::new(), @@ -188,12 +203,14 @@ pub mod latch { index: thread.index, } } + #[inline] pub fn reset(&self) { self.inner.reset() } } impl Latch for ThreadWakeLatch { + #[inline] unsafe fn set_raw(this: *const Self) { let (pool, index) = { let this = &*this; @@ -205,6 +222,7 @@ pub mod latch { } impl Probe for ThreadWakeLatch { + #[inline] fn probe(&self) -> bool { self.inner.probe() } @@ -216,6 +234,7 @@ pub mod latch { } impl MutexLatch { + #[inline] pub const fn new() -> MutexLatch { Self { mutex: Mutex::new(false), @@ -223,12 +242,14 @@ pub mod latch { } } + #[inline] pub fn wait(&self) { let mut guard = self.mutex.lock(); while !*guard { self.signal.wait(&mut guard); } } + #[inline] pub fn wait_and_reset(&self) { let mut guard = self.mutex.lock(); while !*guard { @@ -239,6 +260,7 @@ pub mod latch { } impl Latch for MutexLatch { + #[inline] unsafe fn set_raw(this: *const Self) { let mut guard = (*this).mutex.lock(); *guard = true; @@ -252,6 +274,7 @@ pub mod latch { } impl CountWakeLatch { + #[inline] pub const fn new(count: usize, thread: &WorkerThread) -> CountWakeLatch { Self { counter: AtomicUsize::new(count), @@ -259,12 +282,14 @@ pub mod latch { } } + #[inline] pub fn increment(&self) { self.counter.fetch_add(1, Ordering::Relaxed); } } impl Latch for CountWakeLatch { + #[inline] unsafe fn set_raw(this: *const Self) { if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 { Latch::set_raw(&(*this).inner); @@ -273,6 +298,7 @@ pub mod latch { } impl Probe for CountWakeLatch { + #[inline] fn probe(&self) -> bool { self.inner.probe() } @@ -281,9 +307,11 @@ pub mod latch { pub struct LatchWaker(L); impl LatchWaker { + #[inline] pub fn new(latch: L) -> Arc { Arc::new(Self(latch)) } + #[inline] pub fn latch(&self) -> &L { &self.0 } @@ -293,9 +321,11 @@ pub mod latch { where L: Latch, { + #[inline] fn wake(self: Arc) { self.wake_by_ref(); } + #[inline] fn wake_by_ref(self: &Arc) { unsafe { Latch::set_raw(&self.0); @@ -328,6 +358,7 @@ pub struct ThreadState { impl ThreadState { /// returns true if thread was sleeping + #[inline] fn wake(&self) -> bool { let mut guard = self.status.lock(); guard.insert(ThreadStatus::SHOULD_WAKE); @@ -335,6 +366,7 @@ impl ThreadState { guard.contains(ThreadStatus::SLEEPING) } + #[inline] fn wait_for_running(&self) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::RUNNING) { @@ -342,6 +374,7 @@ impl ThreadState { } } + #[inline] fn wait_for_should_wake(&self) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::SHOULD_WAKE) { @@ -351,6 +384,7 @@ impl ThreadState { guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); } + #[inline] fn wait_for_should_wake_timeout(&self, timeout: Duration) { let mut guard = self.status.lock(); while !guard.contains(ThreadStatus::SHOULD_WAKE) { @@ -366,6 +400,7 @@ impl ThreadState { guard.remove(ThreadStatus::SHOULD_WAKE | ThreadStatus::SLEEPING); } + #[inline] fn wait_for_termination(&self) { let mut guard = self.status.lock(); while guard.contains(ThreadStatus::RUNNING) { @@ -373,18 +408,21 @@ impl ThreadState { } } + #[inline] fn notify_running(&self) { let mut guard = self.status.lock(); guard.insert(ThreadStatus::RUNNING); self.status_changed.notify_all(); } + #[inline] fn notify_termination(&self) { let mut guard = self.status.lock(); *guard = ThreadStatus::empty(); self.status_changed.notify_all(); } + #[inline] fn notify_should_terminate(&self) { unsafe { Latch::set_raw(&self.should_terminate); @@ -451,6 +489,7 @@ impl ThreadPool { } } + #[inline] fn threads(&self) -> &[CachePadded] { &self.threads[..self.pool_state.num_threads.load(Ordering::Relaxed) as usize] } @@ -473,6 +512,7 @@ impl ThreadPool { } } + #[inline] pub fn id(&self) -> impl Eq { core::ptr::from_ref(self) as usize } @@ -824,32 +864,40 @@ std::thread_local! { } impl WorkerThread { + #[inline] fn info(&self) -> &ThreadState { &self.pool.threads[self.index as usize] } + #[inline] fn pool(&self) -> &'static ThreadPool { self.pool } + #[inline] fn index(&self) -> usize { self.index } + #[inline] fn is_worker_thread() -> bool { Self::with(|worker| worker.is_some()) } fn with) -> T>(f: F) -> T { WORKER_THREAD_STATE.with(|thread| f(thread.get())) } + #[inline] fn pop_task(&self) -> Option { self.queue.pop_front() } + #[inline] fn push_task(&self, task: TaskRef) { self.queue.push_front(task); } + #[inline] fn drain(&self) -> impl Iterator { self.queue.drain() } + #[inline] fn claim_shoved_task(&self) -> Option { if let Some(task) = self.info().shoved_task.try_take() { return Some(task); @@ -884,6 +932,7 @@ impl WorkerThread { task.execute(); } + #[inline] fn try_promote(&self) { #[cfg(feature = "internal_heartbeat")] let now = std::time::Instant::now(); @@ -907,6 +956,7 @@ impl WorkerThread { } } + #[inline] fn find_any_task(&self) -> Option { // TODO: attempt stealing work here, too. self.pop_task() @@ -914,6 +964,7 @@ impl WorkerThread { .or_else(|| self.pool.global_queue.pop()) } + #[inline] fn run_until(&self, latch: &L) where L: Probe, @@ -933,12 +984,14 @@ impl WorkerThread { } } + #[inline] fn run_until_inner(&self) { match self.find_any_task() { Some(task) => { self.execute(task); } None => { + debug!("waiting for tasks"); self.info().wait_for_should_wake(); } } @@ -1012,28 +1065,36 @@ pub struct TaskQueue(UnsafeCell>); impl TaskQueue { /// Creates a new [`TaskQueue`]. + #[inline] const fn new() -> Self { Self(UnsafeCell::new(VecDeque::new())) } + #[inline] fn get_mut(&self) -> &mut VecDeque { unsafe { &mut *self.0.get() } } + #[inline] fn pop_front(&self) -> Option { self.get_mut().pop_front() } + #[inline] fn pop_back(&self) -> Option { self.get_mut().pop_back() } + #[inline] fn push_back(&self, t: T) { self.get_mut().push_back(t); } + #[inline] fn push_front(&self, t: T) { self.get_mut().push_front(t); } + #[inline] fn take(&self) -> VecDeque { let this = core::mem::replace(self.get_mut(), VecDeque::new()); this } + #[inline] fn drain(&self) -> impl Iterator { self.take().into_iter() } @@ -1084,6 +1145,7 @@ impl Slot { } } + #[inline] pub fn try_put(&self, t: T) -> Option { match self.state.compare_exchange( SlotState::empty().into(), @@ -1105,6 +1167,7 @@ impl Slot { } } + #[inline] pub fn try_take(&self) -> Option { match self.state.compare_exchange( SlotState::OCCUPIED.into(), @@ -1281,7 +1344,7 @@ mod tests { 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, ]; - const REPEAT: usize = 0x8000; + const REPEAT: usize = 0x100; fn run_in_scope(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { let pool = Box::new(pool); @@ -1316,8 +1379,7 @@ mod tests { pool.scope(|s| { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { s.spawn(move |_| { - let tmp = (0..p).reduce(|a, b| black_box(a & b)); - black_box(tmp); + black_box(spinning(p)); }); } }); @@ -1339,8 +1401,7 @@ mod tests { pool.scope(|s| { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { s.spawn(async move { - let tmp = (0..p).reduce(|a, b| black_box(a & b)); - black_box(tmp); + black_box(spinning(p)); }); } }); @@ -1373,15 +1434,7 @@ mod tests { run_in_scope(pool, |s| { for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { s.spawn(move |_| { - // std::thread::sleep(Duration::from_micros(p as u64)); - // spin for - let tmp = (0..p).reduce(|a, b| black_box(a & b)); - black_box(tmp); - - // WAIT_COUNT.with(|count| { - // // eprintln!("{} + {p}", count.get()); - // count.set(count.get() + p); - // }); + black_box(spinning(p)); }); } }); @@ -1389,4 +1442,27 @@ mod tests { // eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); } + + #[test] + #[tracing_test::traced_test] + fn sync() { + let now = std::time::Instant::now(); + for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { + black_box(spinning(p)); + } + let elapsed = now.elapsed().as_micros(); + + eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3); + } + + #[inline] + fn spinning(i: usize) { + let rng = rng::XorShift64Star::new(i as u64); + (0..i).reduce(|a, b| { + black_box({ + let a = rng.next_usize(a.max(1)); + ((b as f32).exp() * (a as f32).sin().cbrt()).to_bits() as usize + }) + }); + } }