todo: separate workerthread and scope logic, add scope type with lifetime

This commit is contained in:
Janis 2025-03-08 12:24:02 +01:00
parent b069f0cc87
commit 3458a900ee
2 changed files with 63 additions and 48 deletions

0
.cargo/config.toml Normal file
View file

View file

@ -3,7 +3,6 @@ mod util {
cell::UnsafeCell, cell::UnsafeCell,
marker::PhantomData, marker::PhantomData,
mem::ManuallyDrop, mem::ManuallyDrop,
num::NonZero,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
ptr::NonNull, ptr::NonNull,
sync::atomic::{AtomicPtr, Ordering}, sync::atomic::{AtomicPtr, Ordering},
@ -58,12 +57,23 @@ mod util {
} }
impl<T> SendPtr<T> { impl<T> SendPtr<T> {
pub fn new(ptr: *mut T) -> Option<Self> { pub const fn new(ptr: *mut T) -> Option<Self> {
NonNull::new(ptr).map(Self) match NonNull::new(ptr) {
Some(ptr) => Some(Self(ptr)),
None => None,
}
} }
pub unsafe fn new_unchecked(ptr: *mut T) -> Self { pub const unsafe fn new_unchecked(ptr: *mut T) -> Self {
unsafe { Self(NonNull::new_unchecked(ptr)) } unsafe { Self(NonNull::new_unchecked(ptr)) }
} }
pub const fn new_const(ptr: *const T) -> Option<Self> {
Self::new(ptr.cast_mut())
}
pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self {
Self::new_unchecked(ptr.cast_mut())
}
} }
// Miri doesn't like tagging pointers that it doesn't know the alignment of. // Miri doesn't like tagging pointers that it doesn't know the alignment of.
@ -174,7 +184,7 @@ mod util {
let ptr = ptr.cast::<()>(); let ptr = ptr.cast::<()>();
loop { loop {
let old = self.0.load(failure); let old = self.0.load(failure);
let new = ptr.with_addr((ptr.addr() & !mask) | (old.addr() & mask)); let new = ptr.map_addr(|addr| (addr & !mask) | (old.addr() & mask));
if self if self
.0 .0
.compare_exchange_weak(old, new, success, failure) .compare_exchange_weak(old, new, success, failure)
@ -603,10 +613,7 @@ mod job {
unsafe impl<T> Send for Job<T> {} unsafe impl<T> Send for Job<T> {}
impl<T> Job<T> { impl<T> Job<T> {
pub fn new( pub fn new(harness: unsafe fn(*const (), *const Job<T>), this: NonNull<()>) -> Job<T> {
harness: unsafe fn(*const (), *const Job<T>, &super::Scope),
this: NonNull<()>,
) -> Job<T> {
Self { Self {
harness_and_state: TaggedAtomicPtr::new( harness_and_state: TaggedAtomicPtr::new(
unsafe { mem::transmute(harness) }, unsafe { mem::transmute(harness) },
@ -748,19 +755,18 @@ mod job {
} }
} }
pub fn execute(job: NonNull<Self>, scope: &super::Scope) { pub fn execute(job: NonNull<Self>) {
// SAFETY: self is non-null // SAFETY: self is non-null
unsafe { unsafe {
let this = job.as_ref(); let this = job.as_ref();
let (ptr, state) = this.harness_and_state.ptr_and_tag(Ordering::Relaxed); let (ptr, state) = this.harness_and_state.ptr_and_tag(Ordering::Relaxed);
debug_assert_eq!(state, JobState::Pending as usize); debug_assert_eq!(state, JobState::Pending as usize);
let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) = let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr());
mem::transmute(ptr.as_ptr());
let this = (*this.val_or_this.get()).this; let this = (*this.val_or_this.get()).this;
harness(this.as_ptr().cast(), job.as_ptr(), scope); harness(this.as_ptr().cast(), job.as_ptr());
} }
} }
@ -829,19 +835,19 @@ mod job {
#[allow(dead_code)] #[allow(dead_code)]
pub fn into_boxed_job<T>(self: Box<Self>) -> Pin<Box<Job<()>>> pub fn into_boxed_job<T>(self: Box<Self>) -> Pin<Box<Job<()>>>
where where
F: FnOnce(&super::Scope) -> T + Send, F: FnOnce() -> T + Send,
T: Send, T: Send,
{ {
#[repr(align(8))] #[repr(align(8))]
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope) unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
where where
F: FnOnce(&super::Scope) -> T + Send, F: FnOnce() -> T + Send,
T: Sized + Send, T: Sized + Send,
{ {
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) }; let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
let f = this.f; let f = this.f;
_ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope))); _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
_ = unsafe { Box::from_raw(job.cast_mut()) }; _ = unsafe { Box::from_raw(job.cast_mut()) };
} }
@ -871,19 +877,19 @@ mod job {
pub fn as_job<T>(self: Pin<&Self>) -> Job<()> pub fn as_job<T>(self: Pin<&Self>) -> Job<()>
where where
F: FnOnce(&super::Scope) -> T + Send, F: FnOnce() -> T + Send,
T: Send, T: Send,
{ {
#[repr(align(8))] #[repr(align(8))]
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>, scope: &super::Scope) unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
where where
F: FnOnce(&super::Scope) -> T + Send, F: FnOnce() -> T + Send,
T: Sized + Send, T: Sized + Send,
{ {
let this = unsafe { &*this.cast::<StackJob<F>>() }; let this = unsafe { &*this.cast::<StackJob<F>>() };
let f = unsafe { this.unwrap() }; let f = unsafe { this.unwrap() };
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope))); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
let job_ref = unsafe { &*job.cast::<Job<T>>() }; let job_ref = unsafe { &*job.cast::<Job<T>>() };
job_ref.complete(result); job_ref.complete(result);
@ -900,6 +906,7 @@ use std::{
cell::{Cell, UnsafeCell}, cell::{Cell, UnsafeCell},
collections::BTreeMap, collections::BTreeMap,
future::Future, future::Future,
marker::PhantomData,
mem::{self, MaybeUninit}, mem::{self, MaybeUninit},
pin::{pin, Pin}, pin::{pin, Pin},
ptr::NonNull, ptr::NonNull,
@ -946,7 +953,14 @@ impl JobCounter {
} }
} }
pub struct Scope { struct WorkerThread {
context: Arc<Context>,
index: usize,
heartbeat: Arc<CachePadded<AtomicBool>>,
queue: UnsafeCell<JobList>,
}
pub struct Scope<'scope> {
join_count: Cell<usize>, join_count: Cell<usize>,
context: Arc<Context>, context: Arc<Context>,
index: usize, index: usize,
@ -954,13 +968,14 @@ pub struct Scope {
queue: UnsafeCell<JobList>, queue: UnsafeCell<JobList>,
job_counter: JobCounter, job_counter: JobCounter,
_pd: PhantomData<&'scope ()>,
} }
thread_local! { thread_local! {
static SCOPE: UnsafeCell<Option<NonNull<Scope>>> = const { UnsafeCell::new(None) }; static SCOPE: UnsafeCell<Option<NonNull<Scope<'static>>>> = const { UnsafeCell::new(None) };
} }
impl Scope { impl<'scope> Scope<'scope> {
/// locks shared context /// locks shared context
#[allow(dead_code)] #[allow(dead_code)]
fn new() -> Self { fn new() -> Self {
@ -979,6 +994,7 @@ impl Scope {
join_count: Cell::new(0), join_count: Cell::new(0),
queue: UnsafeCell::new(JobList::new()), queue: UnsafeCell::new(JobList::new()),
job_counter: JobCounter::default(), job_counter: JobCounter::default(),
_pd: PhantomData,
} }
} }
@ -1033,22 +1049,22 @@ impl Scope {
Self::with_in(Context::global(), f) Self::with_in(Context::global(), f)
} }
unsafe fn set_current(scope: *const Scope) { unsafe fn set_current(scope: *const Scope<'static>) {
SCOPE.with(|ptr| unsafe { SCOPE.with(|ptr| unsafe {
_ = (&mut *ptr.get()).insert(NonNull::new_unchecked(scope.cast_mut())); _ = (&mut *ptr.get()).insert(NonNull::new_unchecked(scope.cast_mut()));
}) })
} }
unsafe fn unset_current() -> Option<NonNull<Scope>> { unsafe fn unset_current() -> Option<NonNull<Scope<'static>>> {
SCOPE.with(|ptr| unsafe { (&mut *ptr.get()).take() }) SCOPE.with(|ptr| unsafe { (&mut *ptr.get()).take() })
} }
#[allow(dead_code)] #[allow(dead_code)]
fn current() -> Option<NonNull<Scope>> { fn current() -> Option<NonNull<Scope<'scope>>> {
SCOPE.with(|ptr| unsafe { *ptr.get() }) SCOPE.with(|ptr| unsafe { *ptr.get() })
} }
fn current_ref<'a>() -> Option<&'a Scope> { fn current_ref<'a>() -> Option<&'a Scope<'scope>> {
SCOPE.with(|ptr| unsafe { (&*ptr.get()).map(|ptr| ptr.as_ref()) }) SCOPE.with(|ptr| unsafe { (&*ptr.get()).map(|ptr| ptr.as_ref()) })
} }
@ -1084,19 +1100,17 @@ impl Scope {
} }
} }
pub fn spawn<'a, F>(&self, f: F) pub fn spawn<F>(&self, f: F)
where where
F: FnOnce(&Scope) + Send + 'a, F: FnOnce(&Scope<'scope>) + Send + 'scope,
{ {
self.job_counter.increment(); self.job_counter.increment();
let this = unsafe { let this = SendPtr::new_const(self).unwrap();
SendPtr::new_unchecked(&self.job_counter as *const JobCounter as *mut JobCounter)
};
let job = HeapJob::new(move |scope: &Scope| unsafe { let job = HeapJob::new(move || unsafe {
f(scope); f(this.as_ref());
this.as_ref().decrement(); this.as_ref().job_counter.decrement();
}) })
.into_boxed_job(); .into_boxed_job();
@ -1104,15 +1118,14 @@ impl Scope {
mem::forget(job); mem::forget(job);
} }
pub fn spawn_future<'a, T, F>(&'a self, future: F) -> async_task::Task<T> pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
where where
F: Future<Output = T> + Send + 'a, F: Future<Output = T> + Send + 'scope,
T: Send + 'a, T: Send + 'scope,
{ {
self.job_counter.increment(); self.job_counter.increment();
let this = let this = SendPtr::new_const(&self.job_counter).unwrap();
unsafe { SendPtr::new_unchecked(&raw const self.job_counter as *mut JobCounter) };
let future = async move { let future = async move {
let _guard = DropGuard::new(move || unsafe { let _guard = DropGuard::new(move || unsafe {
@ -1121,10 +1134,10 @@ impl Scope {
future.await future.await
}; };
let this = SendPtr::new(&raw const *self as *mut Self).unwrap(); let this = SendPtr::new_const(self).unwrap();
let schedule = move |runnable: Runnable| { let schedule = move |runnable: Runnable| {
#[repr(align(8))] #[repr(align(8))]
unsafe fn harness<T>(this: *const (), job: *const Job<T>, _: &Scope) { unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
runnable.run(); runnable.run();
@ -1153,7 +1166,7 @@ impl Scope {
Fut: Future<Output = T> + Send + 'static, Fut: Future<Output = T> + Send + 'static,
T: Send + 'static, T: Send + 'static,
{ {
let this = SendPtr::new(self as *const Self as *mut Self).unwrap(); let this = SendPtr::new_const(self).unwrap();
let future = async move { f(unsafe { this.as_ref() }).await }; let future = async move { f(unsafe { this.as_ref() }).await };
self.spawn_future(future) self.spawn_future(future)
@ -1209,7 +1222,9 @@ impl Scope {
A: FnOnce(&Self) -> RA + Send, A: FnOnce(&Self) -> RA + Send,
B: FnOnce(&Self) -> RB + Send, B: FnOnce(&Self) -> RB + Send,
{ {
let a = pin!(StackJob::new(move |scope: &Scope| { let this = SendPtr::new_const(self).unwrap();
let a = pin!(StackJob::new(move || unsafe {
let scope = this.as_ref();
scope.tick(); scope.tick();
a(scope) a(scope)
@ -1225,14 +1240,14 @@ impl Scope {
job.unlink(); job.unlink();
} }
unsafe { a.unwrap()(self) } unsafe { a.unwrap()() }
} else { } else {
match self.wait_until::<RA>(unsafe { match self.wait_until::<RA>(unsafe {
mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref()) mem::transmute::<Pin<&Job<()>>, Pin<&Job<RA>>>(job.as_ref())
}) { }) {
Some(Ok(t)) => t, Some(Ok(t)) => t,
Some(Err(payload)) => std::panic::resume_unwind(payload), Some(Err(payload)) => std::panic::resume_unwind(payload),
None => unsafe { a.unwrap()(self) }, None => unsafe { a.unwrap()() },
} }
}; };
@ -1250,7 +1265,7 @@ impl Scope {
#[inline] #[inline]
fn execute(&self, job: NonNull<Job>) { fn execute(&self, job: NonNull<Job>) {
self.tick(); self.tick();
Job::execute(job, self); Job::execute(job);
} }
#[cold] #[cold]