From d6115359948d48defd7018e8fcd0444e1366de0a Mon Sep 17 00:00:00 2001 From: Janis Date: Thu, 19 Jun 2025 14:25:15 +0200 Subject: [PATCH] it compiles... --- src/praetor/mod.rs | 154 ++++++++++++++++++++++++++++++++------------- 1 file changed, 111 insertions(+), 43 deletions(-) diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index eefb23f..72cab6f 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -971,15 +971,16 @@ mod job { } use std::{ + any::Any, cell::UnsafeCell, collections::BTreeMap, future::Future, hint::cold_path, marker::PhantomData, mem::{self, MaybeUninit}, - ptr::NonNull, + ptr::{self, NonNull}, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}, Arc, OnceLock, Weak, }, time::Duration, @@ -1027,6 +1028,12 @@ impl JobCounter { } } +impl crate::latch::Probe for JobCounter { + fn probe(&self) -> bool { + self.count() == 0 + } +} + struct WorkerThread { context: Arc, index: usize, @@ -1041,6 +1048,8 @@ pub struct Scope<'scope> { job_counter: JobCounter, // local threadpool context: Arc, + // panic error + panic: AtomicPtr>, // variant lifetime _pd: PhantomData, } @@ -1289,71 +1298,118 @@ impl WorkerThread { } } +pub fn scope<'scope, F, R>(f: F) -> R +where + F: FnOnce(&Scope<'scope>) -> R + Send, + R: Send, +{ + Scope::<'scope>::scope(f) +} + impl<'scope> Scope<'scope> { fn wait_for_jobs(&self) { - unsafe { - tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); - self.job_counter.wait(); - } + tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count()); + + let thread = WorkerThread::current_ref().unwrap(); + thread.wait_until_latch(&self.job_counter); } - pub fn scope(&self, f: F) -> R + pub fn scope(f: F) -> R where F: FnOnce(&Self) -> R + Send, R: Send, { - self.complete(|| f(self)) + run_in_worker(|thread| { + let this = Self::from_context(thread.context.clone()); + this.complete(|| f(&this)) + }) } + pub fn scope_with_context(context: Arc, f: F) -> R + where + F: FnOnce(&Self) -> R + Send, + R: Send, + { + context.run_in_worker(|_| { + let this = Self::from_context(context.clone()); + this.complete(|| f(&this)) + }) + } + + /// should be called from within a worker thread. fn complete(&self, f: F) -> R where F: FnOnce() -> R + Send, R: Send, { - use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; - #[repr(align(8))] - unsafe fn harness T, T>(this: *const (), job: *const Job) { - let f = unsafe { Box::from_raw(this.cast::().cast_mut()) }; - - let result = catch_unwind(AssertUnwindSafe(move || f())); - - let job = unsafe { Box::from_raw(job.cast_mut()) }; - job.complete(result); - } + use std::panic::{catch_unwind, AssertUnwindSafe}; + #[allow(dead_code)] fn make_job T, T>(f: F) -> Job { + #[repr(align(8))] + unsafe fn harness T, T>(this: *const (), job: *const Job) { + let f = unsafe { Box::from_raw(this.cast::().cast_mut()) }; + + let result = catch_unwind(AssertUnwindSafe(move || f())); + + let job = unsafe { Box::from_raw(job.cast_mut()) }; + job.complete(result); + } + Job::::new(harness::, unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(f))).cast() }) } - let result = WorkerThread::with_in(&self.context, |worker| { - let mut job = make_job(f); - - unsafe { - _ = worker; - job.set_pending(); - Job::execute(NonNull::new_unchecked(&mut job)); + let result = match catch_unwind(AssertUnwindSafe(|| f())) { + Ok(val) => Some(val), + Err(payload) => { + self.panicked(payload); + None } - - // let this = SendPtr::new_const(self).unwrap(); - // - // worker.push_front(&job); - // - // match worker.wait_until(&job) { - // Some(result) => result, - // None => unsafe { - // let f = Box::::from_non_null(job.unwrap_this().cast()); - // Ok(f(this.as_ref())) - // }, - // } - - job.wait() - }); + }; self.wait_for_jobs(); + self.maybe_propagate_panic(); - result.into_result() + // SAFETY: if result panicked, we would have propagated the panic above. + result.unwrap() + } + + /// resumes the panic if one happened in this scope. + fn maybe_propagate_panic(&self) { + let err_ptr = self.panic.load(Ordering::Relaxed); + if !err_ptr.is_null() { + unsafe { + let err = Box::from_raw(err_ptr); + std::panic::resume_unwind(*err); + } + } + } + + /// stores the first panic that happened in this scope. + fn panicked(&self, err: Box) { + self.panic.load(Ordering::Relaxed).is_null().then(|| { + use core::mem::ManuallyDrop; + let mut boxed = ManuallyDrop::new(Box::new(err)); + + let err_ptr: *mut Box = &mut **boxed; + if self + .panic + .compare_exchange( + ptr::null_mut(), + err_ptr, + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + // we successfully set the panic, no need to drop + } else { + // drop the error, someone else already set it + _ = ManuallyDrop::into_inner(boxed); + } + }); } pub fn spawn(&self, f: F) @@ -1530,6 +1586,7 @@ impl<'scope> Scope<'scope> { context: ctx, join_count: AtomicUsize::new(0), job_counter: JobCounter::default(), + panic: AtomicPtr::new(ptr::null_mut()), _pd: PhantomData, } } @@ -1568,8 +1625,7 @@ impl ThreadPool { F: FnOnce(&Scope<'scope>) -> R + Send, R: Send, { - let scope = Scope::from_context(self.context.clone()); - scope.scope(f) + Scope::scope_with_context(self.context.clone(), f) } } @@ -1661,6 +1717,7 @@ impl Context { self.shared_job.notify_one(); } + /// Runs closure in this context, processing the worker's jobs while waiting for the result. fn run_in_worker_cross(self: &Arc, worker: &WorkerThread, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, @@ -1690,6 +1747,7 @@ impl Context { t } + /// Run closure in this context, sleeping until the job is done. pub fn run_in_worker_cold(self: &Arc, f: F) -> T where F: FnOnce(&WorkerThread) -> T + Send, @@ -1720,6 +1778,7 @@ impl Context { t } + /// Run closure in this context. pub fn run_in_worker(self: &Arc, f: F) -> T where T: Send, @@ -1739,9 +1798,18 @@ impl Context { } } +fn run_in_worker(f: F) -> T +where + T: Send, + F: FnOnce(&WorkerThread) -> T + Send, +{ + Context::global().run_in_worker(f) +} + static GLOBAL_CONTEXT: OnceLock> = OnceLock::new(); const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100); +/// returns the number of available hardware threads, or 1 if it cannot be determined. fn available_parallelism() -> usize { std::thread::available_parallelism() .map(|n| n.get())