scope
This commit is contained in:
parent
e75094d2a5
commit
fd0cd86a8d
|
@ -23,4 +23,8 @@ tracing-subscriber = {version ="0.3.18", features = ["env-filter"]}
|
|||
anyhow = "1.0.89"
|
||||
thiserror = "2.0"
|
||||
bitflags = "2.6"
|
||||
# derive_more = "1.0.0"
|
||||
core_affinity = "0.8.1"
|
||||
# derive_more = "1.0.0"
|
||||
|
||||
[dev-dependencies]
|
||||
tracing-test = "0.2.5"
|
||||
|
|
324
src/lib.rs
324
src/lib.rs
|
@ -20,7 +20,9 @@ use bitflags::bitflags;
|
|||
use crossbeam::{queue::SegQueue, utils::CachePadded};
|
||||
use latch::{AtomicLatch, Latch, LatchWaker, MutexLatch, Probe, ThreadWakeLatch};
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use scope::Scope;
|
||||
use task::{HeapTask, StackTask, TaskRef};
|
||||
use tracing::debug;
|
||||
|
||||
pub mod task {
|
||||
use std::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin};
|
||||
|
@ -78,6 +80,9 @@ pub mod task {
|
|||
pub fn run(self) {
|
||||
self.task.into_inner().unwrap()();
|
||||
}
|
||||
pub unsafe fn run_as_ref(&self) {
|
||||
((&mut *self.task.get()).take().unwrap())();
|
||||
}
|
||||
|
||||
pub fn as_task_ref(self: Pin<&Self>) -> TaskRef {
|
||||
unsafe { TaskRef::new(&*self) }
|
||||
|
@ -380,39 +385,69 @@ impl ThreadState {
|
|||
self.status_changed.notify_all();
|
||||
}
|
||||
|
||||
fn set_should_terminate(&self) {
|
||||
fn notify_should_terminate(&self) {
|
||||
unsafe {
|
||||
Latch::set_raw(&self.should_terminate);
|
||||
}
|
||||
self.wake();
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_THREADS: usize = 32;
|
||||
|
||||
type ThreadCallback = dyn Fn(&WorkerThread) + Send + Sync + 'static;
|
||||
pub struct ThreadPoolCallbacks {
|
||||
at_entry: Option<Arc<ThreadCallback>>,
|
||||
at_exit: Option<Arc<ThreadCallback>>,
|
||||
}
|
||||
|
||||
impl ThreadPoolCallbacks {
|
||||
pub const fn new_empty() -> ThreadPoolCallbacks {
|
||||
Self {
|
||||
at_entry: None,
|
||||
at_exit: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ThreadPool {
|
||||
threads: [CachePadded<ThreadState>; MAX_THREADS],
|
||||
pool_state: CachePadded<ThreadPoolState>,
|
||||
global_queue: SegQueue<TaskRef>,
|
||||
callbacks: CachePadded<ThreadPoolCallbacks>,
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState {
|
||||
should_shove: AtomicBool::new(false),
|
||||
shoved_task: Slot::new(),
|
||||
status: Mutex::new(ThreadStatus::empty()),
|
||||
status_changed: Condvar::new(),
|
||||
should_terminate: AtomicLatch::new(),
|
||||
});
|
||||
pub const fn new() -> Self {
|
||||
const INITIAL_THREAD_STATE: CachePadded<ThreadState> = CachePadded::new(ThreadState {
|
||||
should_shove: AtomicBool::new(false),
|
||||
shoved_task: Slot::new(),
|
||||
status: Mutex::new(ThreadStatus::empty()),
|
||||
status_changed: Condvar::new(),
|
||||
should_terminate: AtomicLatch::new(),
|
||||
});
|
||||
|
||||
Self {
|
||||
threads: const { [INITIAL_THREAD_STATE; MAX_THREADS] },
|
||||
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
|
||||
pool_state: CachePadded::new(ThreadPoolState {
|
||||
num_threads: AtomicUsize::new(0),
|
||||
lock: Mutex::new(()),
|
||||
heartbeat_state: INITIAL_THREAD_STATE,
|
||||
heartbeat_state: Self::INITIAL_THREAD_STATE,
|
||||
}),
|
||||
global_queue: SegQueue::new(),
|
||||
callbacks: CachePadded::new(ThreadPoolCallbacks::new_empty()),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn new_with_callbacks(callbacks: ThreadPoolCallbacks) -> ThreadPool {
|
||||
Self {
|
||||
threads: const { [Self::INITIAL_THREAD_STATE; MAX_THREADS] },
|
||||
pool_state: CachePadded::new(ThreadPoolState {
|
||||
num_threads: AtomicUsize::new(0),
|
||||
lock: Mutex::new(()),
|
||||
heartbeat_state: Self::INITIAL_THREAD_STATE,
|
||||
}),
|
||||
global_queue: SegQueue::new(),
|
||||
callbacks: CachePadded::new(callbacks),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -470,13 +505,15 @@ impl ThreadPool {
|
|||
fn resize<F: Fn(usize) -> usize>(&'static self, size: F) -> usize {
|
||||
if WorkerThread::is_worker_thread() {
|
||||
// acquire required here?
|
||||
debug!("tried to resize from within threadpool!");
|
||||
return self.pool_state.num_threads.load(Ordering::Acquire);
|
||||
}
|
||||
|
||||
let _guard = self.pool_state.lock.lock();
|
||||
|
||||
let current_size = self.pool_state.num_threads.load(Ordering::Acquire);
|
||||
let new_size = size(current_size).max(MAX_THREADS);
|
||||
let new_size = size(current_size).clamp(0, MAX_THREADS);
|
||||
debug!(current_size, new_size, "resizing threadpool");
|
||||
|
||||
if new_size == current_size {
|
||||
return current_size;
|
||||
|
@ -490,7 +527,7 @@ impl ThreadPool {
|
|||
std::cmp::Ordering::Greater => {
|
||||
let new_threads = &self.threads[current_size..new_size];
|
||||
|
||||
for (i, thread) in new_threads.iter().enumerate() {
|
||||
for (i, _) in new_threads.iter().enumerate() {
|
||||
std::thread::spawn(move || {
|
||||
WorkerThread::worker_loop(&self, current_size + i);
|
||||
});
|
||||
|
@ -510,10 +547,14 @@ impl ThreadPool {
|
|||
}
|
||||
}
|
||||
std::cmp::Ordering::Less => {
|
||||
debug!(
|
||||
"waiting for threads {:?} to terminate.",
|
||||
new_size..current_size
|
||||
);
|
||||
let terminating_threads = &self.threads[new_size..current_size];
|
||||
|
||||
for thread in terminating_threads {
|
||||
thread.set_should_terminate();
|
||||
thread.notify_should_terminate();
|
||||
}
|
||||
for thread in terminating_threads {
|
||||
thread.wait_for_termination();
|
||||
|
@ -521,7 +562,7 @@ impl ThreadPool {
|
|||
|
||||
#[cfg(not(feature = "internal_heartbeat"))]
|
||||
if new_size == 0 {
|
||||
self.pool_state.heartbeat_state.set_should_terminate();
|
||||
self.pool_state.heartbeat_state.notify_should_terminate();
|
||||
self.pool_state.heartbeat_state.wait_for_termination();
|
||||
}
|
||||
}
|
||||
|
@ -699,6 +740,66 @@ impl ThreadPool {
|
|||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn join<F, G, T, U>(&'static self, f: F, g: G) -> (T, U)
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
G: FnOnce() -> U + Send,
|
||||
T: Send,
|
||||
U: Send,
|
||||
{
|
||||
self.in_worker(|worker, _| {
|
||||
let mut result_b = None;
|
||||
|
||||
let latch_b = ThreadWakeLatch::new(worker);
|
||||
let task_b = pin!(StackTask::new(|| {
|
||||
result_b = Some(g());
|
||||
unsafe {
|
||||
Latch::set_raw(&latch_b);
|
||||
}
|
||||
}));
|
||||
|
||||
let ref_b = task_b.as_ref().as_task_ref();
|
||||
let b_id = ref_b.id();
|
||||
// TODO: maybe try to push this off to another thread immediately first?
|
||||
worker.push_task(ref_b);
|
||||
|
||||
let result_a = f();
|
||||
|
||||
while !latch_b.probe() {
|
||||
match worker.pop_task() {
|
||||
Some(task) => {
|
||||
if task.id() == b_id {
|
||||
worker.try_promote();
|
||||
unsafe {
|
||||
task_b.run_as_ref();
|
||||
}
|
||||
break;
|
||||
}
|
||||
worker.execute(task);
|
||||
}
|
||||
None => {
|
||||
worker.run_until(&latch_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(result_a, result_b.unwrap())
|
||||
})
|
||||
}
|
||||
|
||||
fn scope<'scope, Fn, T>(&'static self, f: Fn) -> T
|
||||
where
|
||||
Fn: FnOnce(Pin<&Scope<'scope>>) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
self.in_worker(|owner, _| {
|
||||
let scope = pin!(unsafe { Scope::<'scope>::new(owner) });
|
||||
let result = f(scope.as_ref());
|
||||
scope.complete(owner);
|
||||
result
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WorkerThread {
|
||||
|
@ -719,6 +820,12 @@ impl WorkerThread {
|
|||
fn info(&self) -> &ThreadState {
|
||||
&self.pool.threads[self.index as usize]
|
||||
}
|
||||
fn pool(&self) -> &'static ThreadPool {
|
||||
self.pool
|
||||
}
|
||||
fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
fn is_worker_thread() -> bool {
|
||||
Self::with(|worker| worker.is_some())
|
||||
}
|
||||
|
@ -758,11 +865,10 @@ impl WorkerThread {
|
|||
match self.info().shoved_task.try_put(task) {
|
||||
// shoved task is occupied, reinsert into queue
|
||||
Some(task) => self.queue.push_back(task),
|
||||
None => {
|
||||
// wake thread to execute task
|
||||
self.pool.wake_any(1);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
// wake thread to execute task
|
||||
self.pool.wake_any(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -843,10 +949,18 @@ impl WorkerThread {
|
|||
last_heartbeat: UnsafeCell::new(std::time::Instant::now()),
|
||||
});
|
||||
|
||||
if let Some(callback) = pool.callbacks.at_entry.as_ref() {
|
||||
callback(worker);
|
||||
}
|
||||
|
||||
info.notify_running();
|
||||
// info.notify_running();
|
||||
worker.run_until(&info.should_terminate);
|
||||
|
||||
if let Some(callback) = pool.callbacks.at_exit.as_ref() {
|
||||
callback(worker);
|
||||
}
|
||||
|
||||
for task in worker.drain() {
|
||||
pool.inject(task);
|
||||
}
|
||||
|
@ -1039,9 +1153,116 @@ mod rng {
|
|||
}
|
||||
}
|
||||
|
||||
mod scope {
|
||||
use std::{
|
||||
future::Future,
|
||||
marker::{PhantomData, PhantomPinned},
|
||||
ptr::{self, NonNull},
|
||||
};
|
||||
|
||||
use async_task::{Runnable, Task};
|
||||
|
||||
use crate::{
|
||||
latch::{CountWakeLatch, Latch},
|
||||
task::{HeapTask, TaskRef},
|
||||
ThreadPool, WorkerThread,
|
||||
};
|
||||
|
||||
pub struct Scope<'scope> {
|
||||
pool: &'static ThreadPool,
|
||||
tasks_completed_latch: CountWakeLatch,
|
||||
_marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
|
||||
}
|
||||
|
||||
impl<'scope> Scope<'scope> {
|
||||
pub unsafe fn new(owner: &WorkerThread) -> Scope<'scope> {
|
||||
Scope {
|
||||
pool: owner.pool(),
|
||||
tasks_completed_latch: CountWakeLatch::new(1, owner),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spawn<Fn>(&self, f: Fn)
|
||||
where
|
||||
Fn: FnOnce(&Scope<'scope>) + Send + 'scope,
|
||||
{
|
||||
self.tasks_completed_latch.increment();
|
||||
|
||||
let ptr = SendPtr::from_ref(self);
|
||||
let task = HeapTask::new(move || unsafe {
|
||||
let this = ptr.as_ref();
|
||||
f(this);
|
||||
Latch::set_raw(&this.tasks_completed_latch);
|
||||
});
|
||||
|
||||
let taskref = unsafe { task.into_task_ref() };
|
||||
self.pool.push_local_or_inject(taskref);
|
||||
}
|
||||
|
||||
pub fn spawn_future<Fut, T>(&self, future: Fut) -> Task<T>
|
||||
where
|
||||
Fut: Future<Output = T> + Send + 'scope,
|
||||
T: Send + 'scope,
|
||||
{
|
||||
self.tasks_completed_latch.increment();
|
||||
|
||||
let ptr = SendPtr::from_ref(self);
|
||||
let schedule = move |runnable: Runnable| {
|
||||
let taskref = unsafe {
|
||||
TaskRef::new_raw(runnable.into_raw().as_ptr(), |this| {
|
||||
let this = NonNull::new_unchecked(this.cast_mut());
|
||||
|
||||
let runnable = Runnable::<()>::from_raw(this);
|
||||
runnable.run();
|
||||
})
|
||||
};
|
||||
|
||||
unsafe {
|
||||
ptr.as_ref().pool.push_local_or_inject(taskref);
|
||||
}
|
||||
};
|
||||
|
||||
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
||||
|
||||
runnable.schedule();
|
||||
task
|
||||
}
|
||||
|
||||
pub fn spawn_async<Fn, Fut, T>(&self, f: Fn) -> Task<T>
|
||||
where
|
||||
Fn: FnOnce() -> Fut + Send + 'scope,
|
||||
Fut: Future<Output = T> + Send + 'scope,
|
||||
T: Send + 'scope,
|
||||
{
|
||||
self.spawn_future(async move { f().await })
|
||||
}
|
||||
|
||||
pub fn complete(&self, owner: &WorkerThread) {
|
||||
unsafe {
|
||||
Latch::set_raw(&self.tasks_completed_latch);
|
||||
}
|
||||
owner.run_until(&self.tasks_completed_latch);
|
||||
}
|
||||
}
|
||||
|
||||
struct SendPtr<T>(*const T);
|
||||
impl<T> SendPtr<T> {
|
||||
fn from_ref(t: &T) -> Self {
|
||||
Self(ptr::from_ref(t).cast())
|
||||
}
|
||||
unsafe fn as_ref(&self) -> &T {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
unsafe impl<T> Send for SendPtr<T> {}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cell::Cell;
|
||||
use std::{cell::Cell, hint::black_box};
|
||||
|
||||
use crate::latch::CountWakeLatch;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
@ -1055,30 +1276,65 @@ mod tests {
|
|||
1861, 1867, 1871, 1873, 1877, 1879, 1889, 1901, 1907,
|
||||
];
|
||||
|
||||
fn run_on_static_pool(f: impl FnOnce(&'static ThreadPool)) {
|
||||
let pool = Box::new(ThreadPool::new());
|
||||
fn run_in_scope<T: Send>(pool: ThreadPool, f: impl FnOnce(Pin<&Scope<'_>>) -> T + Send) -> T {
|
||||
let pool = Box::new(pool);
|
||||
let ptr = Box::into_raw(pool);
|
||||
|
||||
{
|
||||
let result = {
|
||||
let pool: &'static ThreadPool = unsafe { &*ptr };
|
||||
pool.ensure_one_worker();
|
||||
f(pool);
|
||||
// pool.ensure_one_worker();
|
||||
pool.resize_to_available();
|
||||
let result = pool.scope(f);
|
||||
pool.resize_to(0);
|
||||
assert!(pool.global_queue.pop().is_none());
|
||||
}
|
||||
result
|
||||
};
|
||||
|
||||
let _pool = unsafe { Box::from_raw(ptr) };
|
||||
result
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tracing_test::traced_test]
|
||||
fn spawn_random() {
|
||||
std::thread_local! {static WAIT_COUNT: Cell<usize>= Cell::new(0);}
|
||||
run_on_static_pool(|pool| {
|
||||
for &p in PRIMES {
|
||||
pool.spawn(move || {
|
||||
std::thread::sleep(Duration::from_micros(p as u64));
|
||||
});
|
||||
}
|
||||
});
|
||||
std::thread_local! {
|
||||
static WAIT_COUNT: Cell<usize> = const {Cell::new(0)};
|
||||
}
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let elapsed = {
|
||||
let pool = ThreadPool::new_with_callbacks(ThreadPoolCallbacks {
|
||||
at_entry: Some(Arc::new(|_worker| {
|
||||
// eprintln!("new worker thread: {}", worker.index);
|
||||
})),
|
||||
at_exit: Some(Arc::new({
|
||||
let counter = counter.clone();
|
||||
move |_worker: &WorkerThread| {
|
||||
// eprintln!("thread {}: {}", worker.index, WAIT_COUNT.get());
|
||||
counter.fetch_add(WAIT_COUNT.get(), Ordering::Relaxed);
|
||||
}
|
||||
})),
|
||||
});
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
run_in_scope(pool, |s| {
|
||||
for &p in core::iter::repeat_n(PRIMES, 0x1000).flatten() {
|
||||
s.spawn(move |_| {
|
||||
// std::thread::sleep(Duration::from_micros(p as u64));
|
||||
// spin for
|
||||
let tmp = (0..p).reduce(|a, b| black_box(a & b));
|
||||
black_box(tmp);
|
||||
|
||||
// WAIT_COUNT.with(|count| {
|
||||
// // eprintln!("{} + {p}", count.get());
|
||||
// count.set(count.get() + p);
|
||||
// });
|
||||
});
|
||||
}
|
||||
});
|
||||
now.elapsed().as_micros()
|
||||
};
|
||||
|
||||
eprintln!("total wait count: {}", counter.load(Ordering::Acquire));
|
||||
eprintln!("total time: {}ms", elapsed as f32 / 1e3);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue