use std::sync::Arc; use crate::{Scope, context::Context, scope::scope_with_context}; pub struct ThreadPool { pub(crate) context: Arc, } impl Drop for ThreadPool { fn drop(&mut self) { // TODO: Ensure that the context is properly cleaned up when the thread pool is dropped. // self.context.set_should_exit(); } } impl ThreadPool { pub fn new_with_threads(num_threads: usize) -> Self { let context = Context::new_with_threads(num_threads); Self { context } } /// Creates a new thread pool with a thread per hardware thread. pub fn new() -> Self { let context = Context::new(); Self { context } } pub fn global() -> Self { let context = Context::global_context().clone(); Self { context } } pub fn scope<'env, F, R>(&self, f: F) -> R where F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send, R: Send, { scope_with_context(&self.context, f) } pub fn spawn(&self, f: F) where F: FnOnce() + Send + 'static, { self.context.spawn(f) } pub fn join(&self, a: A, b: B) -> (RA, RB) where RA: Send, RB: Send, A: FnOnce() -> RA + Send, B: FnOnce() -> RB + Send, { self.context.join(a, b) } } #[cfg(test)] mod tests { use super::*; #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn pool_spawn_borrow() { let pool = ThreadPool::new_with_threads(1); let mut x = 0; pool.scope(|scope| { scope.spawn(|_| { #[cfg(feature = "tracing")] tracing::info!("Incrementing x"); x += 1; }); }); assert_eq!(x, 1); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn pool_spawn_future() { let pool = ThreadPool::new_with_threads(1); let mut x = 0; let task = pool.scope(|scope| { let task = scope.spawn_async(|_| async { x += 1; }); task }); futures::executor::block_on(task); assert_eq!(x, 1); } #[test] #[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)] fn pool_join() { let pool = ThreadPool::new_with_threads(1); let (a, b) = pool.join(|| 3 + 4, || 5 * 6); assert_eq!(a, 7); assert_eq!(b, 30); } }