From 8b35cb7f45ba0df8890510b62a27360d1f2d7654 Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 31 Jan 2025 00:53:11 +0100 Subject: [PATCH] comparison tests --- Cargo.toml | 2 ++ src/lib.rs | 66 ++++++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 93b6150..641c08c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,12 +5,14 @@ edition = "2021" [features] internal_heartbeat = [] +cpu-pinning = [] [dependencies] futures = "0.3" rayon = "1.10" +bevy_tasks = "0.15.1" parking_lot = "0.12.3" thread_local = "1.1.8" crossbeam = "0.8.4" diff --git a/src/lib.rs b/src/lib.rs index 2b2bd05..40fa1b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -509,6 +509,7 @@ impl ThreadPool { return self.pool_state.num_threads.load(Ordering::Acquire); } + #[cfg(feature = "cpu-pinning")] let cpus = core_affinity::get_core_ids().unwrap(); let _guard = self.pool_state.lock.lock(); @@ -530,8 +531,10 @@ impl ThreadPool { let new_threads = &self.threads[current_size..new_size]; for (i, _) in new_threads.iter().enumerate() { + #[cfg(feature = "cpu-pinning")] let core = cpus[i]; std::thread::spawn(move || { + #[cfg(feature = "cpu-pinning")] core_affinity::set_for_current(core); WorkerThread::worker_loop(&self, current_size + i); }); @@ -1266,8 +1269,6 @@ mod scope { mod tests { use std::{cell::Cell, hint::black_box}; - use crate::latch::CountWakeLatch; - use super::*; const PRIMES: &'static [usize] = &[ @@ -1280,6 +1281,8 @@ mod tests { 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, ]; + const REPEAT: usize = 0x8000; + fn run_in_scope(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { let pool = Box::new(pool); let ptr = Box::into_raw(pool); @@ -1288,7 +1291,10 @@ mod tests { let pool: &'static ThreadPool = unsafe { &*ptr }; // pool.ensure_one_worker(); pool.resize_to_available(); + let now = std::time::Instant::now(); let result = pool.scope(f); + let elapsed = now.elapsed().as_micros(); + eprintln!("(mine) total time: {}ms", elapsed as f32 / 1e3); pool.resize_to(0); assert!(pool.global_queue.pop().is_none()); result @@ -1300,12 +1306,57 @@ mod tests { #[test] #[tracing_test::traced_test] - fn spawn_random() { + fn rayon() { + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(bevy_tasks::available_parallelism()) + .build() + .unwrap(); + + let now = std::time::Instant::now(); + pool.scope(|s| { + for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { + s.spawn(move |_| { + let tmp = (0..p).reduce(|a, b| black_box(a & b)); + black_box(tmp); + }); + } + }); + let elapsed = now.elapsed().as_micros(); + + eprintln!("(rayon) total time: {}ms", elapsed as f32 / 1e3); + } + + #[test] + #[tracing_test::traced_test] + fn bevy_tasks() { + let pool = bevy_tasks::ComputeTaskPool::get_or_init(|| { + bevy_tasks::TaskPoolBuilder::new() + .num_threads(bevy_tasks::available_parallelism()) + .build() + }); + + let now = std::time::Instant::now(); + pool.scope(|s| { + for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { + s.spawn(async move { + let tmp = (0..p).reduce(|a, b| black_box(a & b)); + black_box(tmp); + }); + } + }); + let elapsed = now.elapsed().as_micros(); + + eprintln!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); + } + + #[test] + #[tracing_test::traced_test] + fn mine() { std::thread_local! { static WAIT_COUNT: Cell = const {Cell::new(0)}; } let counter = Arc::new(AtomicUsize::new(0)); - let elapsed = { + { let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks { at_entry: Some(Arc::new(|_worker| { // eprintln!("new worker thread: {}", worker.index); @@ -1319,9 +1370,8 @@ mod tests { })), }); - let now = std::time::Instant::now(); run_in_scope(pool, |s| { - for &p in core::iter::repeat_n(PRIMES, 0x200).flatten() { + for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { s.spawn(move |_| { // std::thread::sleep(Duration::from_micros(p as u64)); // spin for @@ -1335,10 +1385,8 @@ mod tests { }); } }); - now.elapsed().as_micros() }; - eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); - eprintln!("total time: {}ms", elapsed as f32 / 1e3); + // eprintln!("total wait count: {}", counter.load(Ordering::Acquire)); } }