idk this sucks
This commit is contained in:
		
							parent
							
								
									a691b614bc
								
							
						
					
					
						commit
						736e4e1a60
					
				|  | @ -4,8 +4,12 @@ version = "0.1.0" | ||||||
| edition = "2021" | edition = "2021" | ||||||
| 
 | 
 | ||||||
| [features] | [features] | ||||||
| internal_heartbeat = [] | heartbeat = [] | ||||||
|  | spin-slow = [] | ||||||
| cpu-pinning = [] | cpu-pinning = [] | ||||||
|  | work-stealing = [] | ||||||
|  | prefer-local = [] | ||||||
|  | never-local = [] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| [dependencies] | [dependencies] | ||||||
|  | @ -16,6 +20,7 @@ bevy_tasks = "0.15.1" | ||||||
| parking_lot = "0.12.3" | parking_lot = "0.12.3" | ||||||
| thread_local = "1.1.8" | thread_local = "1.1.8" | ||||||
| crossbeam = "0.8.4" | crossbeam = "0.8.4" | ||||||
|  | st3 = "0.4" | ||||||
| 
 | 
 | ||||||
| async-task = "4.7.1" | async-task = "4.7.1" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										550
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										550
									
								
								src/lib.rs
									
									
									
									
									
								
							|  | @ -1,11 +1,10 @@ | ||||||
| use std::{ | use std::{ | ||||||
|     cell::{OnceCell, UnsafeCell}, |     cell::{Cell, UnsafeCell}, | ||||||
|     collections::VecDeque, |  | ||||||
|     future::Future, |     future::Future, | ||||||
|     mem::MaybeUninit, |     mem::MaybeUninit, | ||||||
|     num::NonZero, |     num::NonZero, | ||||||
|     pin::{pin, Pin}, |     pin::{pin, Pin}, | ||||||
|     ptr::NonNull, |     ptr::{self, NonNull}, | ||||||
|     sync::{ |     sync::{ | ||||||
|         atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, |         atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, | ||||||
|         Arc, |         Arc, | ||||||
|  | @ -17,7 +16,11 @@ use std::{ | ||||||
| 
 | 
 | ||||||
| use async_task::{Runnable, Task}; | use async_task::{Runnable, Task}; | ||||||
| use bitflags::bitflags; | use bitflags::bitflags; | ||||||
| use crossbeam::{queue::SegQueue, utils::CachePadded}; | use crossbeam::{ | ||||||
|  |     atomic::AtomicCell, | ||||||
|  |     deque::{Injector, Stealer, Worker}, | ||||||
|  |     utils::CachePadded, | ||||||
|  | }; | ||||||
| use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; | use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; | ||||||
| use parking_lot::{Condvar, Mutex}; | use parking_lot::{Condvar, Mutex}; | ||||||
| use scope::Scope; | use scope::Scope; | ||||||
|  | @ -337,7 +340,7 @@ pub mod latch { | ||||||
| pub struct ThreadPoolState { | pub struct ThreadPoolState { | ||||||
|     num_threads: AtomicUsize, |     num_threads: AtomicUsize, | ||||||
|     lock: Mutex<()>, |     lock: Mutex<()>, | ||||||
|     heartbeat_state: CachePadded<ThreadState>, |     heartbeat_state: CachePadded<ThreadControl>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bitflags! { | bitflags! { | ||||||
|  | @ -348,15 +351,21 @@ bitflags! { | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub struct ThreadState { | pub struct ThreadControl { | ||||||
|     should_shove: AtomicBool, |  | ||||||
|     shoved_task: Slot<TaskRef>, |  | ||||||
|     status: Mutex<ThreadStatus>, |     status: Mutex<ThreadStatus>, | ||||||
|     status_changed: Condvar, |     status_changed: Condvar, | ||||||
|     should_terminate: AtomicLatch, |     should_terminate: AtomicLatch, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl ThreadState { | pub struct ThreadState { | ||||||
|  |     should_shove: AtomicBool, | ||||||
|  |     control: ThreadControl, | ||||||
|  |     stealer: Stealer<TaskRef>, | ||||||
|  |     worker: AtomicCell<Option<Worker<TaskRef>>>, | ||||||
|  |     shoved_task: CachePadded<Slot<TaskRef>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl ThreadControl { | ||||||
|     /// returns true if thread was sleeping
 |     /// returns true if thread was sleeping
 | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn wake(&self) -> bool { |     fn wake(&self) -> bool { | ||||||
|  | @ -451,40 +460,48 @@ impl ThreadPoolCallbacks { | ||||||
| pub struct ThreadPool { | pub struct ThreadPool { | ||||||
|     threads: [CachePadded<ThreadState>; MAX_THREADS], |     threads: [CachePadded<ThreadState>; MAX_THREADS], | ||||||
|     pool_state: CachePadded<ThreadPoolState>, |     pool_state: CachePadded<ThreadPoolState>, | ||||||
|     global_queue: SegQueue<TaskRef>, |     global_queue: Injector<TaskRef>, | ||||||
|     callbacks: CachePadded<ThreadPoolCallbacks>, |     callbacks: CachePadded<ThreadPoolCallbacks>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl ThreadPool { | impl ThreadPool { | ||||||
|     const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState { |     pub fn new() -> Self { | ||||||
|  |         Self::new_with_callbacks(ThreadPoolCallbacks::new_empty()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { | ||||||
|  |         let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| { | ||||||
|  |             let worker = Worker::<TaskRef>::new_fifo(); | ||||||
|  |             let stealer = worker.stealer(); | ||||||
|  | 
 | ||||||
|  |             let thread = CachePadded::new(ThreadState { | ||||||
|                 should_shove: AtomicBool::new(false), |                 should_shove: AtomicBool::new(false), | ||||||
|         shoved_task: Slot::new(), |                 shoved_task: Slot::new().into(), | ||||||
|  |                 control: ThreadControl { | ||||||
|                     status: Mutex::new(ThreadStatus::empty()), |                     status: Mutex::new(ThreadStatus::empty()), | ||||||
|                     status_changed: Condvar::new(), |                     status_changed: Condvar::new(), | ||||||
|                     should_terminate: AtomicLatch::new(), |                     should_terminate: AtomicLatch::new(), | ||||||
|  |                 }, | ||||||
|  |                 stealer, | ||||||
|  |                 worker: AtomicCell::new(Some(worker)), | ||||||
|  |             }); | ||||||
|  |             uninit.write(thread); | ||||||
|  |             unsafe { uninit.assume_init() } | ||||||
|         }); |         }); | ||||||
|     pub const fn new() -> Self { |  | ||||||
|         Self { |  | ||||||
|             threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, |  | ||||||
|             pool_state: CachePadded::new(ThreadPoolState { |  | ||||||
|                 num_threads: AtomicUsize::new(0), |  | ||||||
|                 lock: Mutex::new(()), |  | ||||||
|                 heartbeat_state: Self::INITIAL_THREAD_STATE, |  | ||||||
|             }), |  | ||||||
|             global_queue: SegQueue::new(), |  | ||||||
|             callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { |  | ||||||
|         Self { |         Self { | ||||||
|             threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, |             threads, | ||||||
|             pool_state: CachePadded::new(ThreadPoolState { |             pool_state: CachePadded::new(ThreadPoolState { | ||||||
|                 num_threads: AtomicUsize::new(0), |                 num_threads: AtomicUsize::new(0), | ||||||
|                 lock: Mutex::new(()), |                 lock: Mutex::new(()), | ||||||
|                 heartbeat_state: Self::INITIAL_THREAD_STATE, |                 heartbeat_state: ThreadControl { | ||||||
|  |                     status: Mutex::new(ThreadStatus::empty()), | ||||||
|  |                     status_changed: Condvar::new(), | ||||||
|  |                     should_terminate: AtomicLatch::new(), | ||||||
|  |                 } | ||||||
|  |                 .into(), | ||||||
|             }), |             }), | ||||||
|             global_queue: SegQueue::new(), |             global_queue: Injector::new(), | ||||||
|             callbacks: CachePadded::new(callbacks), |             callbacks: CachePadded::new(callbacks), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | @ -495,7 +512,7 @@ impl ThreadPool { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn wake_thread(&self, index: usize) -> Option<bool> { |     pub fn wake_thread(&self, index: usize) -> Option<bool> { | ||||||
|         Some(self.threads.get(index as usize)?.wake()) |         Some(self.threads.get(index as usize)?.control.wake()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn wake_any(&self, count: usize) -> usize { |     pub fn wake_any(&self, count: usize) -> usize { | ||||||
|  | @ -503,7 +520,7 @@ impl ThreadPool { | ||||||
|             let num_woken = self |             let num_woken = self | ||||||
|                 .threads |                 .threads | ||||||
|                 .iter() |                 .iter() | ||||||
|                 .filter_map(|thread| thread.wake().then_some(())) |                 .filter_map(|thread| thread.control.wake().then_some(())) | ||||||
|                 .take(count) |                 .take(count) | ||||||
|                 .count(); |                 .count(); | ||||||
|             num_woken |             num_woken | ||||||
|  | @ -517,6 +534,27 @@ impl ThreadPool { | ||||||
|         core::ptr::from_ref(self) as usize |         core::ptr::from_ref(self) as usize | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     fn push_local_or_inject_balanced(&self, task: TaskRef) { | ||||||
|  |         let global_len = self.global_queue.len(); | ||||||
|  |         WorkerThread::with(|worker| match worker { | ||||||
|  |             Some(worker) if worker.pool.id() == self.id() => { | ||||||
|  |                 let worker_len = worker.worker.len(); | ||||||
|  |                 if worker_len == 0 { | ||||||
|  |                     worker.push_task(task); | ||||||
|  |                 } else if global_len == 0 { | ||||||
|  |                     self.inject(task); | ||||||
|  |                 } else { | ||||||
|  |                     if worker_len >= global_len { | ||||||
|  |                         worker.push_task(task); | ||||||
|  |                     } else { | ||||||
|  |                         self.inject(task); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             _ => self.inject(task), | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     fn push_local_or_inject(&self, task: TaskRef) { |     fn push_local_or_inject(&self, task: TaskRef) { | ||||||
|         WorkerThread::with(|worker| match worker { |         WorkerThread::with(|worker| match worker { | ||||||
|             Some(worker) if worker.pool.id() == self.id() => worker.push_task(task), |             Some(worker) if worker.pool.id() == self.id() => worker.push_task(task), | ||||||
|  | @ -536,6 +574,15 @@ impl ThreadPool { | ||||||
|         self.wake_any(n); |         self.wake_any(n); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     fn inject_maybe_local(&self, task: TaskRef) { | ||||||
|  |         #[cfg(all(not(feature = "never-local"), feature = "prefer-local"))] | ||||||
|  |         self.push_local_or_inject(task); | ||||||
|  |         #[cfg(all(not(feature = "prefer-local"), feature = "never-local"))] | ||||||
|  |         self.inject(task); | ||||||
|  |         #[cfg(not(any(feature = "prefer-local", feature = "never-local")))] | ||||||
|  |         self.push_local_or_inject_balanced(task); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     fn inject(&self, task: TaskRef) { |     fn inject(&self, task: TaskRef) { | ||||||
|         self.global_queue.push(task); |         self.global_queue.push(task); | ||||||
| 
 | 
 | ||||||
|  | @ -581,10 +628,10 @@ impl ThreadPool { | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 for thread in new_threads { |                 for thread in new_threads { | ||||||
|                     thread.wait_for_running(); |                     thread.control.wait_for_running(); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 #[cfg(not(feature = "internal_heartbeat"))] |                 #[cfg(feature = "heartbeat")] | ||||||
|                 if current_size == 0 { |                 if current_size == 0 { | ||||||
|                     std::thread::spawn(move || { |                     std::thread::spawn(move || { | ||||||
|                         heartbeat_loop(self); |                         heartbeat_loop(self); | ||||||
|  | @ -601,13 +648,13 @@ impl ThreadPool { | ||||||
|                 let terminating_threads = &self.threads[new_size..current_size]; |                 let terminating_threads = &self.threads[new_size..current_size]; | ||||||
| 
 | 
 | ||||||
|                 for thread in terminating_threads { |                 for thread in terminating_threads { | ||||||
|                     thread.notify_should_terminate(); |                     thread.control.notify_should_terminate(); | ||||||
|                 } |                 } | ||||||
|                 for thread in terminating_threads { |                 for thread in terminating_threads { | ||||||
|                     thread.wait_for_termination(); |                     thread.control.wait_for_termination(); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 #[cfg(not(feature = "internal_heartbeat"))] |                 #[cfg(feature = "heartbeat")] | ||||||
|                 if new_size == 0 { |                 if new_size == 0 { | ||||||
|                     self.pool_state.heartbeat_state.notify_should_terminate(); |                     self.pool_state.heartbeat_state.notify_should_terminate(); | ||||||
|                     self.pool_state.heartbeat_state.wait_for_termination(); |                     self.pool_state.heartbeat_state.wait_for_termination(); | ||||||
|  | @ -712,7 +759,7 @@ impl ThreadPool { | ||||||
|         })); |         })); | ||||||
| 
 | 
 | ||||||
|         let taskref = task.into_ref().as_task_ref(); |         let taskref = task.into_ref().as_task_ref(); | ||||||
|         self.inject(taskref); |         self.push_local_or_inject(taskref); | ||||||
| 
 | 
 | ||||||
|         worker.run_until(&latch); |         worker.run_until(&latch); | ||||||
|         result.unwrap() |         result.unwrap() | ||||||
|  | @ -727,7 +774,7 @@ impl ThreadPool { | ||||||
|         let task = HeapTask::new(f); |         let task = HeapTask::new(f); | ||||||
| 
 | 
 | ||||||
|         let taskref = unsafe { task.into_static_task_ref() }; |         let taskref = unsafe { task.into_static_task_ref() }; | ||||||
|         self.push_local_or_inject(taskref); |         self.inject_maybe_local(taskref); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T> |     fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T> | ||||||
|  | @ -745,7 +792,7 @@ impl ThreadPool { | ||||||
|                 }) |                 }) | ||||||
|             }; |             }; | ||||||
| 
 | 
 | ||||||
|             self.push_local_or_inject(taskref); |             self.inject_maybe_local(taskref); | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let (runnable, task) = async_task::spawn(future, schedule); |         let (runnable, task) = async_task::spawn(future, schedule); | ||||||
|  | @ -789,6 +836,30 @@ impl ThreadPool { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) |     fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) | ||||||
|  |     where | ||||||
|  |         F: FnOnce() -> T + Send, | ||||||
|  |         G: FnOnce() -> U + Send, | ||||||
|  |         T: Send, | ||||||
|  |         U: Send, | ||||||
|  |     { | ||||||
|  |         self.join_threaded(f, g) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn join_seq<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) | ||||||
|  |     where | ||||||
|  |         F: FnOnce() -> T + Send, | ||||||
|  |         G: FnOnce() -> U + Send, | ||||||
|  |         T: Send, | ||||||
|  |         U: Send, | ||||||
|  |     { | ||||||
|  |         let a = f(); | ||||||
|  |         let b = g(); | ||||||
|  | 
 | ||||||
|  |         (a, b) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline] | ||||||
|  |     fn join_threaded<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) | ||||||
|     where |     where | ||||||
|         F: FnOnce() -> T + Send, |         F: FnOnce() -> T + Send, | ||||||
|         G: FnOnce() -> U + Send, |         G: FnOnce() -> U + Send, | ||||||
|  | @ -808,7 +879,6 @@ impl ThreadPool { | ||||||
| 
 | 
 | ||||||
|             let ref_b = task_b.as_ref().as_task_ref(); |             let ref_b = task_b.as_ref().as_task_ref(); | ||||||
|             let b_id = ref_b.id(); |             let b_id = ref_b.id(); | ||||||
|             // TODO: maybe try to push this off to another thread immediately first?
 |  | ||||||
|             worker.push_task(ref_b); |             worker.push_task(ref_b); | ||||||
| 
 | 
 | ||||||
|             let result_a = f(); |             let result_a = f(); | ||||||
|  | @ -817,14 +887,18 @@ impl ThreadPool { | ||||||
|                 match worker.pop_task() { |                 match worker.pop_task() { | ||||||
|                     Some(task) => { |                     Some(task) => { | ||||||
|                         if task.id() == b_id { |                         if task.id() == b_id { | ||||||
|                             worker.try_promote(); |                             // we're not calling execute() here, so manually try
 | ||||||
|  |                             // shoving a task.
 | ||||||
|  |                             //worker.try_promote();
 | ||||||
|  |                             worker.shove_task(); | ||||||
|                             unsafe { |                             unsafe { | ||||||
|                                 task_b.run_as_ref(); |                                 task_b.run_as_ref(); | ||||||
|                             } |                             } | ||||||
|                             break; |                             break; | ||||||
|                         } |                         } else { | ||||||
|                             worker.execute(task); |                             worker.execute(task); | ||||||
|                         } |                         } | ||||||
|  |                     } | ||||||
|                     None => { |                     None => { | ||||||
|                         worker.run_until(&latch_b); |                         worker.run_until(&latch_b); | ||||||
|                     } |                     } | ||||||
|  | @ -837,12 +911,12 @@ impl ThreadPool { | ||||||
| 
 | 
 | ||||||
|     fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T |     fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T | ||||||
|     where |     where | ||||||
|         Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send, |         Fn: FnOnce(&Scope<'scope>) -> T + Send, | ||||||
|         T: Send, |         T: Send, | ||||||
|     { |     { | ||||||
|         self.in_worker(|owner, _| { |         self.in_worker(|owner, _| { | ||||||
|             let scope = pin!(unsafe { Scope::<'scope>::new(owner) }); |             let scope = unsafe { Scope::<'scope>::new(owner) }; | ||||||
|             let result = f(scope.as_ref()); |             let result = f(&scope); | ||||||
|             scope.complete(owner); |             scope.complete(owner); | ||||||
|             result |             result | ||||||
|         }) |         }) | ||||||
|  | @ -850,7 +924,8 @@ impl ThreadPool { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub struct WorkerThread { | pub struct WorkerThread { | ||||||
|     queue: TaskQueue<TaskRef>, |     // queue: TaskQueue<TaskRef>,
 | ||||||
|  |     worker: Worker<TaskRef>, | ||||||
|     pool: &'static ThreadPool, |     pool: &'static ThreadPool, | ||||||
|     index: usize, |     index: usize, | ||||||
|     rng: rng::XorShift64Star, |     rng: rng::XorShift64Star, | ||||||
|  | @ -860,7 +935,7 @@ pub struct WorkerThread { | ||||||
| const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) }; | const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) }; | ||||||
| 
 | 
 | ||||||
| std::thread_local! { | std::thread_local! { | ||||||
|     static WORKER_THREAD_STATE: CachePadded<OnceCell<WorkerThread>> = const {CachePadded::new(OnceCell::new())}; |     static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())}; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl WorkerThread { | impl WorkerThread { | ||||||
|  | @ -880,25 +955,47 @@ impl WorkerThread { | ||||||
|     fn is_worker_thread() -> bool { |     fn is_worker_thread() -> bool { | ||||||
|         Self::with(|worker| worker.is_some()) |         Self::with(|worker| worker.is_some()) | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|     fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T { |     fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T { | ||||||
|         WORKER_THREAD_STATE.with(|thread| f(thread.get())) |         WORKER_THREAD_STATE.with(|thread| { | ||||||
|  |             f(NonNull::<WorkerThread>::new(thread.get().cast_mut()) | ||||||
|  |                 .map(|ptr| unsafe { ptr.as_ref() })) | ||||||
|  |         }) | ||||||
|     } |     } | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn pop_task(&self) -> Option<TaskRef> { |     fn pop_task(&self) -> Option<TaskRef> { | ||||||
|         self.queue.pop_front() |         self.worker.pop() | ||||||
|  |         //self.queue.pop_front(task);
 | ||||||
|     } |     } | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn push_task(&self, task: TaskRef) { |     fn push_task(&self, task: TaskRef) { | ||||||
|         self.queue.push_front(task); |         self.worker.push(task); | ||||||
|  |         //self.queue.push_front(task);
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn drain(&self) -> impl Iterator<Item = TaskRef> { |     fn drain(&self) -> impl Iterator<Item = TaskRef> { | ||||||
|         self.queue.drain() |         // self.queue.drain()
 | ||||||
|  |         core::iter::empty() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline] | ||||||
|  |     fn steal_tasks(&self) -> Option<TaskRef> { | ||||||
|  |         // careful not to call threads() here because that omits any threads
 | ||||||
|  |         // that were killed, which might still have tasks.
 | ||||||
|  |         let threads = &self.pool.threads; | ||||||
|  |         let (start, end) = threads.split_at(self.rng.next_usize(threads.len())); | ||||||
|  | 
 | ||||||
|  |         end.iter() | ||||||
|  |             .chain(start) | ||||||
|  |             .find_map(|thread: &CachePadded<ThreadState>| { | ||||||
|  |                 thread.stealer.steal_batch_and_pop(&self.worker).success() | ||||||
|  |             }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn claim_shoved_task(&self) -> Option<TaskRef> { |     fn claim_shoved_task(&self) -> Option<TaskRef> { | ||||||
|  |         // take own shoved task first
 | ||||||
|         if let Some(task) = self.info().shoved_task.try_take() { |         if let Some(task) = self.info().shoved_task.try_take() { | ||||||
|             return Some(task); |             return Some(task); | ||||||
|         } |         } | ||||||
|  | @ -916,12 +1013,16 @@ impl WorkerThread { | ||||||
| 
 | 
 | ||||||
|     #[cold] |     #[cold] | ||||||
|     fn shove_task(&self) { |     fn shove_task(&self) { | ||||||
|         if let Some(task) = self.queue.pop_back() { |         if !self.info().shoved_task.is_occupied() { | ||||||
|  |             if let Some(task) = self.info().stealer.steal().success() { | ||||||
|                 match self.info().shoved_task.try_put(task) { |                 match self.info().shoved_task.try_put(task) { | ||||||
|                     // shoved task is occupied, reinsert into queue
 |                     // shoved task is occupied, reinsert into queue
 | ||||||
|                 Some(task) => self.queue.push_back(task), |                     // this really shouldn't happen
 | ||||||
|  |                     Some(_task) => unreachable!(), | ||||||
|                     None => {} |                     None => {} | ||||||
|                 } |                 } | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|             // wake thread to execute task
 |             // wake thread to execute task
 | ||||||
|             self.pool.wake_any(1); |             self.pool.wake_any(1); | ||||||
|         } |         } | ||||||
|  | @ -934,24 +1035,15 @@ impl WorkerThread { | ||||||
| 
 | 
 | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn try_promote(&self) { |     fn try_promote(&self) { | ||||||
|         #[cfg(feature = "internal_heartbeat")] |         #[cfg(feature = "heartbeat")] | ||||||
|         let now = std::time::Instant::now(); |  | ||||||
|         // SAFETY: workerthread is thread-local non-sync
 |  | ||||||
| 
 |  | ||||||
|         #[cfg(feature = "internal_heartbeat")] |  | ||||||
|         let should_shove = |  | ||||||
|             unsafe { *self.last_heartbeat.get() }.duration_since(now) > HEARTBEAT_INTERVAL; |  | ||||||
|         #[cfg(not(feature = "internal_heartbeat"))] |  | ||||||
|         let should_shove = self.info().should_shove.load(Ordering::Acquire); |         let should_shove = self.info().should_shove.load(Ordering::Acquire); | ||||||
|  |         #[cfg(not(feature = "heartbeat"))] | ||||||
|  |         let should_shove = true; | ||||||
| 
 | 
 | ||||||
|         if should_shove { |         if should_shove { | ||||||
|             // SAFETY: workerthread is thread-local non-sync
 |             #[cfg(feature = "heartbeat")] | ||||||
|             #[cfg(feature = "internal_heartbeat")] |  | ||||||
|             unsafe { |  | ||||||
|                 *&mut *self.last_heartbeat.get() = now; |  | ||||||
|             } |  | ||||||
|             #[cfg(not(feature = "internal_heartbeat"))] |  | ||||||
|             self.info().should_shove.store(false, Ordering::Release); |             self.info().should_shove.store(false, Ordering::Release); | ||||||
|  | 
 | ||||||
|             self.shove_task(); |             self.shove_task(); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | @ -959,9 +1051,22 @@ impl WorkerThread { | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn find_any_task(&self) -> Option<TaskRef> { |     fn find_any_task(&self) -> Option<TaskRef> { | ||||||
|         // TODO: attempt stealing work here, too.
 |         // TODO: attempt stealing work here, too.
 | ||||||
|         self.pop_task() |         let mut task = self | ||||||
|  |             .pop_task() | ||||||
|             .or_else(|| self.claim_shoved_task()) |             .or_else(|| self.claim_shoved_task()) | ||||||
|             .or_else(|| self.pool.global_queue.pop()) |             .or_else(|| { | ||||||
|  |                 self.pool | ||||||
|  |                     .global_queue | ||||||
|  |                     .steal_batch_and_pop(&self.worker) | ||||||
|  |                     .success() | ||||||
|  |             }); | ||||||
|  | 
 | ||||||
|  |         #[cfg(feature = "work-stealing")] | ||||||
|  |         { | ||||||
|  |             task = task.or_else(|| self.steal_tasks()); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         task | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[inline] |     #[inline] | ||||||
|  | @ -991,34 +1096,36 @@ impl WorkerThread { | ||||||
|                 self.execute(task); |                 self.execute(task); | ||||||
|             } |             } | ||||||
|             None => { |             None => { | ||||||
|                 debug!("waiting for tasks"); |                 //debug!("waiting for tasks");
 | ||||||
|                 self.info().wait_for_should_wake(); |                 self.info().control.wait_for_should_wake(); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn worker_loop(pool: &'static ThreadPool, index: usize) { |     fn worker_loop(pool: &'static ThreadPool, index: usize) { | ||||||
|         let info = &pool.threads()[index as usize]; |         let info = &pool.threads()[index as usize]; | ||||||
| 
 |         let worker = CachePadded::new(WorkerThread { | ||||||
|         WORKER_THREAD_STATE.with(|worker| { |             // queue: TaskQueue::new(),
 | ||||||
|             let worker = worker.get_or_init(|| WorkerThread { |             worker: info.worker.take().unwrap(), | ||||||
|                 queue: TaskQueue::new(), |  | ||||||
|             pool, |             pool, | ||||||
|             index, |             index, | ||||||
|             rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64), |             rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64), | ||||||
|             last_heartbeat: UnsafeCell::new(std::time::Instant::now()), |             last_heartbeat: UnsafeCell::new(std::time::Instant::now()), | ||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|  |         WORKER_THREAD_STATE.with(|cell| { | ||||||
|  |             cell.set(&*worker); | ||||||
|  | 
 | ||||||
|             if let Some(callback) = pool.callbacks.at_entry.as_ref() { |             if let Some(callback) = pool.callbacks.at_entry.as_ref() { | ||||||
|                 callback(worker); |                 callback(&worker); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             info.notify_running(); |             info.control.notify_running(); | ||||||
|             // info.notify_running();
 |             // info.notify_running();
 | ||||||
|             worker.run_until(&info.should_terminate); |             worker.run_until(&info.control.should_terminate); | ||||||
| 
 | 
 | ||||||
|             if let Some(callback) = pool.callbacks.at_exit.as_ref() { |             if let Some(callback) = pool.callbacks.at_exit.as_ref() { | ||||||
|                 callback(worker); |                 callback(&worker); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             for task in worker.drain() { |             for task in worker.drain() { | ||||||
|  | @ -1028,9 +1135,14 @@ impl WorkerThread { | ||||||
|             if let Some(task) = info.shoved_task.try_take() { |             if let Some(task) = info.shoved_task.try_take() { | ||||||
|                 pool.inject(task); |                 pool.inject(task); | ||||||
|             } |             } | ||||||
|  | 
 | ||||||
|  |             cell.set(ptr::null()); | ||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|         info.notify_termination(); |         let WorkerThread { worker, .. } = CachePadded::into_inner(worker); | ||||||
|  |         info.worker.store(Some(worker)); | ||||||
|  | 
 | ||||||
|  |         info.control.notify_termination(); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -1061,56 +1173,68 @@ fn heartbeat_loop(pool: &'static ThreadPool) { | ||||||
|     state.notify_termination(); |     state.notify_termination(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | use vec_queue::TaskQueue; | ||||||
|  | 
 | ||||||
|  | mod vec_queue { | ||||||
|  |     use std::{cell::UnsafeCell, collections::VecDeque}; | ||||||
|  | 
 | ||||||
|     pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>); |     pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>); | ||||||
| 
 | 
 | ||||||
|     impl<T> TaskQueue<T> { |     impl<T> TaskQueue<T> { | ||||||
|         /// Creates a new [`TaskQueue<T>`].
 |         /// Creates a new [`TaskQueue<T>`].
 | ||||||
|         #[inline] |         #[inline] | ||||||
|     const fn new() -> Self { |         pub const fn new() -> Self { | ||||||
|             Self(UnsafeCell::new(VecDeque::new())) |             Self(UnsafeCell::new(VecDeque::new())) | ||||||
|         } |         } | ||||||
|         #[inline] |         #[inline] | ||||||
|     fn get_mut(&self) -> &mut VecDeque<T> { |         pub fn get_mut(&self) -> &mut VecDeque<T> { | ||||||
|             unsafe { &mut *self.0.get() } |             unsafe { &mut *self.0.get() } | ||||||
|         } |         } | ||||||
|         #[inline] |         #[inline] | ||||||
|     fn pop_front(&self) -> Option<T> { |         pub fn pop_front(&self) -> Option<T> { | ||||||
|             self.get_mut().pop_front() |             self.get_mut().pop_front() | ||||||
|         } |         } | ||||||
|         #[inline] |         #[inline] | ||||||
|     fn pop_back(&self) -> Option<T> { |         pub fn pop_back(&self) -> Option<T> { | ||||||
|             self.get_mut().pop_back() |             self.get_mut().pop_back() | ||||||
|         } |         } | ||||||
|         #[inline] |         #[inline] | ||||||
|     fn push_back(&self, t: T) { |         pub fn push_back(&self, t: T) { | ||||||
|             self.get_mut().push_back(t); |             self.get_mut().push_back(t); | ||||||
|         } |         } | ||||||
|         #[inline] |         #[inline] | ||||||
|     fn push_front(&self, t: T) { |         pub fn push_front(&self, t: T) { | ||||||
|             self.get_mut().push_front(t); |             self.get_mut().push_front(t); | ||||||
|         } |         } | ||||||
|         #[inline] |         #[inline] | ||||||
|     fn take(&self) -> VecDeque<T> { |         pub fn take(&self) -> VecDeque<T> { | ||||||
|             let this = core::mem::replace(self.get_mut(), VecDeque::new()); |             let this = core::mem::replace(self.get_mut(), VecDeque::new()); | ||||||
|             this |             this | ||||||
|         } |         } | ||||||
|         #[inline] |         #[inline] | ||||||
|     fn drain(&self) -> impl Iterator<Item = T> { |         pub fn drain(&self) -> impl Iterator<Item = T> { | ||||||
|             self.take().into_iter() |             self.take().into_iter() | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| bitflags! { | #[repr(u8)] | ||||||
|     #[derive(Debug, Clone, Copy)] | #[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||||||
|     pub struct SlotState: u8 { | enum SlotState { | ||||||
|         const LOCKED = 1 << 1; |     None, | ||||||
|         const OCCUPIED = 1 << 2; |     Locked, | ||||||
|  |     Occupied, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<u8> for SlotState { | ||||||
|  |     fn from(value: u8) -> Self { | ||||||
|  |         unsafe { core::mem::transmute(value) } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl From<SlotState> for u8 { | impl From<SlotState> for u8 { | ||||||
|     fn from(value: SlotState) -> Self { |     fn from(value: SlotState) -> Self { | ||||||
|         value.bits() |         value as u8 | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -1125,10 +1249,7 @@ unsafe impl<T> Sync for Slot<T> where T: Send {} | ||||||
| impl<T> Drop for Slot<T> { | impl<T> Drop for Slot<T> { | ||||||
|     fn drop(&mut self) { |     fn drop(&mut self) { | ||||||
|         if core::mem::needs_drop::<T>() { |         if core::mem::needs_drop::<T>() { | ||||||
|             if SlotState::from_bits(*self.state.get_mut()) |             if *self.state.get_mut() == SlotState::Occupied as u8 { | ||||||
|                 .unwrap() |  | ||||||
|                 .contains(SlotState::OCCUPIED) |  | ||||||
|             { |  | ||||||
|                 unsafe { |                 unsafe { | ||||||
|                     self.slot.get().drop_in_place(); |                     self.slot.get().drop_in_place(); | ||||||
|                 } |                 } | ||||||
|  | @ -1141,15 +1262,56 @@ impl<T> Slot<T> { | ||||||
|     pub const fn new() -> Slot<T> { |     pub const fn new() -> Slot<T> { | ||||||
|         Self { |         Self { | ||||||
|             slot: UnsafeCell::new(MaybeUninit::uninit()), |             slot: UnsafeCell::new(MaybeUninit::uninit()), | ||||||
|             state: AtomicU8::new(SlotState::empty().bits()), |             state: AtomicU8::new(SlotState::None as u8), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     pub fn is_occupied(&self) -> bool { | ||||||
|  |         self.state.load(Ordering::Acquire) == SlotState::Occupied.into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline] | ||||||
|  |     pub fn insert(&self, t: T) -> Option<T> { | ||||||
|  |         let value = match self | ||||||
|  |             .state | ||||||
|  |             .swap(SlotState::Locked.into(), Ordering::AcqRel) | ||||||
|  |             .into() | ||||||
|  |         { | ||||||
|  |             SlotState::Locked => { | ||||||
|  |                 // return early: was already locked.
 | ||||||
|  |                 debug!("slot was already locked"); | ||||||
|  |                 return None; | ||||||
|  |             } | ||||||
|  |             SlotState::Occupied => { | ||||||
|  |                 let slot = self.slot.get(); | ||||||
|  |                 // replace
 | ||||||
|  |                 unsafe { | ||||||
|  |                     let v = (*slot).assume_init_read(); | ||||||
|  |                     (*slot).write(t); | ||||||
|  |                     Some(v) | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             SlotState::None => { | ||||||
|  |                 let slot = self.slot.get(); | ||||||
|  |                 // insert
 | ||||||
|  |                 unsafe { | ||||||
|  |                     (*slot).write(t); | ||||||
|  |                 } | ||||||
|  |                 None | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         // release lock
 | ||||||
|  |         self.state | ||||||
|  |             .store(SlotState::Occupied.into(), Ordering::Release); | ||||||
|  |         value | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     #[inline] |     #[inline] | ||||||
|     pub fn try_put(&self, t: T) -> Option<T> { |     pub fn try_put(&self, t: T) -> Option<T> { | ||||||
|         match self.state.compare_exchange( |         match self.state.compare_exchange( | ||||||
|             SlotState::empty().into(), |             SlotState::None.into(), | ||||||
|             SlotState::LOCKED.into(), |             SlotState::Locked.into(), | ||||||
|             Ordering::Acquire, |             Ordering::Acquire, | ||||||
|             Ordering::Relaxed, |             Ordering::Relaxed, | ||||||
|         ) { |         ) { | ||||||
|  | @ -1161,7 +1323,7 @@ impl<T> Slot<T> { | ||||||
| 
 | 
 | ||||||
|                 // release lock
 |                 // release lock
 | ||||||
|                 self.state |                 self.state | ||||||
|                     .store(SlotState::OCCUPIED.into(), Ordering::Release); |                     .store(SlotState::Occupied.into(), Ordering::Release); | ||||||
|                 None |                 None | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  | @ -1170,8 +1332,8 @@ impl<T> Slot<T> { | ||||||
|     #[inline] |     #[inline] | ||||||
|     pub fn try_take(&self) -> Option<T> { |     pub fn try_take(&self) -> Option<T> { | ||||||
|         match self.state.compare_exchange( |         match self.state.compare_exchange( | ||||||
|             SlotState::OCCUPIED.into(), |             SlotState::Occupied.into(), | ||||||
|             SlotState::LOCKED.into(), |             SlotState::Locked.into(), | ||||||
|             Ordering::Acquire, |             Ordering::Acquire, | ||||||
|             Ordering::Relaxed, |             Ordering::Relaxed, | ||||||
|         ) { |         ) { | ||||||
|  | @ -1181,8 +1343,7 @@ impl<T> Slot<T> { | ||||||
|                 let t = unsafe { (*slot).assume_init_read() }; |                 let t = unsafe { (*slot).assume_init_read() }; | ||||||
| 
 | 
 | ||||||
|                 // release lock
 |                 // release lock
 | ||||||
|                 self.state |                 self.state.store(SlotState::None.into(), Ordering::Release); | ||||||
|                     .store(SlotState::empty().into(), Ordering::Release); |  | ||||||
|                 Some(t) |                 Some(t) | ||||||
|             } |             } | ||||||
|             Err(_) => None, |             Err(_) => None, | ||||||
|  | @ -1227,14 +1388,15 @@ mod scope { | ||||||
|     use std::{ |     use std::{ | ||||||
|         future::Future, |         future::Future, | ||||||
|         marker::{PhantomData, PhantomPinned}, |         marker::{PhantomData, PhantomPinned}, | ||||||
|  |         pin::pin, | ||||||
|         ptr::{self, NonNull}, |         ptr::{self, NonNull}, | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     use async_task::{Runnable, Task}; |     use async_task::{Runnable, Task}; | ||||||
| 
 | 
 | ||||||
|     use crate::{ |     use crate::{ | ||||||
|         latch::{CountWakeLatch, Latch}, |         latch::{CountWakeLatch, Latch, Probe, ThreadWakeLatch}, | ||||||
|         task::{HeapTask, TaskRef}, |         task::{HeapTask, StackTask, TaskRef}, | ||||||
|         ThreadPool, WorkerThread, |         ThreadPool, WorkerThread, | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  | @ -1253,6 +1415,16 @@ mod scope { | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         pub fn join<F, G, T, U>(&self, f: F, g: G) -> (T, U) | ||||||
|  |         where | ||||||
|  |             F: FnOnce(&Self) -> T + Send, | ||||||
|  |             G: FnOnce(&Self) -> U + Send, | ||||||
|  |             T: Send, | ||||||
|  |             U: Send, | ||||||
|  |         { | ||||||
|  |             self.pool.join(|| f(self), || g(self)) | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         pub fn spawn<Fn>(&self, f: Fn) |         pub fn spawn<Fn>(&self, f: Fn) | ||||||
|         where |         where | ||||||
|             Fn: FnOnce(&Scope<'scope>) + Send + 'scope, |             Fn: FnOnce(&Scope<'scope>) + Send + 'scope, | ||||||
|  | @ -1267,7 +1439,7 @@ mod scope { | ||||||
|             }); |             }); | ||||||
| 
 | 
 | ||||||
|             let taskref = unsafe { task.into_task_ref() }; |             let taskref = unsafe { task.into_task_ref() }; | ||||||
|             self.pool.push_local_or_inject(taskref); |             self.pool.inject_maybe_local(taskref); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T> |         pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T> | ||||||
|  | @ -1289,7 +1461,7 @@ mod scope { | ||||||
|                 }; |                 }; | ||||||
| 
 | 
 | ||||||
|                 unsafe { |                 unsafe { | ||||||
|                     ptr.as_ref().pool.push_local_or_inject(taskref); |                     ptr.as_ref().pool.inject_maybe_local(taskref); | ||||||
|                 } |                 } | ||||||
|             }; |             }; | ||||||
| 
 | 
 | ||||||
|  | @ -1332,8 +1504,58 @@ mod scope { | ||||||
| mod tests { | mod tests { | ||||||
|     use std::{cell::Cell, hint::black_box}; |     use std::{cell::Cell, hint::black_box}; | ||||||
| 
 | 
 | ||||||
|  |     use tracing::info; | ||||||
|  | 
 | ||||||
|     use super::*; |     use super::*; | ||||||
| 
 | 
 | ||||||
|  |     mod tree { | ||||||
|  | 
 | ||||||
|  |         pub struct Tree<T> { | ||||||
|  |             nodes: Box<[Node<T>]>, | ||||||
|  |             root: Option<usize>, | ||||||
|  |         } | ||||||
|  |         pub struct Node<T> { | ||||||
|  |             pub leaf: T, | ||||||
|  |             pub left: Option<usize>, | ||||||
|  |             pub right: Option<usize>, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         impl<T> Tree<T> { | ||||||
|  |             pub fn new(depth: usize, t: T) -> Tree<T> | ||||||
|  |             where | ||||||
|  |                 T: Copy, | ||||||
|  |             { | ||||||
|  |                 let mut nodes = Vec::with_capacity((0..depth).sum()); | ||||||
|  |                 let root = Self::build_node(&mut nodes, depth, t); | ||||||
|  |                 Self { | ||||||
|  |                     nodes: nodes.into_boxed_slice(), | ||||||
|  |                     root: Some(root), | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             pub fn root(&self) -> Option<usize> { | ||||||
|  |                 self.root | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             pub fn get(&self, index: usize) -> &Node<T> { | ||||||
|  |                 &self.nodes[index] | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize | ||||||
|  |             where | ||||||
|  |                 T: Copy, | ||||||
|  |             { | ||||||
|  |                 let node = Node { | ||||||
|  |                     leaf: t, | ||||||
|  |                     left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), | ||||||
|  |                     right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), | ||||||
|  |                 }; | ||||||
|  |                 nodes.push(node); | ||||||
|  |                 nodes.len() - 1 | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     const PRIMES: &'static [usize] = &[ |     const PRIMES: &'static [usize] = &[ | ||||||
|         1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, |         1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, | ||||||
|         1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, |         1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, | ||||||
|  | @ -1344,9 +1566,14 @@ mod tests { | ||||||
|         1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, |         1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, | ||||||
|     ]; |     ]; | ||||||
| 
 | 
 | ||||||
|     const REPEAT: usize = 0x100; |     #[cfg(feature = "spin-slow")] | ||||||
|  |     const REPEAT: usize = 0x800; | ||||||
|  |     #[cfg(not(feature = "spin-slow"))] | ||||||
|  |     const REPEAT: usize = 0x8000; | ||||||
| 
 | 
 | ||||||
|     fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { |     const TREE_SIZE: usize = 10; | ||||||
|  | 
 | ||||||
|  |     fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T { | ||||||
|         let pool = Box::new(pool); |         let pool = Box::new(pool); | ||||||
|         let ptr = Box::into_raw(pool); |         let ptr = Box::into_raw(pool); | ||||||
| 
 | 
 | ||||||
|  | @ -1357,9 +1584,9 @@ mod tests { | ||||||
|             let now = std::time::Instant::now(); |             let now = std::time::Instant::now(); | ||||||
|             let result = pool.scope(f); |             let result = pool.scope(f); | ||||||
|             let elapsed = now.elapsed().as_micros(); |             let elapsed = now.elapsed().as_micros(); | ||||||
|             eprintln!("(mine) total time: {}ms", elapsed as f32 / 1e3); |             info!("(mine) total time: {}ms", elapsed as f32 / 1e3); | ||||||
|             pool.resize_to(0); |             pool.resize_to(0); | ||||||
|             assert!(pool.global_queue.pop().is_none()); |             assert!(pool.global_queue.is_empty()); | ||||||
|             result |             result | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|  | @ -1385,7 +1612,38 @@ mod tests { | ||||||
|         }); |         }); | ||||||
|         let elapsed = now.elapsed().as_micros(); |         let elapsed = now.elapsed().as_micros(); | ||||||
| 
 | 
 | ||||||
|         eprintln!("(rayon) total time: {}ms", elapsed as f32 / 1e3); |         info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[test] | ||||||
|  |     #[tracing_test::traced_test] | ||||||
|  |     fn rayon_join() { | ||||||
|  |         let pool = rayon::ThreadPoolBuilder::new() | ||||||
|  |             .num_threads(bevy_tasks::available_parallelism()) | ||||||
|  |             .build() | ||||||
|  |             .unwrap(); | ||||||
|  | 
 | ||||||
|  |         let tree = tree::Tree::new(TREE_SIZE, 1u32); | ||||||
|  | 
 | ||||||
|  |         fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 { | ||||||
|  |             let node = tree.get(node); | ||||||
|  |             let (l, r) = rayon::join( | ||||||
|  |                 || node.left.map(|node| sum(tree, node)).unwrap_or_default(), | ||||||
|  |                 || node.right.map(|node| sum(tree, node)).unwrap_or_default(), | ||||||
|  |             ); | ||||||
|  | 
 | ||||||
|  |             node.leaf + l + r | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let now = std::time::Instant::now(); | ||||||
|  |         let sum = pool.scope(move |s| { | ||||||
|  |             let root = tree.root().unwrap(); | ||||||
|  |             sum(&tree, root) | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         let elapsed = now.elapsed().as_micros(); | ||||||
|  | 
 | ||||||
|  |         info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[test] |     #[test] | ||||||
|  | @ -1407,7 +1665,7 @@ mod tests { | ||||||
|         }); |         }); | ||||||
|         let elapsed = now.elapsed().as_micros(); |         let elapsed = now.elapsed().as_micros(); | ||||||
| 
 | 
 | ||||||
|         eprintln!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); |         info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[test] |     #[test] | ||||||
|  | @ -1418,18 +1676,7 @@ mod tests { | ||||||
|         } |         } | ||||||
|         let counter = Arc::new(AtomicUsize::new(0)); |         let counter = Arc::new(AtomicUsize::new(0)); | ||||||
|         { |         { | ||||||
|             let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks { |             let pool = ThreadPool::new(); | ||||||
|                 at_entry: Some(Arc::new(|_worker| { |  | ||||||
|                     // eprintln!("new worker thread: {}", worker.index);
 |  | ||||||
|                 })), |  | ||||||
|                 at_exit: Some(Arc::new({ |  | ||||||
|                     let counter = counter.clone(); |  | ||||||
|                     move |_worker: &WorkerThread| { |  | ||||||
|                         // eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get());
 |  | ||||||
|                         counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed); |  | ||||||
|                     } |  | ||||||
|                 })), |  | ||||||
|             }); |  | ||||||
| 
 | 
 | ||||||
|             run_in_scope(pool, |s| { |             run_in_scope(pool, |s| { | ||||||
|                 for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { |                 for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { | ||||||
|  | @ -1443,6 +1690,33 @@ mod tests { | ||||||
|         // eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
 |         // eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     #[test] | ||||||
|  |     #[tracing_test::traced_test] | ||||||
|  |     fn mine_join() { | ||||||
|  |         let pool = ThreadPool::new(); | ||||||
|  | 
 | ||||||
|  |         let tree = tree::Tree::new(TREE_SIZE, 1u32); | ||||||
|  | 
 | ||||||
|  |         fn sum(tree: &tree::Tree<u32>, node: usize, scope: &Scope<'_>) -> u32 { | ||||||
|  |             let node = tree.get(node); | ||||||
|  |             let (l, r) = scope.join( | ||||||
|  |                 |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), | ||||||
|  |                 |s| { | ||||||
|  |                     node.right | ||||||
|  |                         .map(|node| sum(tree, node, s)) | ||||||
|  |                         .unwrap_or_default() | ||||||
|  |                 }, | ||||||
|  |             ); | ||||||
|  | 
 | ||||||
|  |             node.leaf + l + r | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let sum = run_in_scope(pool, move |s| { | ||||||
|  |             let root = tree.root().unwrap(); | ||||||
|  |             sum(&tree, root, s) | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     #[test] |     #[test] | ||||||
|     #[tracing_test::traced_test] |     #[tracing_test::traced_test] | ||||||
|     fn sync() { |     fn sync() { | ||||||
|  | @ -1452,11 +1726,19 @@ mod tests { | ||||||
|         } |         } | ||||||
|         let elapsed = now.elapsed().as_micros(); |         let elapsed = now.elapsed().as_micros(); | ||||||
| 
 | 
 | ||||||
|         eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3); |         info!("(sync) total time: {}ms", elapsed as f32 / 1e3); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn spinning(i: usize) { |     fn spinning(i: usize) { | ||||||
|  |         #[cfg(feature = "spin-slow")] | ||||||
|  |         spinning_slow(i); | ||||||
|  |         #[cfg(not(feature = "spin-slow"))] | ||||||
|  |         spinning_fast(i); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[inline] | ||||||
|  |     fn spinning_slow(i: usize) { | ||||||
|         let rng = rng::XorShift64Star::new(i as u64); |         let rng = rng::XorShift64Star::new(i as u64); | ||||||
|         (0..i).reduce(|a, b| { |         (0..i).reduce(|a, b| { | ||||||
|             black_box({ |             black_box({ | ||||||
|  | @ -1465,4 +1747,16 @@ mod tests { | ||||||
|             }) |             }) | ||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     #[inline] | ||||||
|  |     fn spinning_fast(i: usize) { | ||||||
|  |         let rng = rng::XorShift64Star::new(i as u64); | ||||||
|  |         //(0..rng.next_usize(i)).reduce(|a, b| {
 | ||||||
|  |         (0..20).reduce(|a, b| { | ||||||
|  |             black_box({ | ||||||
|  |                 let a = rng.next_usize(a.max(1)); | ||||||
|  |                 a ^ b | ||||||
|  |             }) | ||||||
|  |         }); | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue