LOTS OF CHANGES: but! this works

This commit is contained in:
Janis 2025-07-01 20:09:52 +02:00
parent 69d3794ff1
commit 6e4f6a1285
7 changed files with 370 additions and 1137 deletions

View file

@ -24,10 +24,10 @@ fn nodes() -> impl Iterator<Item = (usize, usize)> {
#[divan::bench(args = nodes())]
fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
fn join_no_overhead<A, B, RA, RB>(scope: &Scope<'_, '_>, a: A, b: B) -> (RA, RB)
fn join_no_overhead<A, B, RA, RB>(scope: Scope<'_, '_>, a: A, b: B) -> (RA, RB)
where
A: FnOnce(&Scope<'_, '_>) -> RA + Send,
B: FnOnce(&Scope<'_, '_>) -> RB + Send,
A: FnOnce(Scope<'_, '_>) -> RA + Send,
B: FnOnce(Scope<'_, '_>) -> RB + Send,
RA: Send,
RB: Send,
{
@ -35,7 +35,7 @@ fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
}
#[inline]
fn sum(node: &Node, scope: &Scope<'_, '_>) -> u64 {
fn sum(node: &Node, scope: Scope<'_, '_>) -> u64 {
let (left, right) = join_no_overhead(
scope,
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
@ -57,7 +57,7 @@ fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
#[divan::bench(args = nodes())]
fn distaff_overhead(bencher: Bencher, nodes: (usize, usize)) {
fn sum<'scope, 'env>(node: &Node, scope: &'scope Scope<'scope, 'env>) -> u64 {
fn sum<'scope, 'env>(node: &Node, scope: Scope<'scope, 'env>) -> u64 {
let (left, right) = scope.join(
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|s| node.right.as_deref().map(|n| sum(n, s)).unwrap_or_default(),

View file

@ -15,7 +15,6 @@ use crate::{
channel::{Parker, Sender},
heartbeat::HeartbeatList,
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
latch::NopLatch,
util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread},
};
@ -142,6 +141,7 @@ impl Context {
.iter()
.find(|(_, heartbeat)| heartbeat.is_waiting())
{
_ = i;
#[cfg(feature = "tracing")]
tracing::trace!("Notifying worker thread {} about job sharing", i);
sender.wake();
@ -160,15 +160,7 @@ impl Context {
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
// SAFETY: we are waiting on this latch in this thread.
let job = StackJob::new(
move || {
let worker = WorkerThread::current_ref()
.expect("WorkerThread::run_in_worker called outside of worker thread");
f(worker)
},
NopLatch,
);
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
let job = Job::from_stackjob(&job);
@ -188,15 +180,7 @@ impl Context {
// current thread isn't a worker thread, create job and inject into context
let parker = Parker::new();
let job = StackJob::new(
move || {
let worker = WorkerThread::current_ref()
.expect("WorkerThread::run_in_worker called outside of worker thread");
f(worker)
},
NopLatch,
);
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
let job = Job::from_stackjob(&job);
@ -247,7 +231,7 @@ impl Context {
where
F: FnOnce() + Send + 'static,
{
let job = Job::from_heapjob(Box::new(HeapJob::new(f)));
let job = Job::from_heapjob(HeapJob::new(|_: &WorkerThread| f()));
#[cfg(feature = "tracing")]
tracing::trace!("Context::spawn: spawning job: {:?}", job);
self.inject_job(job.share(None));
@ -364,19 +348,16 @@ mod tests {
let parker = Parker::new();
let job = StackJob::new(
{
let counter = counter.clone();
move || {
#[cfg(feature = "tracing")]
tracing::info!("Job running");
counter.fetch_add(1, Ordering::SeqCst);
let job = StackJob::new({
let counter = counter.clone();
move |_: &WorkerThread| {
#[cfg(feature = "tracing")]
tracing::info!("Job running");
counter.fetch_add(1, Ordering::SeqCst);
42
}
},
NopLatch,
);
42
}
});
let job = Job::from_stackjob(&job);

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,6 @@ use std::{hint::cold_path, sync::Arc};
use crate::{
context::Context,
job::{Job2 as Job, StackJob},
latch::NopLatch,
workerthread::WorkerThread,
};
@ -14,19 +13,19 @@ impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA,
B: FnOnce() -> RB,
A: FnOnce(&WorkerThread) -> RA,
B: FnOnce(&WorkerThread) -> RB,
{
let rb = b();
let ra = a();
let rb = b(self);
let ra = a(self);
(ra, rb)
}
pub(crate) fn join_heartbeat_every<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB,
A: FnOnce(&WorkerThread) -> RA + Send,
B: FnOnce(&WorkerThread) -> RB,
RA: Send,
{
// self.join_heartbeat_every_inner::<A, B, RA, RB, 2>(a, b)
@ -34,12 +33,13 @@ impl WorkerThread {
}
/// This function must be called from a worker thread.
#[allow(dead_code)]
#[inline(always)]
fn join_heartbeat_every_inner<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB,
A: FnOnce(&WorkerThread) -> RA + Send,
B: FnOnce(&WorkerThread) -> RB,
{
// SAFETY: each worker is only ever used by one thread, so this is safe.
let count = self.join_count.get();
@ -63,22 +63,22 @@ impl WorkerThread {
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB,
A: FnOnce(&WorkerThread) -> RA + Send,
B: FnOnce(&WorkerThread) -> RB,
{
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
#[cfg(feature = "metrics")]
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
let a = StackJob::new(a, NopLatch);
let a = StackJob::new(a);
let job = Job::from_stackjob(&a);
self.push_back(&job);
self.tick();
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
Ok(val) => val,
Err(payload) => {
#[cfg(feature = "tracing")]
@ -109,7 +109,7 @@ impl WorkerThread {
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
unsafe { a.unwrap()() }
unsafe { a.unwrap()(self) }
}
}
} else {
@ -119,7 +119,7 @@ impl WorkerThread {
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.unwrap()()
a.unwrap()(self)
}
};
@ -136,12 +136,13 @@ impl Context {
RB: Send,
{
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _>(a, b))
self.run_in_worker(move |worker| {
worker.join_heartbeat_every::<_, _, _, _>(|_| a(), |_| b())
})
}
}
/// run two closures potentially in parallel, in the global threadpool.
#[allow(dead_code)]
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,

View file

@ -1,10 +1,12 @@
use std::{
any::Any,
marker::PhantomData,
panic::{AssertUnwindSafe, catch_unwind},
pin::{self, Pin},
ptr::{self, NonNull},
sync::{
Arc,
atomic::{AtomicPtr, Ordering},
atomic::{AtomicPtr, AtomicUsize, Ordering},
},
};
@ -30,6 +32,102 @@ use crate::{
// - when a join() job finishes, it's latch is set
// - when we wait for a join() job, we loop over the latch until it is set
// a Scope must keep track of:
// - The number of async jobs spawned, which is used to determine when the scope
// is complete.
// - A panic box, which is set when a job panics and is used to resume the panic
// when the scope is completed.
// - The Parker of the worker on which the scope was created, which is signaled
// when the last outstanding async job finishes.
// - The current worker thread in order to avoid having to query the
// thread-local storage.
struct ScopeInner {
outstanding_jobs: AtomicUsize,
parker: NonNull<crate::channel::Parker>,
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
}
unsafe impl Send for ScopeInner {}
unsafe impl Sync for ScopeInner {}
#[derive(Clone, Copy)]
pub struct Scope<'scope, 'env: 'scope> {
inner: SendPtr<ScopeInner>,
worker: SendPtr<WorkerThread>,
_scope: PhantomData<&'scope mut &'scope ()>,
_env: PhantomData<&'env mut &'env ()>,
}
impl ScopeInner {
fn from_worker(worker: &WorkerThread) -> Self {
Self {
outstanding_jobs: AtomicUsize::new(0),
parker: worker.heartbeat.parker().into(),
panic: AtomicPtr::new(ptr::null_mut()),
}
}
fn increment(&self) {
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
}
fn decrement(&self) {
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
unsafe {
self.parker.as_ref().unpark();
}
}
}
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
unsafe {
let err = Box::into_raw(Box::new(err));
if !self
.panic
.compare_exchange(ptr::null_mut(), err, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
// someone else already set the panic, so we drop the error
_ = Box::from_raw(err);
}
}
}
fn maybe_propagate_panic(&self) {
let err = self.panic.swap(ptr::null_mut(), Ordering::AcqRel);
if err.is_null() {
return;
} else {
// SAFETY: we have exclusive access to the panic error, so we can safely resume it.
unsafe {
let err = *Box::from_raw(err);
std::panic::resume_unwind(err);
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn wait_for_jobs(&self) {
loop {
let count = self.outstanding_jobs.load(Ordering::Relaxed);
if count == 0 {
break;
}
#[cfg(feature = "tracing")]
tracing::trace!("waiting for {} jobs to finish.", count);
// wait until the parker is unparked
unsafe {
self.parker.as_ref().park();
}
// parking gives us AcqRel semantics.
}
}
}
// find below a sketch of an unbalanced tree:
// []
// / \
@ -52,7 +150,7 @@ use crate::{
// - another thread sharing a job
// - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share.
pub struct Scope<'scope, 'env: 'scope> {
pub struct Scope2<'scope, 'env: 'scope> {
// latch to wait on before the scope finishes
job_counter: CountLatch,
// local threadpool
@ -66,7 +164,7 @@ pub struct Scope<'scope, 'env: 'scope> {
pub fn scope<'env, F, R>(f: F) -> R
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
R: Send,
{
scope_with_context(Context::global_context(), f)
@ -74,43 +172,25 @@ where
pub fn scope_with_context<'env, F, R>(context: &Arc<Context>, f: F) -> R
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
R: Send,
{
context.run_in_worker(|worker| {
// SAFETY: we call complete() after creating this scope, which
// ensures that any jobs spawned from the scope exit before the
// scope closes.
let this = unsafe { Scope::from_context(context.clone()) };
this.complete(worker, || f(&this))
let inner = pin::pin!(ScopeInner::from_worker(worker));
let this = Scope::<'_, 'env>::new(worker, inner.as_ref());
this.complete(|| f(this))
})
}
impl<'scope, 'env> Scope<'scope, 'env> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn wait_for_jobs(&self, worker: &WorkerThread) {
self.job_counter.set_inner(worker.heartbeat.parker());
if self.job_counter.count() > 0 {
#[cfg(feature = "tracing")]
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
#[cfg(feature = "tracing")]
tracing::trace!(
"thread id: {:?}, jobs: {:?}",
worker.heartbeat.index(),
unsafe { worker.queue.as_ref_unchecked() }
);
// set worker index in the job counter
worker.wait_until_pred(|| self.job_counter.probe());
}
}
/// should be called from within a worker thread.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn complete<F, R>(&self, worker: &WorkerThread, f: F) -> R
fn complete<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
F: FnOnce() -> R,
{
use std::panic::{AssertUnwindSafe, catch_unwind};
@ -122,76 +202,90 @@ impl<'scope, 'env> Scope<'scope, 'env> {
}
};
self.wait_for_jobs(worker);
self.maybe_propagate_panic();
let inner = self.inner();
inner.wait_for_jobs();
inner.maybe_propagate_panic();
// SAFETY: if result panicked, we would have propagated the panic above.
result.unwrap()
}
/// resumes the panic if one happened in this scope.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn maybe_propagate_panic(&self) {
let err_ptr = self.panic.load(Ordering::Relaxed);
if !err_ptr.is_null() {
unsafe {
let err = Box::from_raw(err_ptr);
std::panic::resume_unwind(*err);
}
}
fn inner(&self) -> &ScopeInner {
unsafe { self.inner.as_ref() }
}
/// stores the first panic that happened in this scope.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
#[cfg(feature = "tracing")]
tracing::debug!("panicked in scope, storing error: {:?}", err);
self.panic.load(Ordering::Relaxed).is_null().then(|| {
use core::mem::ManuallyDrop;
let mut boxed = ManuallyDrop::new(Box::new(err));
let err_ptr: *mut Box<dyn Any + Send + 'static> = &mut **boxed;
if self
.panic
.compare_exchange(
ptr::null_mut(),
err_ptr,
Ordering::SeqCst,
Ordering::Relaxed,
)
.is_ok()
{
// we successfully set the panic, no need to drop
} else {
// drop the error, someone else already set it
_ = ManuallyDrop::into_inner(boxed);
}
});
self.inner().panicked(err);
}
pub fn spawn<F>(&'scope self, f: F)
pub fn spawn<F>(&self, f: F)
where
F: FnOnce(&'scope Self) + Send,
F: FnOnce(Self) + Send,
{
self.job_counter.increment();
let inner = self.inner;
let this = SendPtr::new_const(self).unwrap();
unsafe {
inner.as_ref().increment();
}
let job = Job::from_heapjob(Box::new(HeapJob::new(move || unsafe {
use std::panic::{AssertUnwindSafe, catch_unwind};
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(this.as_ref()))) {
this.as_unchecked_ref().panicked(payload);
struct SpawnedJob<F> {
f: F,
inner: SendPtr<ScopeInner>,
}
impl<F> SpawnedJob<F> {
fn new<'scope, 'env, T>(f: F, inner: SendPtr<ScopeInner>) -> Job
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
Job::from_harness(
Self::harness,
Box::into_non_null(Box::new(Self { f, inner })).cast(),
)
}
this.as_unchecked_ref().job_counter.decrement();
})));
self.context.inject_job(job.share(None));
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
_: Option<Sender>,
) where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
let Self { f, inner } =
unsafe { *Box::<SpawnedJob<F>>::from_non_null(this.cast()) };
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, inner) };
// SAFETY: we are in a worker thread, so the inner is valid.
(f)(scope);
unsafe { inner.as_ref().decrement() };
}
}
let job = SpawnedJob::new(
move |scope| {
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(scope))) {
scope.inner().panicked(payload);
}
scope.inner().decrement();
},
self.inner,
);
self.context().inject_job(job.share(None));
// WorkerThread::current_ref()
// .expect("spawn is run in workerthread.")
// .push_front(job.as_ptr());
}
pub fn spawn_future<T, F>(&'scope self, future: F) -> async_task::Task<T>
pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
where
F: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
@ -200,9 +294,9 @@ impl<'scope, 'env> Scope<'scope, 'env> {
}
#[allow(dead_code)]
pub fn spawn_async<T, Fut, Fn>(&'scope self, f: Fn) -> async_task::Task<T>
pub fn spawn_async<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope,
Fn: FnOnce(Self) -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
@ -210,25 +304,30 @@ impl<'scope, 'env> Scope<'scope, 'env> {
}
#[inline]
fn spawn_async_internal<T, Fut, Fn>(&'scope self, f: Fn) -> async_task::Task<T>
fn spawn_async_internal<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
where
Fn: FnOnce(&'scope Self) -> Fut + Send + 'scope,
Fn: FnOnce(Self) -> Fut + Send + 'scope,
Fut: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.job_counter.increment();
self.inner().increment();
let this = SendPtr::new_const(self).unwrap();
// let this = SendPtr::new_const(&self.job_counter).unwrap();
// TODO: make sure this worker lasts long enough for the
// reference to remain valid for the duration of the future.
let scope = unsafe { Self::new_unchecked(self.worker.as_ref(), self.inner) };
let future = async move {
// SAFETY: this is valid until we decrement the job counter.
unsafe {
let _guard = DropGuard::new(move || {
this.as_unchecked_ref().job_counter.decrement();
this.as_ref().inner().decrement();
});
// TODO: handle panics here
f(this.as_ref()).await
f(scope).await
}
};
@ -244,7 +343,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let job = Job::<()>::from_harness(harness, runnable.into_raw());
// casting into Job<()> here
self.context.inject_job(job.share(None));
self.context().inject_job(job.share(None));
// WorkerThread::current_ref()
// .expect("spawn_async_internal is run in workerthread.")
// .push_front(job);
@ -257,37 +356,140 @@ impl<'scope, 'env> Scope<'scope, 'env> {
task
}
pub fn join<A, B, RA, RB>(&'scope self, a: A, b: B) -> (RA, RB)
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
RA: Send,
RB: Send,
A: FnOnce(&'scope Self) -> RA + Send,
B: FnOnce(&'scope Self) -> RB + Send,
A: FnOnce(Self) -> RA + Send,
B: FnOnce(Self) -> RB,
{
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
let this = SendPtr::new_const(self).unwrap();
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
use std::{
cell::UnsafeCell,
mem::{self, ManuallyDrop},
};
worker.join_heartbeat_every::<_, _, _, _>(
let worker = self.worker();
struct ScopeJob<F> {
f: UnsafeCell<ManuallyDrop<F>>,
inner: SendPtr<ScopeInner>,
}
impl<F> ScopeJob<F> {
fn new(f: F, inner: SendPtr<ScopeInner>) -> Self {
Self {
f: UnsafeCell::new(ManuallyDrop::new(f)),
inner,
}
}
fn into_job<'scope, 'env, T>(&self) -> Job<T>
where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
let this = this;
move || a(unsafe { this.as_ref() })
},
Job::from_harness(Self::harness, NonNull::from(self).cast())
}
unsafe fn unwrap(&self) -> F {
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
}
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
sender: Option<Sender>,
) where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
T: Send,
{
let this = this;
move || b(unsafe { this.as_ref() })
},
)
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
let f = unsafe { this.unwrap() };
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
let sender: Sender<T> = unsafe { mem::transmute(sender) };
// SAFETY: we are in a worker thread, so the inner is valid.
sender.send(catch_unwind(AssertUnwindSafe(|| f(scope))));
}
}
let stack = ScopeJob::new(a, self.inner);
let job = ScopeJob::into_job(&stack);
worker.push_back(&job);
worker.tick();
let rb = match catch_unwind(AssertUnwindSafe(|| b(*self))) {
Ok(val) => val,
Err(payload) => {
#[cfg(feature = "tracing")]
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
std::hint::cold_path();
// if b panicked, we need to wait for a to finish
let mut receiver = job.take_receiver();
worker.wait_until_pred(|| match &receiver {
Some(recv) => recv.poll().is_some(),
None => {
receiver = job.take_receiver();
false
}
});
resume_unwind(payload);
}
};
let ra = if let Some(recv) = job.take_receiver() {
match worker.wait_until_recv(recv) {
Some(t) => crate::util::unwrap_or_panic(t),
None => {
#[cfg(feature = "tracing")]
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
unsafe { stack.unwrap()(*self) }
}
}
} else {
worker.pop_back();
unsafe {
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
stack.unwrap()(*self)
}
};
(ra, rb)
}
unsafe fn from_context(context: Arc<Context>) -> Self {
fn new(worker: &WorkerThread, inner: Pin<&'scope ScopeInner>) -> Self {
// SAFETY: we are creating a new scope, so the inner is valid.
unsafe { Self::new_unchecked(worker, SendPtr::new_const(&*inner).unwrap()) }
}
unsafe fn new_unchecked(worker: &WorkerThread, inner: SendPtr<ScopeInner>) -> Self {
Self {
context,
job_counter: CountLatch::new(ptr::null()),
panic: AtomicPtr::new(ptr::null_mut()),
inner,
worker: SendPtr::new_const(worker).unwrap(),
_scope: PhantomData,
_env: PhantomData,
}
}
pub fn context(&self) -> &Arc<Context> {
unsafe { &self.worker.as_ref().context }
}
pub fn worker(&self) -> &WorkerThread {
unsafe { self.worker.as_ref() }
}
}
#[cfg(test)]
@ -330,7 +532,7 @@ mod tests {
fn scope_join_many() {
let pool = ThreadPool::new_with_threads(1);
fn sum<'scope, 'env>(scope: &'scope Scope<'scope, 'env>, n: usize) -> usize {
fn sum<'scope, 'env>(scope: Scope<'scope, 'env>, n: usize) -> usize {
if n == 0 {
return 0;
}

View file

@ -32,7 +32,7 @@ impl ThreadPool {
pub fn scope<'env, F, R>(&self, f: F) -> R
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R + Send,
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
R: Send,
{
scope_with_context(&self.context, f)

View file

@ -94,8 +94,7 @@ impl<T> SendPtr<T> {
unsafe { Self::new_unchecked(ptr.cast_mut()) }
}
pub unsafe fn as_unchecked_ref(&self) -> &T {
// SAFETY: `self.0` is a valid non-null pointer.
pub(crate) unsafe fn as_ref(&self) -> &T {
unsafe { self.0.as_ref() }
}
}