idk this sucks
This commit is contained in:
		
							parent
							
								
									a691b614bc
								
							
						
					
					
						commit
						736e4e1a60
					
				|  | @ -4,8 +4,12 @@ version = "0.1.0" | |||
| edition = "2021" | ||||
| 
 | ||||
| [features] | ||||
| internal_heartbeat = [] | ||||
| heartbeat = [] | ||||
| spin-slow = [] | ||||
| cpu-pinning = [] | ||||
| work-stealing = [] | ||||
| prefer-local = [] | ||||
| never-local = [] | ||||
| 
 | ||||
| 
 | ||||
| [dependencies] | ||||
|  | @ -16,6 +20,7 @@ bevy_tasks = "0.15.1" | |||
| parking_lot = "0.12.3" | ||||
| thread_local = "1.1.8" | ||||
| crossbeam = "0.8.4" | ||||
| st3 = "0.4" | ||||
| 
 | ||||
| async-task = "4.7.1" | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										550
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										550
									
								
								src/lib.rs
									
									
									
									
									
								
							|  | @ -1,11 +1,10 @@ | |||
| use std::{ | ||||
|     cell::{OnceCell, UnsafeCell}, | ||||
|     collections::VecDeque, | ||||
|     cell::{Cell, UnsafeCell}, | ||||
|     future::Future, | ||||
|     mem::MaybeUninit, | ||||
|     num::NonZero, | ||||
|     pin::{pin, Pin}, | ||||
|     ptr::NonNull, | ||||
|     ptr::{self, NonNull}, | ||||
|     sync::{ | ||||
|         atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, | ||||
|         Arc, | ||||
|  | @ -17,7 +16,11 @@ use std::{ | |||
| 
 | ||||
| use async_task::{Runnable, Task}; | ||||
| use bitflags::bitflags; | ||||
| use crossbeam::{queue::SegQueue, utils::CachePadded}; | ||||
| use crossbeam::{ | ||||
|     atomic::AtomicCell, | ||||
|     deque::{Injector, Stealer, Worker}, | ||||
|     utils::CachePadded, | ||||
| }; | ||||
| use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; | ||||
| use parking_lot::{Condvar, Mutex}; | ||||
| use scope::Scope; | ||||
|  | @ -337,7 +340,7 @@ pub mod latch { | |||
| pub struct ThreadPoolState { | ||||
|     num_threads: AtomicUsize, | ||||
|     lock: Mutex<()>, | ||||
|     heartbeat_state: CachePadded<ThreadState>, | ||||
|     heartbeat_state: CachePadded<ThreadControl>, | ||||
| } | ||||
| 
 | ||||
| bitflags! { | ||||
|  | @ -348,15 +351,21 @@ bitflags! { | |||
|     } | ||||
| } | ||||
| 
 | ||||
| pub struct ThreadState { | ||||
|     should_shove: AtomicBool, | ||||
|     shoved_task: Slot<TaskRef>, | ||||
| pub struct ThreadControl { | ||||
|     status: Mutex<ThreadStatus>, | ||||
|     status_changed: Condvar, | ||||
|     should_terminate: AtomicLatch, | ||||
| } | ||||
| 
 | ||||
| impl ThreadState { | ||||
| pub struct ThreadState { | ||||
|     should_shove: AtomicBool, | ||||
|     control: ThreadControl, | ||||
|     stealer: Stealer<TaskRef>, | ||||
|     worker: AtomicCell<Option<Worker<TaskRef>>>, | ||||
|     shoved_task: CachePadded<Slot<TaskRef>>, | ||||
| } | ||||
| 
 | ||||
| impl ThreadControl { | ||||
|     /// returns true if thread was sleeping
 | ||||
|     #[inline] | ||||
|     fn wake(&self) -> bool { | ||||
|  | @ -451,40 +460,48 @@ impl ThreadPoolCallbacks { | |||
| pub struct ThreadPool { | ||||
|     threads: [CachePadded<ThreadState>; MAX_THREADS], | ||||
|     pool_state: CachePadded<ThreadPoolState>, | ||||
|     global_queue: SegQueue<TaskRef>, | ||||
|     global_queue: Injector<TaskRef>, | ||||
|     callbacks: CachePadded<ThreadPoolCallbacks>, | ||||
| } | ||||
| 
 | ||||
| impl ThreadPool { | ||||
|     const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState { | ||||
|     pub fn new() -> Self { | ||||
|         Self::new_with_callbacks(ThreadPoolCallbacks::new_empty()) | ||||
|     } | ||||
| 
 | ||||
|     pub fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { | ||||
|         let threads = [const { MaybeUninit::uninit() }; MAX_THREADS].map(|mut uninit| { | ||||
|             let worker = Worker::<TaskRef>::new_fifo(); | ||||
|             let stealer = worker.stealer(); | ||||
| 
 | ||||
|             let thread = CachePadded::new(ThreadState { | ||||
|                 should_shove: AtomicBool::new(false), | ||||
|         shoved_task: Slot::new(), | ||||
|                 shoved_task: Slot::new().into(), | ||||
|                 control: ThreadControl { | ||||
|                     status: Mutex::new(ThreadStatus::empty()), | ||||
|                     status_changed: Condvar::new(), | ||||
|                     should_terminate: AtomicLatch::new(), | ||||
|                 }, | ||||
|                 stealer, | ||||
|                 worker: AtomicCell::new(Some(worker)), | ||||
|             }); | ||||
|             uninit.write(thread); | ||||
|             unsafe { uninit.assume_init() } | ||||
|         }); | ||||
|     pub const fn new() -> Self { | ||||
|         Self { | ||||
|             threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, | ||||
|             pool_state: CachePadded::new(ThreadPoolState { | ||||
|                 num_threads: AtomicUsize::new(0), | ||||
|                 lock: Mutex::new(()), | ||||
|                 heartbeat_state: Self::INITIAL_THREAD_STATE, | ||||
|             }), | ||||
|             global_queue: SegQueue::new(), | ||||
|             callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool { | ||||
|         Self { | ||||
|             threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] }, | ||||
|             threads, | ||||
|             pool_state: CachePadded::new(ThreadPoolState { | ||||
|                 num_threads: AtomicUsize::new(0), | ||||
|                 lock: Mutex::new(()), | ||||
|                 heartbeat_state: Self::INITIAL_THREAD_STATE, | ||||
|                 heartbeat_state: ThreadControl { | ||||
|                     status: Mutex::new(ThreadStatus::empty()), | ||||
|                     status_changed: Condvar::new(), | ||||
|                     should_terminate: AtomicLatch::new(), | ||||
|                 } | ||||
|                 .into(), | ||||
|             }), | ||||
|             global_queue: SegQueue::new(), | ||||
|             global_queue: Injector::new(), | ||||
|             callbacks: CachePadded::new(callbacks), | ||||
|         } | ||||
|     } | ||||
|  | @ -495,7 +512,7 @@ impl ThreadPool { | |||
|     } | ||||
| 
 | ||||
|     pub fn wake_thread(&self, index: usize) -> Option<bool> { | ||||
|         Some(self.threads.get(index as usize)?.wake()) | ||||
|         Some(self.threads.get(index as usize)?.control.wake()) | ||||
|     } | ||||
| 
 | ||||
|     pub fn wake_any(&self, count: usize) -> usize { | ||||
|  | @ -503,7 +520,7 @@ impl ThreadPool { | |||
|             let num_woken = self | ||||
|                 .threads | ||||
|                 .iter() | ||||
|                 .filter_map(|thread| thread.wake().then_some(())) | ||||
|                 .filter_map(|thread| thread.control.wake().then_some(())) | ||||
|                 .take(count) | ||||
|                 .count(); | ||||
|             num_woken | ||||
|  | @ -517,6 +534,27 @@ impl ThreadPool { | |||
|         core::ptr::from_ref(self) as usize | ||||
|     } | ||||
| 
 | ||||
|     fn push_local_or_inject_balanced(&self, task: TaskRef) { | ||||
|         let global_len = self.global_queue.len(); | ||||
|         WorkerThread::with(|worker| match worker { | ||||
|             Some(worker) if worker.pool.id() == self.id() => { | ||||
|                 let worker_len = worker.worker.len(); | ||||
|                 if worker_len == 0 { | ||||
|                     worker.push_task(task); | ||||
|                 } else if global_len == 0 { | ||||
|                     self.inject(task); | ||||
|                 } else { | ||||
|                     if worker_len >= global_len { | ||||
|                         worker.push_task(task); | ||||
|                     } else { | ||||
|                         self.inject(task); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             _ => self.inject(task), | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     fn push_local_or_inject(&self, task: TaskRef) { | ||||
|         WorkerThread::with(|worker| match worker { | ||||
|             Some(worker) if worker.pool.id() == self.id() => worker.push_task(task), | ||||
|  | @ -536,6 +574,15 @@ impl ThreadPool { | |||
|         self.wake_any(n); | ||||
|     } | ||||
| 
 | ||||
|     fn inject_maybe_local(&self, task: TaskRef) { | ||||
|         #[cfg(all(not(feature = "never-local"), feature = "prefer-local"))] | ||||
|         self.push_local_or_inject(task); | ||||
|         #[cfg(all(not(feature = "prefer-local"), feature = "never-local"))] | ||||
|         self.inject(task); | ||||
|         #[cfg(not(any(feature = "prefer-local", feature = "never-local")))] | ||||
|         self.push_local_or_inject_balanced(task); | ||||
|     } | ||||
| 
 | ||||
|     fn inject(&self, task: TaskRef) { | ||||
|         self.global_queue.push(task); | ||||
| 
 | ||||
|  | @ -581,10 +628,10 @@ impl ThreadPool { | |||
|                 } | ||||
| 
 | ||||
|                 for thread in new_threads { | ||||
|                     thread.wait_for_running(); | ||||
|                     thread.control.wait_for_running(); | ||||
|                 } | ||||
| 
 | ||||
|                 #[cfg(not(feature = "internal_heartbeat"))] | ||||
|                 #[cfg(feature = "heartbeat")] | ||||
|                 if current_size == 0 { | ||||
|                     std::thread::spawn(move || { | ||||
|                         heartbeat_loop(self); | ||||
|  | @ -601,13 +648,13 @@ impl ThreadPool { | |||
|                 let terminating_threads = &self.threads[new_size..current_size]; | ||||
| 
 | ||||
|                 for thread in terminating_threads { | ||||
|                     thread.notify_should_terminate(); | ||||
|                     thread.control.notify_should_terminate(); | ||||
|                 } | ||||
|                 for thread in terminating_threads { | ||||
|                     thread.wait_for_termination(); | ||||
|                     thread.control.wait_for_termination(); | ||||
|                 } | ||||
| 
 | ||||
|                 #[cfg(not(feature = "internal_heartbeat"))] | ||||
|                 #[cfg(feature = "heartbeat")] | ||||
|                 if new_size == 0 { | ||||
|                     self.pool_state.heartbeat_state.notify_should_terminate(); | ||||
|                     self.pool_state.heartbeat_state.wait_for_termination(); | ||||
|  | @ -712,7 +759,7 @@ impl ThreadPool { | |||
|         })); | ||||
| 
 | ||||
|         let taskref = task.into_ref().as_task_ref(); | ||||
|         self.inject(taskref); | ||||
|         self.push_local_or_inject(taskref); | ||||
| 
 | ||||
|         worker.run_until(&latch); | ||||
|         result.unwrap() | ||||
|  | @ -727,7 +774,7 @@ impl ThreadPool { | |||
|         let task = HeapTask::new(f); | ||||
| 
 | ||||
|         let taskref = unsafe { task.into_static_task_ref() }; | ||||
|         self.push_local_or_inject(taskref); | ||||
|         self.inject_maybe_local(taskref); | ||||
|     } | ||||
| 
 | ||||
|     fn spawn_future<Fut, T>(&'static self, future: Fut) -> Task<T> | ||||
|  | @ -745,7 +792,7 @@ impl ThreadPool { | |||
|                 }) | ||||
|             }; | ||||
| 
 | ||||
|             self.push_local_or_inject(taskref); | ||||
|             self.inject_maybe_local(taskref); | ||||
|         }; | ||||
| 
 | ||||
|         let (runnable, task) = async_task::spawn(future, schedule); | ||||
|  | @ -789,6 +836,30 @@ impl ThreadPool { | |||
|     } | ||||
| 
 | ||||
|     fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) | ||||
|     where | ||||
|         F: FnOnce() -> T + Send, | ||||
|         G: FnOnce() -> U + Send, | ||||
|         T: Send, | ||||
|         U: Send, | ||||
|     { | ||||
|         self.join_threaded(f, g) | ||||
|     } | ||||
| 
 | ||||
|     fn join_seq<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) | ||||
|     where | ||||
|         F: FnOnce() -> T + Send, | ||||
|         G: FnOnce() -> U + Send, | ||||
|         T: Send, | ||||
|         U: Send, | ||||
|     { | ||||
|         let a = f(); | ||||
|         let b = g(); | ||||
| 
 | ||||
|         (a, b) | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn join_threaded<F, G, T, U>(&'static self, f: F, g: G) -> (T, U) | ||||
|     where | ||||
|         F: FnOnce() -> T + Send, | ||||
|         G: FnOnce() -> U + Send, | ||||
|  | @ -808,7 +879,6 @@ impl ThreadPool { | |||
| 
 | ||||
|             let ref_b = task_b.as_ref().as_task_ref(); | ||||
|             let b_id = ref_b.id(); | ||||
|             // TODO: maybe try to push this off to another thread immediately first?
 | ||||
|             worker.push_task(ref_b); | ||||
| 
 | ||||
|             let result_a = f(); | ||||
|  | @ -817,14 +887,18 @@ impl ThreadPool { | |||
|                 match worker.pop_task() { | ||||
|                     Some(task) => { | ||||
|                         if task.id() == b_id { | ||||
|                             worker.try_promote(); | ||||
|                             // we're not calling execute() here, so manually try
 | ||||
|                             // shoving a task.
 | ||||
|                             //worker.try_promote();
 | ||||
|                             worker.shove_task(); | ||||
|                             unsafe { | ||||
|                                 task_b.run_as_ref(); | ||||
|                             } | ||||
|                             break; | ||||
|                         } | ||||
|                         } else { | ||||
|                             worker.execute(task); | ||||
|                         } | ||||
|                     } | ||||
|                     None => { | ||||
|                         worker.run_until(&latch_b); | ||||
|                     } | ||||
|  | @ -837,12 +911,12 @@ impl ThreadPool { | |||
| 
 | ||||
|     fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T | ||||
|     where | ||||
|         Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send, | ||||
|         Fn: FnOnce(&Scope<'scope>) -> T + Send, | ||||
|         T: Send, | ||||
|     { | ||||
|         self.in_worker(|owner, _| { | ||||
|             let scope = pin!(unsafe { Scope::<'scope>::new(owner) }); | ||||
|             let result = f(scope.as_ref()); | ||||
|             let scope = unsafe { Scope::<'scope>::new(owner) }; | ||||
|             let result = f(&scope); | ||||
|             scope.complete(owner); | ||||
|             result | ||||
|         }) | ||||
|  | @ -850,7 +924,8 @@ impl ThreadPool { | |||
| } | ||||
| 
 | ||||
| pub struct WorkerThread { | ||||
|     queue: TaskQueue<TaskRef>, | ||||
|     // queue: TaskQueue<TaskRef>,
 | ||||
|     worker: Worker<TaskRef>, | ||||
|     pool: &'static ThreadPool, | ||||
|     index: usize, | ||||
|     rng: rng::XorShift64Star, | ||||
|  | @ -860,7 +935,7 @@ pub struct WorkerThread { | |||
| const HEARTBEAT_INTERVAL: core::time::Duration = const { core::time::Duration::from_micros(100) }; | ||||
| 
 | ||||
| std::thread_local! { | ||||
|     static WORKER_THREAD_STATE: CachePadded<OnceCell<WorkerThread>> = const {CachePadded::new(OnceCell::new())}; | ||||
|     static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const {Cell::new(ptr::null())}; | ||||
| } | ||||
| 
 | ||||
| impl WorkerThread { | ||||
|  | @ -880,25 +955,47 @@ impl WorkerThread { | |||
|     fn is_worker_thread() -> bool { | ||||
|         Self::with(|worker| worker.is_some()) | ||||
|     } | ||||
| 
 | ||||
|     fn with<T, F: FnOnce(Option<&Self>) -> T>(f: F) -> T { | ||||
|         WORKER_THREAD_STATE.with(|thread| f(thread.get())) | ||||
|         WORKER_THREAD_STATE.with(|thread| { | ||||
|             f(NonNull::<WorkerThread>::new(thread.get().cast_mut()) | ||||
|                 .map(|ptr| unsafe { ptr.as_ref() })) | ||||
|         }) | ||||
|     } | ||||
|     #[inline] | ||||
|     fn pop_task(&self) -> Option<TaskRef> { | ||||
|         self.queue.pop_front() | ||||
|         self.worker.pop() | ||||
|         //self.queue.pop_front(task);
 | ||||
|     } | ||||
|     #[inline] | ||||
|     fn push_task(&self, task: TaskRef) { | ||||
|         self.queue.push_front(task); | ||||
|         self.worker.push(task); | ||||
|         //self.queue.push_front(task);
 | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn drain(&self) -> impl Iterator<Item = TaskRef> { | ||||
|         self.queue.drain() | ||||
|         // self.queue.drain()
 | ||||
|         core::iter::empty() | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn steal_tasks(&self) -> Option<TaskRef> { | ||||
|         // careful not to call threads() here because that omits any threads
 | ||||
|         // that were killed, which might still have tasks.
 | ||||
|         let threads = &self.pool.threads; | ||||
|         let (start, end) = threads.split_at(self.rng.next_usize(threads.len())); | ||||
| 
 | ||||
|         end.iter() | ||||
|             .chain(start) | ||||
|             .find_map(|thread: &CachePadded<ThreadState>| { | ||||
|                 thread.stealer.steal_batch_and_pop(&self.worker).success() | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn claim_shoved_task(&self) -> Option<TaskRef> { | ||||
|         // take own shoved task first
 | ||||
|         if let Some(task) = self.info().shoved_task.try_take() { | ||||
|             return Some(task); | ||||
|         } | ||||
|  | @ -916,12 +1013,16 @@ impl WorkerThread { | |||
| 
 | ||||
|     #[cold] | ||||
|     fn shove_task(&self) { | ||||
|         if let Some(task) = self.queue.pop_back() { | ||||
|         if !self.info().shoved_task.is_occupied() { | ||||
|             if let Some(task) = self.info().stealer.steal().success() { | ||||
|                 match self.info().shoved_task.try_put(task) { | ||||
|                     // shoved task is occupied, reinsert into queue
 | ||||
|                 Some(task) => self.queue.push_back(task), | ||||
|                     // this really shouldn't happen
 | ||||
|                     Some(_task) => unreachable!(), | ||||
|                     None => {} | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|             // wake thread to execute task
 | ||||
|             self.pool.wake_any(1); | ||||
|         } | ||||
|  | @ -934,24 +1035,15 @@ impl WorkerThread { | |||
| 
 | ||||
|     #[inline] | ||||
|     fn try_promote(&self) { | ||||
|         #[cfg(feature = "internal_heartbeat")] | ||||
|         let now = std::time::Instant::now(); | ||||
|         // SAFETY: workerthread is thread-local non-sync
 | ||||
| 
 | ||||
|         #[cfg(feature = "internal_heartbeat")] | ||||
|         let should_shove = | ||||
|             unsafe { *self.last_heartbeat.get() }.duration_since(now) > HEARTBEAT_INTERVAL; | ||||
|         #[cfg(not(feature = "internal_heartbeat"))] | ||||
|         #[cfg(feature = "heartbeat")] | ||||
|         let should_shove = self.info().should_shove.load(Ordering::Acquire); | ||||
|         #[cfg(not(feature = "heartbeat"))] | ||||
|         let should_shove = true; | ||||
| 
 | ||||
|         if should_shove { | ||||
|             // SAFETY: workerthread is thread-local non-sync
 | ||||
|             #[cfg(feature = "internal_heartbeat")] | ||||
|             unsafe { | ||||
|                 *&mut *self.last_heartbeat.get() = now; | ||||
|             } | ||||
|             #[cfg(not(feature = "internal_heartbeat"))] | ||||
|             #[cfg(feature = "heartbeat")] | ||||
|             self.info().should_shove.store(false, Ordering::Release); | ||||
| 
 | ||||
|             self.shove_task(); | ||||
|         } | ||||
|     } | ||||
|  | @ -959,9 +1051,22 @@ impl WorkerThread { | |||
|     #[inline] | ||||
|     fn find_any_task(&self) -> Option<TaskRef> { | ||||
|         // TODO: attempt stealing work here, too.
 | ||||
|         self.pop_task() | ||||
|         let mut task = self | ||||
|             .pop_task() | ||||
|             .or_else(|| self.claim_shoved_task()) | ||||
|             .or_else(|| self.pool.global_queue.pop()) | ||||
|             .or_else(|| { | ||||
|                 self.pool | ||||
|                     .global_queue | ||||
|                     .steal_batch_and_pop(&self.worker) | ||||
|                     .success() | ||||
|             }); | ||||
| 
 | ||||
|         #[cfg(feature = "work-stealing")] | ||||
|         { | ||||
|             task = task.or_else(|| self.steal_tasks()); | ||||
|         } | ||||
| 
 | ||||
|         task | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|  | @ -991,34 +1096,36 @@ impl WorkerThread { | |||
|                 self.execute(task); | ||||
|             } | ||||
|             None => { | ||||
|                 debug!("waiting for tasks"); | ||||
|                 self.info().wait_for_should_wake(); | ||||
|                 //debug!("waiting for tasks");
 | ||||
|                 self.info().control.wait_for_should_wake(); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn worker_loop(pool: &'static ThreadPool, index: usize) { | ||||
|         let info = &pool.threads()[index as usize]; | ||||
| 
 | ||||
|         WORKER_THREAD_STATE.with(|worker| { | ||||
|             let worker = worker.get_or_init(|| WorkerThread { | ||||
|                 queue: TaskQueue::new(), | ||||
|         let worker = CachePadded::new(WorkerThread { | ||||
|             // queue: TaskQueue::new(),
 | ||||
|             worker: info.worker.take().unwrap(), | ||||
|             pool, | ||||
|             index, | ||||
|             rng: rng::XorShift64Star::new(pool as *const _ as u64 + index as u64), | ||||
|             last_heartbeat: UnsafeCell::new(std::time::Instant::now()), | ||||
|         }); | ||||
| 
 | ||||
|         WORKER_THREAD_STATE.with(|cell| { | ||||
|             cell.set(&*worker); | ||||
| 
 | ||||
|             if let Some(callback) = pool.callbacks.at_entry.as_ref() { | ||||
|                 callback(worker); | ||||
|                 callback(&worker); | ||||
|             } | ||||
| 
 | ||||
|             info.notify_running(); | ||||
|             info.control.notify_running(); | ||||
|             // info.notify_running();
 | ||||
|             worker.run_until(&info.should_terminate); | ||||
|             worker.run_until(&info.control.should_terminate); | ||||
| 
 | ||||
|             if let Some(callback) = pool.callbacks.at_exit.as_ref() { | ||||
|                 callback(worker); | ||||
|                 callback(&worker); | ||||
|             } | ||||
| 
 | ||||
|             for task in worker.drain() { | ||||
|  | @ -1028,9 +1135,14 @@ impl WorkerThread { | |||
|             if let Some(task) = info.shoved_task.try_take() { | ||||
|                 pool.inject(task); | ||||
|             } | ||||
| 
 | ||||
|             cell.set(ptr::null()); | ||||
|         }); | ||||
| 
 | ||||
|         info.notify_termination(); | ||||
|         let WorkerThread { worker, .. } = CachePadded::into_inner(worker); | ||||
|         info.worker.store(Some(worker)); | ||||
| 
 | ||||
|         info.control.notify_termination(); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | @ -1061,56 +1173,68 @@ fn heartbeat_loop(pool: &'static ThreadPool) { | |||
|     state.notify_termination(); | ||||
| } | ||||
| 
 | ||||
| use vec_queue::TaskQueue; | ||||
| 
 | ||||
| mod vec_queue { | ||||
|     use std::{cell::UnsafeCell, collections::VecDeque}; | ||||
| 
 | ||||
|     pub struct TaskQueue<T>(UnsafeCell<VecDeque<T>>); | ||||
| 
 | ||||
|     impl<T> TaskQueue<T> { | ||||
|         /// Creates a new [`TaskQueue<T>`].
 | ||||
|         #[inline] | ||||
|     const fn new() -> Self { | ||||
|         pub const fn new() -> Self { | ||||
|             Self(UnsafeCell::new(VecDeque::new())) | ||||
|         } | ||||
|         #[inline] | ||||
|     fn get_mut(&self) -> &mut VecDeque<T> { | ||||
|         pub fn get_mut(&self) -> &mut VecDeque<T> { | ||||
|             unsafe { &mut *self.0.get() } | ||||
|         } | ||||
|         #[inline] | ||||
|     fn pop_front(&self) -> Option<T> { | ||||
|         pub fn pop_front(&self) -> Option<T> { | ||||
|             self.get_mut().pop_front() | ||||
|         } | ||||
|         #[inline] | ||||
|     fn pop_back(&self) -> Option<T> { | ||||
|         pub fn pop_back(&self) -> Option<T> { | ||||
|             self.get_mut().pop_back() | ||||
|         } | ||||
|         #[inline] | ||||
|     fn push_back(&self, t: T) { | ||||
|         pub fn push_back(&self, t: T) { | ||||
|             self.get_mut().push_back(t); | ||||
|         } | ||||
|         #[inline] | ||||
|     fn push_front(&self, t: T) { | ||||
|         pub fn push_front(&self, t: T) { | ||||
|             self.get_mut().push_front(t); | ||||
|         } | ||||
|         #[inline] | ||||
|     fn take(&self) -> VecDeque<T> { | ||||
|         pub fn take(&self) -> VecDeque<T> { | ||||
|             let this = core::mem::replace(self.get_mut(), VecDeque::new()); | ||||
|             this | ||||
|         } | ||||
|         #[inline] | ||||
|     fn drain(&self) -> impl Iterator<Item = T> { | ||||
|         pub fn drain(&self) -> impl Iterator<Item = T> { | ||||
|             self.take().into_iter() | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| bitflags! { | ||||
|     #[derive(Debug, Clone, Copy)] | ||||
|     pub struct SlotState: u8 { | ||||
|         const LOCKED = 1 << 1; | ||||
|         const OCCUPIED = 1 << 2; | ||||
| #[repr(u8)] | ||||
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||||
| enum SlotState { | ||||
|     None, | ||||
|     Locked, | ||||
|     Occupied, | ||||
| } | ||||
| 
 | ||||
| impl From<u8> for SlotState { | ||||
|     fn from(value: u8) -> Self { | ||||
|         unsafe { core::mem::transmute(value) } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<SlotState> for u8 { | ||||
|     fn from(value: SlotState) -> Self { | ||||
|         value.bits() | ||||
|         value as u8 | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | @ -1125,10 +1249,7 @@ unsafe impl<T> Sync for Slot<T> where T: Send {} | |||
| impl<T> Drop for Slot<T> { | ||||
|     fn drop(&mut self) { | ||||
|         if core::mem::needs_drop::<T>() { | ||||
|             if SlotState::from_bits(*self.state.get_mut()) | ||||
|                 .unwrap() | ||||
|                 .contains(SlotState::OCCUPIED) | ||||
|             { | ||||
|             if *self.state.get_mut() == SlotState::Occupied as u8 { | ||||
|                 unsafe { | ||||
|                     self.slot.get().drop_in_place(); | ||||
|                 } | ||||
|  | @ -1141,15 +1262,56 @@ impl<T> Slot<T> { | |||
|     pub const fn new() -> Slot<T> { | ||||
|         Self { | ||||
|             slot: UnsafeCell::new(MaybeUninit::uninit()), | ||||
|             state: AtomicU8::new(SlotState::empty().bits()), | ||||
|             state: AtomicU8::new(SlotState::None as u8), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn is_occupied(&self) -> bool { | ||||
|         self.state.load(Ordering::Acquire) == SlotState::Occupied.into() | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     pub fn insert(&self, t: T) -> Option<T> { | ||||
|         let value = match self | ||||
|             .state | ||||
|             .swap(SlotState::Locked.into(), Ordering::AcqRel) | ||||
|             .into() | ||||
|         { | ||||
|             SlotState::Locked => { | ||||
|                 // return early: was already locked.
 | ||||
|                 debug!("slot was already locked"); | ||||
|                 return None; | ||||
|             } | ||||
|             SlotState::Occupied => { | ||||
|                 let slot = self.slot.get(); | ||||
|                 // replace
 | ||||
|                 unsafe { | ||||
|                     let v = (*slot).assume_init_read(); | ||||
|                     (*slot).write(t); | ||||
|                     Some(v) | ||||
|                 } | ||||
|             } | ||||
|             SlotState::None => { | ||||
|                 let slot = self.slot.get(); | ||||
|                 // insert
 | ||||
|                 unsafe { | ||||
|                     (*slot).write(t); | ||||
|                 } | ||||
|                 None | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         // release lock
 | ||||
|         self.state | ||||
|             .store(SlotState::Occupied.into(), Ordering::Release); | ||||
|         value | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     pub fn try_put(&self, t: T) -> Option<T> { | ||||
|         match self.state.compare_exchange( | ||||
|             SlotState::empty().into(), | ||||
|             SlotState::LOCKED.into(), | ||||
|             SlotState::None.into(), | ||||
|             SlotState::Locked.into(), | ||||
|             Ordering::Acquire, | ||||
|             Ordering::Relaxed, | ||||
|         ) { | ||||
|  | @ -1161,7 +1323,7 @@ impl<T> Slot<T> { | |||
| 
 | ||||
|                 // release lock
 | ||||
|                 self.state | ||||
|                     .store(SlotState::OCCUPIED.into(), Ordering::Release); | ||||
|                     .store(SlotState::Occupied.into(), Ordering::Release); | ||||
|                 None | ||||
|             } | ||||
|         } | ||||
|  | @ -1170,8 +1332,8 @@ impl<T> Slot<T> { | |||
|     #[inline] | ||||
|     pub fn try_take(&self) -> Option<T> { | ||||
|         match self.state.compare_exchange( | ||||
|             SlotState::OCCUPIED.into(), | ||||
|             SlotState::LOCKED.into(), | ||||
|             SlotState::Occupied.into(), | ||||
|             SlotState::Locked.into(), | ||||
|             Ordering::Acquire, | ||||
|             Ordering::Relaxed, | ||||
|         ) { | ||||
|  | @ -1181,8 +1343,7 @@ impl<T> Slot<T> { | |||
|                 let t = unsafe { (*slot).assume_init_read() }; | ||||
| 
 | ||||
|                 // release lock
 | ||||
|                 self.state | ||||
|                     .store(SlotState::empty().into(), Ordering::Release); | ||||
|                 self.state.store(SlotState::None.into(), Ordering::Release); | ||||
|                 Some(t) | ||||
|             } | ||||
|             Err(_) => None, | ||||
|  | @ -1227,14 +1388,15 @@ mod scope { | |||
|     use std::{ | ||||
|         future::Future, | ||||
|         marker::{PhantomData, PhantomPinned}, | ||||
|         pin::pin, | ||||
|         ptr::{self, NonNull}, | ||||
|     }; | ||||
| 
 | ||||
|     use async_task::{Runnable, Task}; | ||||
| 
 | ||||
|     use crate::{ | ||||
|         latch::{CountWakeLatch, Latch}, | ||||
|         task::{HeapTask, TaskRef}, | ||||
|         latch::{CountWakeLatch, Latch, Probe, ThreadWakeLatch}, | ||||
|         task::{HeapTask, StackTask, TaskRef}, | ||||
|         ThreadPool, WorkerThread, | ||||
|     }; | ||||
| 
 | ||||
|  | @ -1253,6 +1415,16 @@ mod scope { | |||
|             } | ||||
|         } | ||||
| 
 | ||||
|         pub fn join<F, G, T, U>(&self, f: F, g: G) -> (T, U) | ||||
|         where | ||||
|             F: FnOnce(&Self) -> T + Send, | ||||
|             G: FnOnce(&Self) -> U + Send, | ||||
|             T: Send, | ||||
|             U: Send, | ||||
|         { | ||||
|             self.pool.join(|| f(self), || g(self)) | ||||
|         } | ||||
| 
 | ||||
|         pub fn spawn<Fn>(&self, f: Fn) | ||||
|         where | ||||
|             Fn: FnOnce(&Scope<'scope>) + Send + 'scope, | ||||
|  | @ -1267,7 +1439,7 @@ mod scope { | |||
|             }); | ||||
| 
 | ||||
|             let taskref = unsafe { task.into_task_ref() }; | ||||
|             self.pool.push_local_or_inject(taskref); | ||||
|             self.pool.inject_maybe_local(taskref); | ||||
|         } | ||||
| 
 | ||||
|         pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T> | ||||
|  | @ -1289,7 +1461,7 @@ mod scope { | |||
|                 }; | ||||
| 
 | ||||
|                 unsafe { | ||||
|                     ptr.as_ref().pool.push_local_or_inject(taskref); | ||||
|                     ptr.as_ref().pool.inject_maybe_local(taskref); | ||||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|  | @ -1332,8 +1504,58 @@ mod scope { | |||
| mod tests { | ||||
|     use std::{cell::Cell, hint::black_box}; | ||||
| 
 | ||||
|     use tracing::info; | ||||
| 
 | ||||
|     use super::*; | ||||
| 
 | ||||
|     mod tree { | ||||
| 
 | ||||
|         pub struct Tree<T> { | ||||
|             nodes: Box<[Node<T>]>, | ||||
|             root: Option<usize>, | ||||
|         } | ||||
|         pub struct Node<T> { | ||||
|             pub leaf: T, | ||||
|             pub left: Option<usize>, | ||||
|             pub right: Option<usize>, | ||||
|         } | ||||
| 
 | ||||
|         impl<T> Tree<T> { | ||||
|             pub fn new(depth: usize, t: T) -> Tree<T> | ||||
|             where | ||||
|                 T: Copy, | ||||
|             { | ||||
|                 let mut nodes = Vec::with_capacity((0..depth).sum()); | ||||
|                 let root = Self::build_node(&mut nodes, depth, t); | ||||
|                 Self { | ||||
|                     nodes: nodes.into_boxed_slice(), | ||||
|                     root: Some(root), | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             pub fn root(&self) -> Option<usize> { | ||||
|                 self.root | ||||
|             } | ||||
| 
 | ||||
|             pub fn get(&self, index: usize) -> &Node<T> { | ||||
|                 &self.nodes[index] | ||||
|             } | ||||
| 
 | ||||
|             pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize | ||||
|             where | ||||
|                 T: Copy, | ||||
|             { | ||||
|                 let node = Node { | ||||
|                     leaf: t, | ||||
|                     left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), | ||||
|                     right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)), | ||||
|                 }; | ||||
|                 nodes.push(node); | ||||
|                 nodes.len() - 1 | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     const PRIMES: &'static [usize] = &[ | ||||
|         1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, | ||||
|         1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399, 1409, | ||||
|  | @ -1344,9 +1566,14 @@ mod tests { | |||
|         1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, | ||||
|     ]; | ||||
| 
 | ||||
|     const REPEAT: usize = 0x100; | ||||
|     #[cfg(feature = "spin-slow")] | ||||
|     const REPEAT: usize = 0x800; | ||||
|     #[cfg(not(feature = "spin-slow"))] | ||||
|     const REPEAT: usize = 0x8000; | ||||
| 
 | ||||
|     fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T { | ||||
|     const TREE_SIZE: usize = 10; | ||||
| 
 | ||||
|     fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(&Scope<'_>) -> T + Send) -> T { | ||||
|         let pool = Box::new(pool); | ||||
|         let ptr = Box::into_raw(pool); | ||||
| 
 | ||||
|  | @ -1357,9 +1584,9 @@ mod tests { | |||
|             let now = std::time::Instant::now(); | ||||
|             let result = pool.scope(f); | ||||
|             let elapsed = now.elapsed().as_micros(); | ||||
|             eprintln!("(mine) total time: {}ms", elapsed as f32 / 1e3); | ||||
|             info!("(mine) total time: {}ms", elapsed as f32 / 1e3); | ||||
|             pool.resize_to(0); | ||||
|             assert!(pool.global_queue.pop().is_none()); | ||||
|             assert!(pool.global_queue.is_empty()); | ||||
|             result | ||||
|         }; | ||||
| 
 | ||||
|  | @ -1385,7 +1612,38 @@ mod tests { | |||
|         }); | ||||
|         let elapsed = now.elapsed().as_micros(); | ||||
| 
 | ||||
|         eprintln!("(rayon) total time: {}ms", elapsed as f32 / 1e3); | ||||
|         info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     #[tracing_test::traced_test] | ||||
|     fn rayon_join() { | ||||
|         let pool = rayon::ThreadPoolBuilder::new() | ||||
|             .num_threads(bevy_tasks::available_parallelism()) | ||||
|             .build() | ||||
|             .unwrap(); | ||||
| 
 | ||||
|         let tree = tree::Tree::new(TREE_SIZE, 1u32); | ||||
| 
 | ||||
|         fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 { | ||||
|             let node = tree.get(node); | ||||
|             let (l, r) = rayon::join( | ||||
|                 || node.left.map(|node| sum(tree, node)).unwrap_or_default(), | ||||
|                 || node.right.map(|node| sum(tree, node)).unwrap_or_default(), | ||||
|             ); | ||||
| 
 | ||||
|             node.leaf + l + r | ||||
|         } | ||||
| 
 | ||||
|         let now = std::time::Instant::now(); | ||||
|         let sum = pool.scope(move |s| { | ||||
|             let root = tree.root().unwrap(); | ||||
|             sum(&tree, root) | ||||
|         }); | ||||
| 
 | ||||
|         let elapsed = now.elapsed().as_micros(); | ||||
| 
 | ||||
|         info!("(rayon) total time: {}ms", elapsed as f32 / 1e3); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|  | @ -1407,7 +1665,7 @@ mod tests { | |||
|         }); | ||||
|         let elapsed = now.elapsed().as_micros(); | ||||
| 
 | ||||
|         eprintln!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); | ||||
|         info!("(bevy_tasks) total time: {}ms", elapsed as f32 / 1e3); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|  | @ -1418,18 +1676,7 @@ mod tests { | |||
|         } | ||||
|         let counter = Arc::new(AtomicUsize::new(0)); | ||||
|         { | ||||
|             let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks { | ||||
|                 at_entry: Some(Arc::new(|_worker| { | ||||
|                     // eprintln!("new worker thread: {}", worker.index);
 | ||||
|                 })), | ||||
|                 at_exit: Some(Arc::new({ | ||||
|                     let counter = counter.clone(); | ||||
|                     move |_worker: &WorkerThread| { | ||||
|                         // eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get());
 | ||||
|                         counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed); | ||||
|                     } | ||||
|                 })), | ||||
|             }); | ||||
|             let pool = ThreadPool::new(); | ||||
| 
 | ||||
|             run_in_scope(pool, |s| { | ||||
|                 for &p in core::iter::repeat_n(PRIMES, REPEAT).flatten() { | ||||
|  | @ -1443,6 +1690,33 @@ mod tests { | |||
|         // eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
 | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     #[tracing_test::traced_test] | ||||
|     fn mine_join() { | ||||
|         let pool = ThreadPool::new(); | ||||
| 
 | ||||
|         let tree = tree::Tree::new(TREE_SIZE, 1u32); | ||||
| 
 | ||||
|         fn sum(tree: &tree::Tree<u32>, node: usize, scope: &Scope<'_>) -> u32 { | ||||
|             let node = tree.get(node); | ||||
|             let (l, r) = scope.join( | ||||
|                 |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(), | ||||
|                 |s| { | ||||
|                     node.right | ||||
|                         .map(|node| sum(tree, node, s)) | ||||
|                         .unwrap_or_default() | ||||
|                 }, | ||||
|             ); | ||||
| 
 | ||||
|             node.leaf + l + r | ||||
|         } | ||||
| 
 | ||||
|         let sum = run_in_scope(pool, move |s| { | ||||
|             let root = tree.root().unwrap(); | ||||
|             sum(&tree, root, s) | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     #[tracing_test::traced_test] | ||||
|     fn sync() { | ||||
|  | @ -1452,11 +1726,19 @@ mod tests { | |||
|         } | ||||
|         let elapsed = now.elapsed().as_micros(); | ||||
| 
 | ||||
|         eprintln!("(sync) total time: {}ms", elapsed as f32 / 1e3); | ||||
|         info!("(sync) total time: {}ms", elapsed as f32 / 1e3); | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn spinning(i: usize) { | ||||
|         #[cfg(feature = "spin-slow")] | ||||
|         spinning_slow(i); | ||||
|         #[cfg(not(feature = "spin-slow"))] | ||||
|         spinning_fast(i); | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn spinning_slow(i: usize) { | ||||
|         let rng = rng::XorShift64Star::new(i as u64); | ||||
|         (0..i).reduce(|a, b| { | ||||
|             black_box({ | ||||
|  | @ -1465,4 +1747,16 @@ mod tests { | |||
|             }) | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn spinning_fast(i: usize) { | ||||
|         let rng = rng::XorShift64Star::new(i as u64); | ||||
|         //(0..rng.next_usize(i)).reduce(|a, b| {
 | ||||
|         (0..20).reduce(|a, b| { | ||||
|             black_box({ | ||||
|                 let a = rng.next_usize(a.max(1)); | ||||
|                 a ^ b | ||||
|             }) | ||||
|         }); | ||||
|     } | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue