erm.......

This commit is contained in:
Janis 2025-06-24 18:03:23 +02:00
parent e8a07ce6a5
commit ed4acbfbd7
13 changed files with 1038 additions and 188 deletions

View file

@ -48,3 +48,4 @@ cfg-if = "1.0.0"
[dev-dependencies]
async-std = "1.13.0"
tracing-test = "0.2.5"
distaff = {path = "distaff"}

View file

@ -184,3 +184,36 @@ fn join_rayon(b: &mut Bencher) {
assert_ne!(sum(&tree, tree.root().unwrap()), 0);
});
}
#[bench]
fn join_distaff(b: &mut Bencher) {
use distaff::*;
let pool = ThreadPool::new();
let tree = tree::Tree::new(TREE_SIZE, 1u32);
fn sum<'scope, 'env>(
tree: &tree::Tree<u32>,
node: usize,
scope: &'scope Scope<'scope, 'env>,
) -> u32 {
let node = tree.get(node);
let (l, r) = scope.join(
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|s| {
node.right
.map(|node| sum(tree, node, s))
.unwrap_or_default()
},
);
node.leaf + l + r
}
b.iter(move || {
pool.scope(|s| {
let sum = sum(&tree, tree.root().unwrap(), s);
// eprintln!("{sum}");
assert_ne!(sum, 0);
});
});
}

View file

@ -13,4 +13,8 @@ tracing = "0.1.40"
parking_lot_core = "0.9.10"
crossbeam-utils = "0.8.21"
async-task = "4.7.1"
async-task = "4.7.1"
[dev-dependencies]
tracing-test = "0.2.5"
futures = "0.3"

View file

@ -8,11 +8,12 @@ use std::{
use alloc::collections::BTreeMap;
use async_task::Runnable;
use crossbeam_utils::CachePadded;
use parking_lot::{Condvar, Mutex};
use crate::{
job::{Job, StackJob},
job::{HeapJob, Job, StackJob},
latch::{LatchRef, MutexLatch, WakeLatch},
workerthread::{HeartbeatThread, WorkerThread},
};
@ -50,10 +51,6 @@ impl Heartbeat {
pub fn is_pending(&self) -> bool {
self.heartbeat.load(Ordering::Relaxed) == Self::PENDING
}
pub fn is_sleeping(&self) -> bool {
self.heartbeat.load(Ordering::Relaxed) == Self::SLEEPING
}
}
pub struct Context {
@ -87,6 +84,7 @@ impl Shared {
// this is unlikely, so make the function cold?
// TODO: profile this
if !self.injected_jobs.is_empty() {
// SAFETY: we checked that injected_jobs is not empty
unsafe { return Some(self.pop_injected_job()) };
} else {
self.jobs.pop_first().map(|(_, job)| job)
@ -105,7 +103,7 @@ impl Shared {
impl Context {
#[inline]
pub fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
self.shared.lock()
}
@ -159,6 +157,17 @@ impl Context {
this
}
pub fn set_should_exit(&self) {
let mut shared = self.shared.lock();
shared.should_exit = true;
for (_, heartbeat) in shared.heartbeats.iter() {
if let Some(heartbeat) = heartbeat.upgrade() {
heartbeat.latch.set();
}
}
self.shared_job.notify_all();
}
pub fn new() -> Arc<Self> {
Self::new_with_threads(crate::util::available_parallelism())
}
@ -270,6 +279,66 @@ impl Context {
}
}
impl Context {
pub fn spawn<F>(self: &Arc<Self>, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(HeapJob::new(f)).into_boxed_job();
tracing::trace!("Context::spawn: spawning job: {:?}", job);
unsafe {
(&*job).set_pending();
self.inject_job(NonNull::new_unchecked(job));
}
}
pub fn spawn_future<T, F>(self: &Arc<Self>, future: F) -> async_task::Task<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let schedule = move |runnable: Runnable| {
#[align(8)]
unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
unsafe {
let runnable =
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
runnable.run();
// SAFETY: job was turned into raw
drop(Box::from_raw(job.cast_mut()));
}
}
let job = Box::new(Job::<T>::new(harness::<T>, runnable.into_raw()));
// casting into Job<()> here
unsafe {
job.set_pending();
self.inject_job(NonNull::new_unchecked(Box::into_raw(job) as *mut Job<()>));
}
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
#[allow(dead_code)]
fn spawn_async<T, Fut, Fn>(self: &Arc<Self>, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let future = async move { f().await };
self.spawn_future(future)
}
}
pub fn run_in_worker<T, F>(f: F) -> T
where
T: Send,
@ -277,3 +346,56 @@ where
{
Context::global_context().run_in_worker(f)
}
#[cfg(test)]
mod tests {
use tracing_test::traced_test;
use super::*;
#[test]
fn run_in_worker_test() {
let ctx = Context::global_context().clone();
let result = ctx.run_in_worker(|_| 42);
assert_eq!(result, 42);
}
#[test]
fn spawn_future_test() {
let ctx = Context::global_context().clone();
let task = ctx.spawn_future(async { 42 });
// Wait for the task to complete
let result = futures::executor::block_on(task);
assert_eq!(result, 42);
}
#[test]
fn spawn_async_test() {
let ctx = Context::global_context().clone();
let task = ctx.spawn_async(|| async { 42 });
// Wait for the task to complete
let result = futures::executor::block_on(task);
assert_eq!(result, 42);
}
#[test]
fn spawn_test() {
let ctx = Context::global_context().clone();
let counter = Arc::new(AtomicU8::new(0));
let barrier = Arc::new(std::sync::Barrier::new(2));
ctx.spawn({
let counter = counter.clone();
let barrier = barrier.clone();
move || {
counter.fetch_add(1, Ordering::SeqCst);
barrier.wait();
}
});
barrier.wait();
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}

View file

@ -40,6 +40,59 @@ impl JobState {
}
pub use joblist::JobList;
pub use jobvec::JobVec;
// replacement for `JobList` that uses a VecDeque instead of a linked list.
mod jobvec {
use std::ptr::NonNull;
use super::Job;
use alloc::collections::VecDeque;
#[derive(Debug)]
pub struct JobVec {
jobs: VecDeque<NonNull<Job>>,
}
impl JobVec {
pub fn new() -> Self {
Self {
jobs: VecDeque::new(),
}
}
pub fn remove(&mut self, job: &Job) {
// SAFETY: job is guaranteed to be valid and non-null
let job_ptr = unsafe { NonNull::new_unchecked(job as *const Job as _) };
self.jobs.retain(|j| *j != job_ptr);
}
pub fn push_front<T>(&mut self, job: *const Job<T>) {
let job_ptr = unsafe { NonNull::new_unchecked(job as _) };
self.jobs.push_front(job_ptr);
}
pub fn push_back<T>(&mut self, job: *const Job<T>) {
let job_ptr = unsafe { NonNull::new_unchecked(job as _) };
self.jobs.push_back(job_ptr);
}
pub fn pop_front(&mut self) -> Option<NonNull<Job>> {
self.jobs.pop_front()
}
pub fn pop_back(&mut self) -> Option<NonNull<Job>> {
self.jobs.pop_back()
}
pub fn is_empty(&self) -> bool {
self.jobs.is_empty()
}
pub fn len(&self) -> usize {
self.jobs.len()
}
}
}
mod joblist {
use core::{fmt::Debug, ptr::NonNull};
@ -87,6 +140,12 @@ mod joblist {
self.tail
}
pub fn remove(&mut self, job: &Job) {
job.unlink();
self.job_count -= 1;
}
/// `job` must be valid until it is removed from the list.
pub unsafe fn push_front<T>(&mut self, job: *const Job<T>) {
self.job_count += 1;
@ -124,8 +183,6 @@ mod joblist {
}
pub fn pop_front(&mut self) -> Option<NonNull<Job>> {
self.job_count -= 1;
let headlink = unsafe { self.head.as_ref().link_mut() };
// SAFETY: headlink.next is guaranteed to be Some.
@ -139,12 +196,13 @@ mod joblist {
headlink.next = Some(next);
next_link.prev = Some(self.head);
// decrement job count after having potentially short-circuited
self.job_count -= 1;
Some(job)
}
pub fn pop_back(&mut self) -> Option<NonNull<Job>> {
self.job_count -= 1;
let taillink = unsafe { self.tail.as_ref().link_mut() };
// SAFETY: taillink.prev is guaranteed to be Some.
@ -158,6 +216,9 @@ mod joblist {
taillink.prev = Some(prev);
prev_link.next = Some(self.tail);
// decrement job count after having potentially short-circuited
self.job_count -= 1;
Some(job)
}
@ -266,8 +327,6 @@ impl<T> Clone for Link<T> {
// `Link` is invariant over `T`
impl<T> Copy for Link<T> {}
struct Thread;
union ValueOrThis<T> {
uninit: (),
value: ManuallyDrop<SmallBox<T>>,
@ -385,7 +444,8 @@ impl<T> Job<T> {
}
/// assumes job is in a `JobList`
pub unsafe fn unlink(&self) {
pub fn unlink(&self) {
// SAFETY: if the job isn't linked, these will operate on a dummy value.
unsafe {
let mut dummy = None;
let Link { prev, next } = *self.link_mut();
@ -590,6 +650,7 @@ mod stackjob {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
tracing::trace!("job completed: {:?}", job);
let job = unsafe { &*job.cast::<Job<T>>() };
job.complete(result);
@ -664,3 +725,248 @@ mod heapjob {
pub use heapjob::HeapJob;
pub use stackjob::StackJob;
#[cfg(test)]
mod tests {
use crate::latch::{AtomicLatch, LatchRef};
use super::*;
#[test]
fn job_lifecycle() {
let latch = AtomicLatch::new();
let stack = StackJob::new(|| 3 + 4, LatchRef::new(&latch));
let job = stack.as_job::<i32>();
assert_eq!(job.state(), JobState::Empty as u8);
job.set_pending();
assert_eq!(job.state(), JobState::Pending as u8);
// execute the job
Job::<()>::execute(unsafe { NonNull::new_unchecked(&job as *const Job as _) });
// wait for the job to finish
let result = unsafe { job.transmute_ref::<i32>().wait() };
assert_eq!(result.into_result(), 7);
}
#[test]
fn job_lifecycle_panic() {
let latch = AtomicLatch::new();
let stack = StackJob::new(|| panic!("test panic"), LatchRef::new(&latch));
let job = stack.as_job::<i32>();
assert_eq!(job.state(), JobState::Empty as u8);
job.set_pending();
assert_eq!(job.state(), JobState::Pending as u8);
// execute the job
Job::<()>::execute(unsafe { NonNull::new_unchecked(&job as *const Job as _) });
// wait for the job to finish
let result = unsafe { job.transmute_ref::<i32>().wait() };
assert!(result.into_inner().is_err());
}
#[test]
fn joblist_popback() {
let mut list = JobList::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
unsafe {
list.push_back(job1);
list.push_back(job2);
}
assert_eq!(list.len(), 2);
let popped_job = list.pop_back().unwrap();
assert_eq!(popped_job.as_ptr(), job2 as _);
let popped_job = list.pop_back().unwrap();
assert_eq!(popped_job.as_ptr(), job1 as _);
assert!(list.is_empty());
}
#[test]
fn joblist_popfront() {
let mut list = JobList::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
unsafe {
list.push_front(job1);
list.push_front(job2);
}
assert_eq!(list.len(), 2);
let popped_job = list.pop_front().unwrap();
assert_eq!(popped_job.as_ptr(), job2 as _);
let popped_job = list.pop_front().unwrap();
assert_eq!(popped_job.as_ptr(), job1 as _);
assert!(list.is_empty());
}
#[test]
fn joblist_unlink_middle() {
let mut list = JobList::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job3 = Box::into_raw(Box::new(Job::<i32>::empty()));
unsafe {
list.push_back(job1);
list.push_back(job2);
list.push_back(job3);
}
assert_eq!(list.len(), 3);
// Unlink the middle job (job2)
unsafe {
(&*job2).unlink();
}
// Check that job1 and job3 are still in the list
let popped_job1 = list.pop_front().unwrap();
assert_eq!(popped_job1.as_ptr(), job1 as _);
let popped_job3 = list.pop_front().unwrap();
assert_eq!(popped_job3.as_ptr(), job3 as _);
}
#[test]
fn joblist_unlink_head() {
let mut list = JobList::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
unsafe {
list.push_back(job1);
list.push_back(job2);
}
assert_eq!(list.len(), 2);
unsafe {
(&*job1).unlink();
}
// Check that job2 is still in the list
let popped_job2 = list.pop_front().unwrap();
assert_eq!(popped_job2.as_ptr(), job2 as _);
}
#[test]
fn joblist_unlink_tail() {
let mut list = JobList::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
unsafe {
list.push_back(job1);
list.push_back(job2);
}
assert_eq!(list.len(), 2);
unsafe {
(&*job2).unlink();
}
// Check that job1 is still in the list
let popped_job1 = list.pop_front().unwrap();
assert_eq!(popped_job1.as_ptr(), job1 as _);
}
#[test]
fn joblist_unlink_single() {
let mut list = JobList::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
unsafe {
list.push_back(job1);
}
assert_eq!(list.len(), 1);
unsafe {
(&*job1).unlink();
}
// Check that popping from an empty list returns None
assert!(list.pop_front().is_none());
}
#[test]
fn joblist_pop_empty() {
let mut list = JobList::new();
// Popping from an empty list should return None
assert!(list.pop_front().is_none());
assert!(list.pop_back().is_none());
}
#[test]
fn jobvec_push_front() {
let mut vec = JobVec::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
vec.push_front(job1);
vec.push_front(job2);
assert_eq!(vec.len(), 2);
let popped_job = vec.pop_front().unwrap();
assert_eq!(popped_job.as_ptr(), job2 as _);
let popped_job = vec.pop_front().unwrap();
assert_eq!(popped_job.as_ptr(), job1 as _);
assert!(vec.is_empty());
}
#[test]
fn jobvec_push_back() {
let mut vec = JobVec::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
vec.push_back(job1);
vec.push_back(job2);
assert_eq!(vec.len(), 2);
let popped_job = vec.pop_back().unwrap();
assert_eq!(popped_job.as_ptr(), job2 as _);
let popped_job = vec.pop_back().unwrap();
assert_eq!(popped_job.as_ptr(), job1 as _);
assert!(vec.is_empty());
}
#[test]
fn jobvec_push_front_pop_back() {
let mut vec = JobVec::new();
let job1 = Box::into_raw(Box::new(Job::<i32>::empty()));
let job2 = Box::into_raw(Box::new(Job::<i32>::empty()));
vec.push_front(job1);
vec.push_front(job2);
assert_eq!(vec.len(), 2);
let popped_job = vec.pop_back().unwrap();
assert_eq!(popped_job.as_ptr(), job1 as _);
let popped_job = vec.pop_back().unwrap();
assert_eq!(popped_job.as_ptr(), job2 as _);
assert!(vec.is_empty());
}
}

View file

@ -1,6 +1,7 @@
use std::hint::cold_path;
use std::{hint::cold_path, sync::Arc};
use crate::{
context::Context,
job::{JobState, StackJob},
latch::{AsCoreLatch, LatchRef, WakeLatch},
workerthread::WorkerThread,
@ -69,7 +70,6 @@ impl WorkerThread {
// WorkerThread::current_ref()
// .expect("stackjob is run in workerthread.")
// .tick();
a()
},
LatchRef::new(&latch),
@ -82,6 +82,7 @@ impl WorkerThread {
Ok(val) => val,
Err(payload) => {
cold_path();
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
// if b panicked, we need to wait for a to finish
self.wait_until_latch(&latch);
resume_unwind(payload);
@ -89,8 +90,11 @@ impl WorkerThread {
};
let ra = if job.state() == JobState::Empty as u8 {
// remove job from the queue, so it doesn't get run again.
// job.unlink();
//SAFETY: we are in a worker thread, so we can safely access the queue.
unsafe {
job.unlink();
self.queue.as_mut_unchecked().remove(&job);
}
// a is allowed to panic here, because we already finished b.
@ -108,3 +112,41 @@ impl WorkerThread {
(ra, rb)
}
}
impl Context {
#[inline]
pub fn join<A, B, RA, RB>(self: &Arc<Self>, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b))
}
}
/// run two closures potentially in parallel, in the global threadpool.
#[allow(dead_code)]
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
join_in(Context::global_context().clone(), a, b)
}
/// run two closures potentially in parallel, in the global threadpool.
#[allow(dead_code)]
fn join_in<A, B, RA, RB>(context: Arc<Context>, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
context.join(a, b)
}

View file

@ -37,6 +37,13 @@ impl AtomicLatch {
inner: AtomicU8::new(Self::UNSET),
}
}
pub const fn new_set() -> Self {
Self {
inner: AtomicU8::new(Self::SET),
}
}
#[inline]
pub fn reset(&self) {
self.inner.store(Self::UNSET, Ordering::Release);
@ -46,6 +53,10 @@ impl AtomicLatch {
self.inner.load(Ordering::Acquire)
}
pub fn set_sleeping(&self) {
self.inner.store(Self::SLEEPING, Ordering::Release);
}
/// returns true if the latch was previously sleeping.
#[inline]
pub unsafe fn set(this: *const Self) -> bool {
@ -244,7 +255,7 @@ impl<L: Latch> Latch for CountLatch<L> {
impl<L: Latch + Probe> Probe for CountLatch<L> {
#[inline]
fn probe(&self) -> bool {
self.inner.probe()
self.count.load(Ordering::Relaxed) == 0
}
}
@ -365,3 +376,168 @@ impl AsCoreLatch for WakeLatch {
&self.inner
}
}
#[cfg(test)]
mod tests {
use std::sync::Barrier;
use tracing::Instrument;
use tracing_test::traced_test;
use super::*;
#[test]
fn test_atomic_latch() {
let latch = AtomicLatch::new();
assert_eq!(latch.get(), AtomicLatch::UNSET);
unsafe {
assert!(!latch.probe());
AtomicLatch::set_raw(&latch);
}
assert_eq!(latch.get(), AtomicLatch::SET);
assert!(latch.probe());
latch.reset();
assert_eq!(latch.get(), AtomicLatch::UNSET);
}
#[test]
fn core_latch_sleep() {
let latch = AtomicLatch::new();
assert_eq!(latch.get(), AtomicLatch::UNSET);
latch.set_sleeping();
assert_eq!(latch.get(), AtomicLatch::SLEEPING);
unsafe {
assert!(!latch.probe());
assert!(AtomicLatch::set(&latch));
}
assert_eq!(latch.get(), AtomicLatch::SET);
assert!(latch.probe());
latch.reset();
assert_eq!(latch.get(), AtomicLatch::UNSET);
}
#[test]
fn nop_latch() {
assert!(
core::mem::size_of::<NopLatch>() == 0,
"NopLatch should be zero-sized"
);
}
#[test]
fn thread_wake_latch() {
let latch = Arc::new(ThreadWakeLatch::new());
let main = Arc::new(ThreadWakeLatch::new());
let handle = std::thread::spawn({
let latch = latch.clone();
let main = main.clone();
move || unsafe {
Latch::set_raw(&*main);
latch.wait();
}
});
unsafe {
main.wait();
Latch::set_raw(&*latch);
}
handle.join().expect("Thread should join successfully");
assert!(
!latch.probe() && !main.probe(),
"Latch should be set after waiting thread wakes up"
);
}
#[test]
fn count_latch() {
let latch = CountLatch::new(AtomicLatch::new());
assert_eq!(latch.count(), 0);
latch.increment();
assert_eq!(latch.count(), 1);
assert!(!latch.probe());
latch.increment();
assert_eq!(latch.count(), 2);
assert!(!latch.probe());
unsafe {
Latch::set_raw(&latch);
}
assert!(!latch.probe());
assert_eq!(latch.count(), 1);
unsafe {
Latch::set_raw(&latch);
}
assert!(latch.probe());
assert_eq!(latch.count(), 0);
}
#[test]
fn mutex_latch() {
let latch = Arc::new(MutexLatch::new());
assert!(!latch.probe());
latch.set();
assert!(latch.probe());
latch.reset();
assert!(!latch.probe());
// Test wait functionality
let latch_clone = latch.clone();
let handle = std::thread::spawn(move || {
latch_clone.wait();
});
// Give the thread time to block
std::thread::sleep(std::time::Duration::from_millis(100));
assert!(!latch.probe());
latch.set();
assert!(latch.probe());
handle.join().expect("Thread should join successfully");
}
#[test]
fn wake_latch() {
let context = Context::new_with_threads(1);
let count = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(2));
tracing::info!("running scope in worker thread");
let latch = context.run_in_worker(|worker| {
tracing::info!("worker thread started: {:?}", worker.index);
let latch = WakeLatch::new(worker.context.clone(), worker.index);
worker.context.spawn({
let heartbeat = worker.heartbeat.clone();
let barrier = barrier.clone();
let count = count.clone();
// set sleeping outside of the closure so we don't have to deal with lifetimes
latch.as_core_latch().set_sleeping();
move || {
tracing::info!("sleeping workerthread");
heartbeat.latch.wait_and_reset();
tracing::info!("woken up workerthread");
count.fetch_add(1, Ordering::SeqCst);
tracing::info!("waiting on barrier");
barrier.wait();
}
});
latch
});
tracing::info!("setting latch in main thread");
unsafe {
Latch::set_raw(&latch);
}
tracing::info!("main thread set latch, waiting for worker thread to wake up");
barrier.wait();
assert_eq!(
count.load(Ordering::SeqCst),
1,
"Latch should have woken the worker thread"
);
}
}

View file

@ -20,3 +20,9 @@ mod scope;
mod threadpool;
pub mod util;
mod workerthread;
pub use context::run_in_worker;
pub use join::join;
pub use scope::{Scope, scope};
pub use threadpool::ThreadPool;
pub use workerthread::WorkerThread;

View file

@ -11,14 +11,14 @@ use std::{
use async_task::Runnable;
use crate::{
context::{Context, run_in_worker},
context::Context,
job::{HeapJob, Job},
latch::{AsCoreLatch, CountLatch, WakeLatch},
util::{DropGuard, SendPtr},
workerthread::WorkerThread,
};
pub struct Scope<'scope> {
pub struct Scope<'scope, 'env: 'scope> {
// latch to wait on before the scope finishes
job_counter: CountLatch<WakeLatch>,
// local threadpool
@ -26,55 +26,44 @@ pub struct Scope<'scope> {
// panic error
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
// variant lifetime
_pd: PhantomData<fn(&'scope ())>,
_scope: PhantomData<&'scope mut &'scope ()>,
_env: PhantomData<&'env mut &'env ()>,
}
pub fn scope<'scope, F, R>(f: F) -> R
pub fn scope<'env, F, R>(f: F) -> R
where
F: FnOnce(&Scope<'scope>) -> R + Send,
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
R: Send,
{
Scope::<'scope>::scope(f)
scope_with_context(Context::global_context(), f)
}
impl<'scope> Scope<'scope> {
pub fn scope_with_context<'env, F, R>(context: &Arc<Context>, f: F) -> R
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
R: Send,
{
context.run_in_worker(|worker| {
// SAFETY: we call complete() after creating this scope, which
// ensures that any jobs spawned from the scope exit before the
// scope closes.
let this = unsafe { Scope::from_context(context.clone()) };
this.complete(worker, || f(&this))
})
}
impl<'scope, 'env> Scope<'scope, 'env> {
fn wait_for_jobs(&self, worker: &WorkerThread) {
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe {
worker.queue.as_ref_unchecked()
});
if self.job_counter.count() > 0 {
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe {
worker.queue.as_ref_unchecked()
});
// set worker index in the job counter
self.job_counter.inner().set_worker_index(worker.index);
worker.wait_until_latch(self.job_counter.as_core_latch());
}
pub fn scope<F, R>(f: F) -> R
where
F: FnOnce(&Self) -> R + Send,
R: Send,
{
run_in_worker(|worker| {
// SAFETY: we call complete() after creating this scope, which
// ensures that any jobs spawned from the scope exit before the
// scope closes.
let this = unsafe { Self::from_context(worker.context.clone()) };
this.complete(worker, || f(&this))
})
}
fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R
where
F: FnOnce(&Self) -> R + Send,
R: Send,
{
context.run_in_worker(|worker| {
// SAFETY: we call complete() after creating this scope, which
// ensures that any jobs spawned from the scope exit before the
// scope closes.
let this = unsafe { Self::from_context(context.clone()) };
this.complete(worker, || f(&this))
})
// set worker index in the job counter
self.job_counter.inner().set_worker_index(worker.index);
worker.wait_until_latch(self.job_counter.as_core_latch());
}
}
/// should be called from within a worker thread.
@ -153,9 +142,9 @@ impl<'scope> Scope<'scope> {
});
}
pub fn spawn<F>(&self, f: F)
pub fn spawn<F>(&'scope self, f: F)
where
F: FnOnce(&Scope<'scope>) + Send,
F: FnOnce(&'scope Self) + Send,
{
self.context.run_in_worker(|worker| {
self.job_counter.increment();
@ -176,70 +165,81 @@ impl<'scope> Scope<'scope> {
});
}
pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
pub fn spawn_future<T, F>(&'scope self, future: F) -> async_task::Task<T>
where
F: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.context.run_in_worker(|worker| {
self.job_counter.increment();
let this = SendPtr::new_const(&self.job_counter).unwrap();
let future = async move {
let _guard = DropGuard::new(move || unsafe {
this.as_ref().decrement();
});
future.await
};
let schedule = move |runnable: Runnable| {
#[align(8)]
unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
unsafe {
let runnable =
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
runnable.run();
// SAFETY: job was turned into raw
drop(Box::from_raw(job.cast_mut()));
}
}
let job = Box::new(Job::<T>::new(harness::<T>, runnable.into_raw()));
// casting into Job<()> here
worker.push_front(Box::into_raw(job) as _);
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
})
self.spawn_async_internal(move |_| future)
}
#[allow(dead_code)]
fn spawn_async<'a, T, Fut, Fn>(&'a self, f: Fn) -> async_task::Task<T>
pub fn spawn_async<T, Fut, Fn>(&'scope self, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce(&Scope) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
let this = SendPtr::new_const(self).unwrap();
let future = async move { f(unsafe { this.as_ref() }).await };
self.spawn_future(future)
self.spawn_async_internal(f)
}
#[inline]
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
fn spawn_async_internal<T, Fut, Fn>(&'scope self, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.job_counter.increment();
let this = SendPtr::new_const(self).unwrap();
// let this = SendPtr::new_const(&self.job_counter).unwrap();
let future = async move {
// SAFETY: this is valid until we decrement the job counter.
unsafe {
let _guard = DropGuard::new(move || {
this.as_unchecked_ref().job_counter.decrement();
});
f(this.as_ref()).await
}
};
let schedule = move |runnable: Runnable| {
#[align(8)]
unsafe fn harness(this: *const (), job: *const Job) {
unsafe {
let runnable =
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
runnable.run();
// SAFETY: job was turned into raw
drop(Box::from_raw(job.cast_mut()));
}
}
let job = Box::new(Job::new(harness, runnable.into_raw()));
// casting into Job<()> here
WorkerThread::current_ref()
.expect("spawn_async_internal is run in workerthread.")
.push_front(Box::into_raw(job) as _);
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
#[inline]
pub fn join<A, B, RA, RB>(&'scope self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send,
A: FnOnce(&'scope Self) -> RA + Send,
B: FnOnce(&'scope Self) -> RB + Send,
{
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
let this = SendPtr::new_const(self).unwrap();
@ -261,7 +261,60 @@ impl<'scope> Scope<'scope> {
context: ctx.clone(),
job_counter: CountLatch::new(WakeLatch::new(ctx, 0)),
panic: AtomicPtr::new(ptr::null_mut()),
_pd: PhantomData,
_scope: PhantomData,
_env: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicU8;
use tracing_test::traced_test;
use super::*;
use crate::ThreadPool;
#[test]
fn spawn() {
let pool = ThreadPool::new_with_threads(1);
let count = Arc::new(AtomicU8::new(0));
scope_with_context(&pool.context, |scope| {
scope.spawn(|_| {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
});
});
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[test]
#[traced_test]
fn join() {
let pool = ThreadPool::new_with_threads(1);
let a = pool.scope(|scope| {
let (a, b) = scope.join(|_| 3 + 4, |_| 5 + 6);
a + b
});
assert_eq!(a, 18);
}
#[test]
fn spawn_future() {
let pool = ThreadPool::new_with_threads(1);
let mut x = 0;
pool.scope(|scope| {
let task = scope.spawn_async(|_| async {
x += 1;
});
task.detach();
});
assert_eq!(x, 1);
}
}

View file

@ -1 +1,93 @@
use std::sync::Arc;
use crate::{Scope, context::Context, scope::scope_with_context};
pub struct ThreadPool {
pub(crate) context: Arc<Context>,
}
impl Drop for ThreadPool {
fn drop(&mut self) {
// Ensure that the context is properly cleaned up when the thread pool is dropped.
self.context.set_should_exit();
}
}
impl ThreadPool {
pub fn new_with_threads(num_threads: usize) -> Self {
let context = Context::new_with_threads(num_threads);
Self { context }
}
/// Creates a new thread pool with a thread per hardware thread.
pub fn new() -> Self {
let context = Context::new();
Self { context }
}
pub fn scope<'env, F, R>(&self, f: F) -> R
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
R: Send,
{
scope_with_context(&self.context, f)
}
pub fn spawn<F, R>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
self.context.spawn(f)
}
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
{
self.context.join(a, b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn spawn_borrow() {
let pool = ThreadPool::new_with_threads(1);
let mut x = 0;
pool.scope(|scope| {
scope.spawn(|_| {
x += 1;
});
});
assert_eq!(x, 1);
}
#[test]
fn spawn_future() {
let pool = ThreadPool::new_with_threads(1);
let mut x = 0;
let task = pool.scope(|scope| {
let task = scope.spawn_async(|_| async {
x += 1;
});
task
});
futures::executor::block_on(task);
assert_eq!(x, 1);
}
#[test]
fn join() {
let pool = ThreadPool::new_with_threads(1);
let (a, b) = pool.join(|| 3 + 4, || 5 * 6);
assert_eq!(a, 7);
assert_eq!(b, 30);
}
}

View file

@ -93,6 +93,11 @@ impl<T> SendPtr<T> {
pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self {
unsafe { Self::new_unchecked(ptr.cast_mut()) }
}
pub unsafe fn as_unchecked_ref(&self) -> &T {
// SAFETY: `self.0` is a valid non-null pointer.
unsafe { self.0.as_ref() }
}
}
/// A tagged atomic pointer that can store a pointer and a tag `BITS` wide in the same space
@ -402,3 +407,63 @@ pub fn available_parallelism() -> usize {
.map(|n| n.get())
.unwrap_or(1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tagged_ptr_exchange() {
let ptr = Box::into_raw(Box::new(42u32));
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
assert_eq!(
tagged_ptr
.compare_exchange_tag(0b11, 0b10, Ordering::Relaxed, Ordering::Relaxed)
.unwrap(),
0b11
);
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b10);
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
unsafe {
_ = Box::from_raw(ptr);
}
}
#[test]
fn value_inline() {
assert!(SmallBox::<u32>::is_inline(), "u32 should be inline");
assert!(SmallBox::<u8>::is_inline(), "u8 should be inline");
assert!(
SmallBox::<Box<u32>>::is_inline(),
"Box<u32> should be inline"
);
assert!(
SmallBox::<[u32; 2]>::is_inline(),
"[u32; 2] should be inline"
);
assert!(
!SmallBox::<[u32; 3]>::is_inline(),
"[u32; 3] should not be inline"
);
assert!(SmallBox::<usize>::is_inline(), "usize should be inline");
#[repr(C, align(16))]
struct LargeType(u8);
assert!(
!SmallBox::<LargeType>::is_inline(),
"LargeType should not be inline"
);
#[repr(C, align(4))]
struct SmallType(u8);
assert!(
SmallBox::<SmallType>::is_inline(),
"SmallType should be inline"
);
}
}

View file

@ -6,11 +6,10 @@ use std::{
};
use crossbeam_utils::CachePadded;
use parking_lot_core::SpinWait;
use crate::{
context::{Context, Heartbeat},
job::{Job, JobList, JobResult},
job::{Job, JobResult, JobVec as JobList},
latch::{AsCoreLatch, CoreLatch, Probe},
util::DropGuard,
};
@ -19,7 +18,7 @@ pub struct WorkerThread {
pub(crate) context: Arc<Context>,
pub(crate) index: usize,
pub(crate) queue: UnsafeCell<JobList>,
heartbeat: Arc<CachePadded<Heartbeat>>,
pub(crate) heartbeat: Arc<CachePadded<Heartbeat>>,
pub(crate) join_count: Cell<u8>,
}
@ -39,11 +38,6 @@ impl WorkerThread {
join_count: Cell::new(0),
}
}
fn new() -> Self {
let context = Context::global_context().clone();
Self::new_in(context)
}
}
impl WorkerThread {
@ -72,7 +66,7 @@ impl WorkerThread {
let mut job = self.context.shared().pop_job();
'outer: loop {
let mut guard = loop {
if let Some(job) = job {
if let Some(job) = job.take() {
self.execute(job);
}
@ -83,9 +77,11 @@ impl WorkerThread {
break 'outer;
}
// TODO: also check the local queue?
match guard.pop_job() {
Some(job) => {
tracing::trace!("worker: popping job: {:?}", job);
Some(popped) => {
tracing::trace!("worker: popping job: {:?}", popped);
job = Some(popped);
// found job, continue inner loop
continue;
}
@ -107,6 +103,7 @@ impl WorkerThread {
#[inline(always)]
fn tick(&self) {
if self.heartbeat.is_pending() {
tracing::trace!("received heartbeat, thread id: {:?}", self.index);
self.heartbeat_cold();
}
}
@ -190,8 +187,11 @@ impl WorkerThread {
unsafe fn drop_in_place(this: *mut Self) {
unsafe {
this.drop_in_place();
drop(Box::from_raw(this));
// SAFETY: this is only called when the thread is exiting, so we can
// safely drop the thread. We use `drop_in_place` to prevent `Box`
// from creating a no-alias reference to the worker thread.
core::ptr::drop_in_place(this);
_ = Box::<core::mem::ManuallyDrop<Self>>::from_raw(this as _);
}
}
}
@ -258,11 +258,6 @@ impl WorkerThread {
assert!(!latch.probe());
'outer: while !latch.probe() {
// take a shared job, if it exists
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
self.execute(shared_job);
}
// process local jobs before locking shared context
while let Some(job) = self.pop_front() {
unsafe {
@ -271,8 +266,16 @@ impl WorkerThread {
self.execute(job);
}
// take a shared job, if it exists
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
self.execute(shared_job);
}
while !latch.probe() {
let job = self.context.shared().pop_job();
let job = {
let mut guard = self.context.shared();
guard.jobs.remove(&self.index).or_else(|| guard.pop_job())
};
match job {
Some(job) => {
@ -281,8 +284,6 @@ impl WorkerThread {
continue 'outer;
}
None => {
tracing::trace!("waiting for shared job, thread id: {:?}", self.index);
// TODO: wait on latch? if we have something that can
// signal being done, e.g. can be waited on instead of
// shared jobs, we should wait on it instead, but we
@ -297,6 +298,9 @@ impl WorkerThread {
// Yield? same as spinning, really, so just exit and let the upstream use wait
// std::thread::yield_now();
tracing::trace!("thread {:?} is sleeping", self.index);
latch.set_sleeping();
self.heartbeat.latch.wait_and_reset();
// since we were sleeping, the shared job can't be populated,
// so resuming the inner loop is fine.
@ -339,58 +343,4 @@ impl WorkerThread {
self.wait_until_latch_cold(latch)
}
}
#[inline]
fn wait_until_predicate<F>(&self, pred: F)
where
F: Fn() -> bool,
{
'outer: while !pred() {
// take a shared job, if it exists
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
self.execute(shared_job);
}
// process local jobs before locking shared context
while let Some(job) = self.pop_front() {
unsafe {
job.as_ref().set_pending();
}
self.execute(job);
}
while !pred() {
let mut guard = self.context.shared();
let mut _spin = SpinWait::new();
match guard.pop_job() {
Some(job) => {
drop(guard);
self.execute(job);
continue 'outer;
}
None => {
tracing::trace!("waiting for shared job, thread id: {:?}", self.index);
// TODO: wait on latch? if we have something that can
// signal being done, e.g. can be waited on instead of
// shared jobs, we should wait on it instead, but we
// would also want to receive shared jobs still?
// Spin? probably just wastes CPU time.
// self.context.shared_job.wait(&mut guard);
// if spin.spin() {
// // wait for more shared jobs.
// // self.context.shared_job.wait(&mut guard);
// return;
// }
// Yield? same as spinning, really, so just exit and let the upstream use wait
// std::thread::yield_now();
return;
}
}
}
}
return;
}
}

View file

@ -813,7 +813,7 @@ mod job {
}
}
/// call this when popping value from local queue
/// must be called before `execute()`
pub fn set_pending(&self) {
let mut spin = SpinWait::new();
loop {