This commit is contained in:
Janis 2025-01-31 00:19:57 +01:00
parent e75094d2a5
commit fd0cd86a8d
2 changed files with 295 additions and 35 deletions

View file

@ -23,4 +23,8 @@ tracing-subscriber = {version ="0.3.18", features = ["env-filter"]}
anyhow = "1.0.89" anyhow = "1.0.89"
thiserror = "2.0" thiserror = "2.0"
bitflags = "2.6" bitflags = "2.6"
core_affinity = "0.8.1"
# derive_more = "1.0.0" # derive_more = "1.0.0"
[dev-dependencies]
tracing-test = "0.2.5"

View file

@ -20,7 +20,9 @@ use bitflags::bitflags;
use crossbeam::{queue::SegQueue, utils::CachePadded}; use crossbeam::{queue::SegQueue, utils::CachePadded};
use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch}; use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use scope::Scope;
use task::{HeapTask, StackTask, TaskRef}; use task::{HeapTask, StackTask, TaskRef};
use tracing::debug;
pub mod task { pub mod task {
use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin};
@ -78,6 +80,9 @@ pub mod task {
pub fn run(self) { pub fn run(self) {
self.task.into_inner().unwrap()(); self.task.into_inner().unwrap()();
} }
pub unsafe fn run_as_ref(&self) {
((&mut *self.task.get()).take().unwrap())();
}
pub fn as_task_ref(self: Pin<&Self>) -> TaskRef { pub fn as_task_ref(self: Pin<&Self>) -> TaskRef {
unsafe { TaskRef::new(&*self) } unsafe { TaskRef::new(&*self) }
@ -380,23 +385,39 @@ impl ThreadState {
self.status_changed.notify_all(); self.status_changed.notify_all();
} }
fn set_should_terminate(&self) { fn notify_should_terminate(&self) {
unsafe { unsafe {
Latch::set_raw(&self.should_terminate); Latch::set_raw(&self.should_terminate);
} }
self.wake();
} }
} }
const MAX_THREADS: usize = 32; const MAX_THREADS: usize = 32;
type ThreadCallback = dyn Fn(&WorkerThread) + Send + Sync + 'static;
pub struct ThreadPoolCallbacks {
at_entry: Option<Arc<ThreadCallback>>,
at_exit: Option<Arc<ThreadCallback>>,
}
impl ThreadPoolCallbacks {
pub const fn new_empty() -> ThreadPoolCallbacks {
Self {
at_entry: None,
at_exit: None,
}
}
}
pub struct ThreadPool { pub struct ThreadPool {
threads: [CachePadded<ThreadState>; MAX_THREADS], threads: [CachePadded<ThreadState>; MAX_THREADS],
pool_state: CachePadded<ThreadPoolState>, pool_state: CachePadded<ThreadPoolState>,
global_queue: SegQueue<TaskRef>, global_queue: SegQueue<TaskRef>,
callbacks: CachePadded<ThreadPoolCallbacks>,
} }
impl ThreadPool { impl ThreadPool {
pub const fn new() -> Self {
const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState { const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState {
should_shove: AtomicBool::new(false), should_shove: AtomicBool::new(false),
shoved_task: Slot::new(), shoved_task: Slot::new(),
@ -404,15 +425,29 @@ impl ThreadPool {
status_changed: Condvar::new(), status_changed: Condvar::new(),
should_terminate: AtomicLatch::new(), should_terminate: AtomicLatch::new(),
}); });
pub const fn new() -> Self {
Self { Self {
threads: const { [INITIAL_THREAD_STATE; MAX_THREADS] }, threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
pool_state: CachePadded::new(ThreadPoolState { pool_state: CachePadded::new(ThreadPoolState {
num_threads: AtomicUsize::new(0), num_threads: AtomicUsize::new(0),
lock: Mutex::new(()), lock: Mutex::new(()),
heartbeat_state: INITIAL_THREAD_STATE, heartbeat_state: Self::INITIAL_THREAD_STATE,
}), }),
global_queue: SegQueue::new(), global_queue: SegQueue::new(),
callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()),
}
}
pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
Self {
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
pool_state: CachePadded::new(ThreadPoolState {
num_threads: AtomicUsize::new(0),
lock: Mutex::new(()),
heartbeat_state: Self::INITIAL_THREAD_STATE,
}),
global_queue: SegQueue::new(),
callbacks: CachePadded::new(callbacks),
} }
} }
@ -470,13 +505,15 @@ impl ThreadPool {
fn resize<F: Fn(usize) -> usize>(&'static self, size: F) -> usize { fn resize<F: Fn(usize) -> usize>(&'static self, size: F) -> usize {
if WorkerThread::is_worker_thread() { if WorkerThread::is_worker_thread() {
// acquire required here? // acquire required here?
debug!("tried to resize from within threadpool!");
return self.pool_state.num_threads.load(Ordering::Acquire); return self.pool_state.num_threads.load(Ordering::Acquire);
} }
let _guard = self.pool_state.lock.lock(); let _guard = self.pool_state.lock.lock();
let current_size = self.pool_state.num_threads.load(Ordering::Acquire); let current_size = self.pool_state.num_threads.load(Ordering::Acquire);
let new_size = size(current_size).max(MAX_THREADS); let new_size = size(current_size).clamp(0, MAX_THREADS);
debug!(current_size, new_size, "resizing threadpool");
if new_size == current_size { if new_size == current_size {
return current_size; return current_size;
@ -490,7 +527,7 @@ impl ThreadPool {
std::cmp::Ordering::Greater => { std::cmp::Ordering::Greater => {
let new_threads = &self.threads[current_size..new_size]; let new_threads = &self.threads[current_size..new_size];
for (i, thread) in new_threads.iter().enumerate() { for (i, _) in new_threads.iter().enumerate() {
std::thread::spawn(move || { std::thread::spawn(move || {
WorkerThread::worker_loop(&self, current_size + i); WorkerThread::worker_loop(&self, current_size + i);
}); });
@ -510,10 +547,14 @@ impl ThreadPool {
} }
} }
std::cmp::Ordering::Less => { std::cmp::Ordering::Less => {
debug!(
"waiting for threads {:?} to terminate.",
new_size..current_size
);
let terminating_threads = &self.threads[new_size..current_size]; let terminating_threads = &self.threads[new_size..current_size];
for thread in terminating_threads { for thread in terminating_threads {
thread.set_should_terminate(); thread.notify_should_terminate();
} }
for thread in terminating_threads { for thread in terminating_threads {
thread.wait_for_termination(); thread.wait_for_termination();
@ -521,7 +562,7 @@ impl ThreadPool {
#[cfg(not(feature = "internal_heartbeat"))] #[cfg(not(feature = "internal_heartbeat"))]
if new_size == 0 { if new_size == 0 {
self.pool_state.heartbeat_state.set_should_terminate(); self.pool_state.heartbeat_state.notify_should_terminate();
self.pool_state.heartbeat_state.wait_for_termination(); self.pool_state.heartbeat_state.wait_for_termination();
} }
} }
@ -699,6 +740,66 @@ impl ThreadPool {
} }
}); });
} }
fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
where
F: FnOnce() -> T + Send,
G: FnOnce() -> U + Send,
T: Send,
U: Send,
{
self.in_worker(|worker, _| {
let mut result_b = None;
let latch_b = ThreadWakeLatch::new(worker);
let task_b = pin!(StackTask::new(|| {
result_b = Some(g());
unsafe {
Latch::set_raw(&latch_b);
}
}));
let ref_b = task_b.as_ref().as_task_ref();
let b_id = ref_b.id();
// TODO: maybe try to push this off to another thread immediately first?
worker.push_task(ref_b);
let result_a = f();
while !latch_b.probe() {
match worker.pop_task() {
Some(task) => {
if task.id() == b_id {
worker.try_promote();
unsafe {
task_b.run_as_ref();
}
break;
}
worker.execute(task);
}
None => {
worker.run_until(&latch_b);
}
}
}
(result_a, result_b.unwrap())
})
}
fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T
where
Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send,
T: Send,
{
self.in_worker(|owner, _| {
let scope = pin!(unsafe { Scope::<'scope>::new(owner) });
let result = f(scope.as_ref());
scope.complete(owner);
result
})
}
} }
pub struct WorkerThread { pub struct WorkerThread {
@ -719,6 +820,12 @@ impl WorkerThread {
fn info(&self) -> &ThreadState { fn info(&self) -> &ThreadState {
&self.pool.threads[self.index as usize] &self.pool.threads[self.index as usize]
} }
fn pool(&self) -> &'static ThreadPool {
self.pool
}
fn index(&self) -> usize {
self.index
}
fn is_worker_thread() -> bool { fn is_worker_thread() -> bool {
Self::with(|worker| worker.is_some()) Self::with(|worker| worker.is_some())
} }
@ -758,13 +865,12 @@ impl WorkerThread {
match self.info().shoved_task.try_put(task) { match self.info().shoved_task.try_put(task) {
// shoved task is occupied, reinsert into queue // shoved task is occupied, reinsert into queue
Some(task) => self.queue.push_back(task), Some(task) => self.queue.push_back(task),
None => { None => {}
}
// wake thread to execute task // wake thread to execute task
self.pool.wake_any(1); self.pool.wake_any(1);
} }
} }
}
}
fn execute(&self, task: TaskRef) { fn execute(&self, task: TaskRef) {
self.try_promote(); self.try_promote();
@ -843,10 +949,18 @@ impl WorkerThread {
last_heartbeat: UnsafeCell::new(std::time::Instant::now()), last_heartbeat: UnsafeCell::new(std::time::Instant::now()),
}); });
if let Some(callback) = pool.callbacks.at_entry.as_ref() {
callback(worker);
}
info.notify_running(); info.notify_running();
// info.notify_running(); // info.notify_running();
worker.run_until(&info.should_terminate); worker.run_until(&info.should_terminate);
if let Some(callback) = pool.callbacks.at_exit.as_ref() {
callback(worker);
}
for task in worker.drain() { for task in worker.drain() {
pool.inject(task); pool.inject(task);
} }
@ -1039,9 +1153,116 @@ mod rng {
} }
} }
mod scope {
use std::{
future::Future,
marker::{PhantomData, PhantomPinned},
ptr::{self, NonNull},
};
use async_task::{Runnable, Task};
use crate::{
latch::{CountWakeLatch, Latch},
task::{HeapTask, TaskRef},
ThreadPool, WorkerThread,
};
pub struct Scope<'scope> {
pool: &'static ThreadPool,
tasks_completed_latch: CountWakeLatch,
_marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
}
impl<'scope> Scope<'scope> {
pub unsafe fn new(owner: &WorkerThread) -> Scope<'scope> {
Scope {
pool: owner.pool(),
tasks_completed_latch: CountWakeLatch::new(1, owner),
_marker: PhantomData,
}
}
pub fn spawn<Fn>(&self, f: Fn)
where
Fn: FnOnce(&Scope<'scope>) + Send + 'scope,
{
self.tasks_completed_latch.increment();
let ptr = SendPtr::from_ref(self);
let task = HeapTask::new(move || unsafe {
let this = ptr.as_ref();
f(this);
Latch::set_raw(&this.tasks_completed_latch);
});
let taskref = unsafe { task.into_task_ref() };
self.pool.push_local_or_inject(taskref);
}
pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T>
where
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.tasks_completed_latch.increment();
let ptr = SendPtr::from_ref(self);
let schedule = move |runnable: Runnable| {
let taskref = unsafe {
TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| {
let this = NonNull::new_unchecked(this.cast_mut());
let runnable = Runnable::<()>::from_raw(this);
runnable.run();
})
};
unsafe {
ptr.as_ref().pool.push_local_or_inject(taskref);
}
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
pub fn spawn_async<Fn, Fut, T>(&self, f: Fn) -> Task<T>
where
Fn: FnOnce() -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.spawn_future(async move { f().await })
}
pub fn complete(&self, owner: &WorkerThread) {
unsafe {
Latch::set_raw(&self.tasks_completed_latch);
}
owner.run_until(&self.tasks_completed_latch);
}
}
struct SendPtr<T>(*const T);
impl<T> SendPtr<T> {
fn from_ref(t: &T) -> Self {
Self(ptr::from_ref(t).cast())
}
unsafe fn as_ref(&self) -> &T {
&*self.0
}
}
unsafe impl<T> Send for SendPtr<T> {}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::cell::Cell; use std::{cell::Cell, hint::black_box};
use crate::latch::CountWakeLatch;
use super::*; use super::*;
@ -1055,30 +1276,65 @@ mod tests {
1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907, 1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
]; ];
fn run_on_static_pool(f: impl FnOnce(&'static ThreadPool)) { fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T {
let pool = Box::new(ThreadPool::new()); let pool = Box::new(pool);
let ptr = Box::into_raw(pool); let ptr = Box::into_raw(pool);
{ let result = {
let pool: &'static ThreadPool = unsafe { &*ptr }; let pool: &'static ThreadPool = unsafe { &*ptr };
pool.ensure_one_worker(); // pool.ensure_one_worker();
f(pool); pool.resize_to_available();
let result = pool.scope(f);
pool.resize_to(0); pool.resize_to(0);
assert!(pool.global_queue.pop().is_none()); assert!(pool.global_queue.pop().is_none());
} result
};
let _pool = unsafe { Box::from_raw(ptr) }; let _pool = unsafe { Box::from_raw(ptr) };
result
} }
#[test] #[test]
#[tracing_test::traced_test]
fn spawn_random() { fn spawn_random() {
std::thread_local! {static WAIT_COUNT: Cell<usize>= Cell::new(0);} std::thread_local! {
run_on_static_pool(|pool| { static WAIT_COUNT: Cell<usize> = const {Cell::new(0)};
for &p in PRIMES { }
pool.spawn(move || { let counter = Arc::new(AtomicUsize::new(0));
std::thread::sleep(Duration::from_micros(p as u64)); let elapsed = {
let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks {
at_entry: Some(Arc::new(|_worker| {
// eprintln!("new worker thread: {}", worker.index);
})),
at_exit: Some(Arc::new({
let counter = counter.clone();
move |_worker: &WorkerThread| {
// eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get());
counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed);
}
})),
});
let now = std::time::Instant::now();
run_in_scope(pool, |s| {
for &p in core::iter::repeat_n(PRIMES, 0x1000).flatten() {
s.spawn(move |_| {
// std::thread::sleep(Duration::from_micros(p as u64));
// spin for
let tmp = (0..p).reduce(|a, b| black_box(a & b));
black_box(tmp);
// WAIT_COUNT.with(|count| {
// // eprintln!("{} + {p}", count.get());
// count.set(count.get() + p);
// });
}); });
} }
}); });
now.elapsed().as_micros()
};
eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
eprintln!("total time: {}ms", elapsed as f32 / 1e3);
} }
} }