diff --git a/distaff/src/scope.rs b/distaff/src/scope.rs index d5e5d3b..c474ff4 100644 --- a/distaff/src/scope.rs +++ b/distaff/src/scope.rs @@ -301,6 +301,27 @@ mod tests { assert_eq!(a, 18); } + #[test] + fn join_many() { + let pool = ThreadPool::new_with_threads(1); + + fn sum<'scope, 'env>(scope: &'scope Scope<'scope, 'env>, n: usize) -> usize { + if n == 0 { + return 0; + } + + let (l, r) = scope.join(|s| sum(s, n - 1), |s| sum(s, n - 1)); + + l + r + 1 + } + + pool.scope(|scope| { + let total = sum(scope, 5); + assert_eq!(total, 31); + // eprintln!("Total sum: {}", total); + }); + } + #[test] fn spawn_future() { let pool = ThreadPool::new_with_threads(1); @@ -315,4 +336,21 @@ mod tests { assert_eq!(x, 1); } + + #[test] + fn spawn_many() { + let pool = ThreadPool::new_with_threads(1); + let count = Arc::new(AtomicU8::new(0)); + + pool.scope(|scope| { + for _ in 0..10 { + let count = count.clone(); + scope.spawn(move |_| { + count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + }); + } + }); + + assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 10); + } } diff --git a/distaff/src/workerthread.rs b/distaff/src/workerthread.rs index ac66786..6533c14 100644 --- a/distaff/src/workerthread.rs +++ b/distaff/src/workerthread.rs @@ -111,6 +111,10 @@ impl WorkerThread { self.tick(); } crate::latch::WakeResult::Set => { + // check if we should exit the thread + if self.context.shared().should_exit() { + break 'outer; + } panic!("this thread shouldn't be woken by a finished job") } } @@ -354,6 +358,7 @@ impl WorkerThread { self.index, latch ); + self.heartbeat().latch.as_core_latch().unset(); return; } diff --git a/examples/join.rs b/examples/join.rs index 9a2c31b..5cc1a07 100644 --- a/examples/join.rs +++ b/examples/join.rs @@ -86,6 +86,7 @@ fn join_distaff() { let sum = sum(&tree, tree.root().unwrap(), s); sum }); + eprintln!("sum: {sum}"); std::hint::black_box(sum); } }