diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..e69de29 diff --git a/src/praetor/mod.rs b/src/praetor/mod.rs index 5e74683..03b5b90 100644 --- a/src/praetor/mod.rs +++ b/src/praetor/mod.rs @@ -3,7 +3,6 @@ mod util { cell::UnsafeCell, marker::PhantomData, mem::ManuallyDrop, - num::NonZero, ops::{Deref, DerefMut}, ptr::NonNull, sync::atomic::{AtomicPtr, Ordering}, @@ -58,12 +57,23 @@ mod util { } impl SendPtr { - pub fn new(ptr: *mut T) -> Option { - NonNull::new(ptr).map(Self) + pub const fn new(ptr: *mut T) -> Option { + match NonNull::new(ptr) { + Some(ptr) => Some(Self(ptr)), + None => None, + } } - pub unsafe fn new_unchecked(ptr: *mut T) -> Self { + pub const unsafe fn new_unchecked(ptr: *mut T) -> Self { unsafe { Self(NonNull::new_unchecked(ptr)) } } + + pub const fn new_const(ptr: *const T) -> Option { + Self::new(ptr.cast_mut()) + } + + pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self { + Self::new_unchecked(ptr.cast_mut()) + } } // Miri doesn't like tagging pointers that it doesn't know the alignment of. @@ -174,7 +184,7 @@ mod util { let ptr = ptr.cast::<()>(); loop { let old = self.0.load(failure); - let new = ptr.with_addr((ptr.addr() & !mask) | (old.addr() & mask)); + let new = ptr.map_addr(|addr| (addr & !mask) | (old.addr() & mask)); if self .0 .compare_exchange_weak(old, new, success, failure) @@ -603,10 +613,7 @@ mod job { unsafe impl Send for Job {} impl Job { - pub fn new( - harness: unsafe fn(*const (), *const Job, &super::Scope), - this: NonNull<()>, - ) -> Job { + pub fn new(harness: unsafe fn(*const (), *const Job), this: NonNull<()>) -> Job { Self { harness_and_state: TaggedAtomicPtr::new( unsafe { mem::transmute(harness) }, @@ -748,19 +755,18 @@ mod job { } } - pub fn execute(job: NonNull, scope: &super::Scope) { + pub fn execute(job: NonNull) { // SAFETY: self is non-null unsafe { let this = job.as_ref(); let (ptr, state) = this.harness_and_state.ptr_and_tag(Ordering::Relaxed); debug_assert_eq!(state, JobState::Pending as usize); - let harness: unsafe fn(*const (), *const Self, scope: &super::Scope) = - mem::transmute(ptr.as_ptr()); + let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr()); let this = (*this.val_or_this.get()).this; - harness(this.as_ptr().cast(), job.as_ptr(), scope); + harness(this.as_ptr().cast(), job.as_ptr()); } } @@ -829,19 +835,19 @@ mod job { #[allow(dead_code)] pub fn into_boxed_job(self: Box) -> Pin>> where - F: FnOnce(&super::Scope) -> T + Send, + F: FnOnce() -> T + Send, T: Send, { #[repr(align(8))] - unsafe fn harness(this: *const (), job: *const Job<()>, scope: &super::Scope) + unsafe fn harness(this: *const (), job: *const Job<()>) where - F: FnOnce(&super::Scope) -> T + Send, + F: FnOnce() -> T + Send, T: Sized + Send, { let this = unsafe { Box::from_raw(this.cast::>().cast_mut()) }; let f = this.f; - _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope))); + _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); _ = unsafe { Box::from_raw(job.cast_mut()) }; } @@ -871,19 +877,19 @@ mod job { pub fn as_job(self: Pin<&Self>) -> Job<()> where - F: FnOnce(&super::Scope) -> T + Send, + F: FnOnce() -> T + Send, T: Send, { #[repr(align(8))] - unsafe fn harness(this: *const (), job: *const Job<()>, scope: &super::Scope) + unsafe fn harness(this: *const (), job: *const Job<()>) where - F: FnOnce(&super::Scope) -> T + Send, + F: FnOnce() -> T + Send, T: Sized + Send, { let this = unsafe { &*this.cast::>() }; let f = unsafe { this.unwrap() }; - let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(scope))); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f())); let job_ref = unsafe { &*job.cast::>() }; job_ref.complete(result); @@ -900,6 +906,7 @@ use std::{ cell::{Cell, UnsafeCell}, collections::BTreeMap, future::Future, + marker::PhantomData, mem::{self, MaybeUninit}, pin::{pin, Pin}, ptr::NonNull, @@ -946,7 +953,14 @@ impl JobCounter { } } -pub struct Scope { +struct WorkerThread { + context: Arc, + index: usize, + heartbeat: Arc>, + queue: UnsafeCell, +} + +pub struct Scope<'scope> { join_count: Cell, context: Arc, index: usize, @@ -954,13 +968,14 @@ pub struct Scope { queue: UnsafeCell, job_counter: JobCounter, + _pd: PhantomData<&'scope ()>, } thread_local! { - static SCOPE: UnsafeCell>> = const { UnsafeCell::new(None) }; + static SCOPE: UnsafeCell>>> = const { UnsafeCell::new(None) }; } -impl Scope { +impl<'scope> Scope<'scope> { /// locks shared context #[allow(dead_code)] fn new() -> Self { @@ -979,6 +994,7 @@ impl Scope { join_count: Cell::new(0), queue: UnsafeCell::new(JobList::new()), job_counter: JobCounter::default(), + _pd: PhantomData, } } @@ -1033,22 +1049,22 @@ impl Scope { Self::with_in(Context::global(), f) } - unsafe fn set_current(scope: *const Scope) { + unsafe fn set_current(scope: *const Scope<'static>) { SCOPE.with(|ptr| unsafe { _ = (&mut *ptr.get()).insert(NonNull::new_unchecked(scope.cast_mut())); }) } - unsafe fn unset_current() -> Option> { + unsafe fn unset_current() -> Option>> { SCOPE.with(|ptr| unsafe { (&mut *ptr.get()).take() }) } #[allow(dead_code)] - fn current() -> Option> { + fn current() -> Option>> { SCOPE.with(|ptr| unsafe { *ptr.get() }) } - fn current_ref<'a>() -> Option<&'a Scope> { + fn current_ref<'a>() -> Option<&'a Scope<'scope>> { SCOPE.with(|ptr| unsafe { (&*ptr.get()).map(|ptr| ptr.as_ref()) }) } @@ -1084,19 +1100,17 @@ impl Scope { } } - pub fn spawn<'a, F>(&self, f: F) + pub fn spawn(&self, f: F) where - F: FnOnce(&Scope) + Send + 'a, + F: FnOnce(&Scope<'scope>) + Send + 'scope, { self.job_counter.increment(); - let this = unsafe { - SendPtr::new_unchecked(&self.job_counter as *const JobCounter as *mut JobCounter) - }; + let this = SendPtr::new_const(self).unwrap(); - let job = HeapJob::new(move |scope: &Scope| unsafe { - f(scope); - this.as_ref().decrement(); + let job = HeapJob::new(move || unsafe { + f(this.as_ref()); + this.as_ref().job_counter.decrement(); }) .into_boxed_job(); @@ -1104,15 +1118,14 @@ impl Scope { mem::forget(job); } - pub fn spawn_future<'a, T, F>(&'a self, future: F) -> async_task::Task + pub fn spawn_future(&self, future: F) -> async_task::Task where - F: Future + Send + 'a, - T: Send + 'a, + F: Future + Send + 'scope, + T: Send + 'scope, { self.job_counter.increment(); - let this = - unsafe { SendPtr::new_unchecked(&raw const self.job_counter as *mut JobCounter) }; + let this = SendPtr::new_const(&self.job_counter).unwrap(); let future = async move { let _guard = DropGuard::new(move || unsafe { @@ -1121,10 +1134,10 @@ impl Scope { future.await }; - let this = SendPtr::new(&raw const *self as *mut Self).unwrap(); + let this = SendPtr::new_const(self).unwrap(); let schedule = move |runnable: Runnable| { #[repr(align(8))] - unsafe fn harness(this: *const (), job: *const Job, _: &Scope) { + unsafe fn harness(this: *const (), job: *const Job) { let runnable = Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut())); runnable.run(); @@ -1153,7 +1166,7 @@ impl Scope { Fut: Future + Send + 'static, T: Send + 'static, { - let this = SendPtr::new(self as *const Self as *mut Self).unwrap(); + let this = SendPtr::new_const(self).unwrap(); let future = async move { f(unsafe { this.as_ref() }).await }; self.spawn_future(future) @@ -1209,7 +1222,9 @@ impl Scope { A: FnOnce(&Self) -> RA + Send, B: FnOnce(&Self) -> RB + Send, { - let a = pin!(StackJob::new(move |scope: &Scope| { + let this = SendPtr::new_const(self).unwrap(); + let a = pin!(StackJob::new(move || unsafe { + let scope = this.as_ref(); scope.tick(); a(scope) @@ -1225,14 +1240,14 @@ impl Scope { job.unlink(); } - unsafe { a.unwrap()(self) } + unsafe { a.unwrap()() } } else { match self.wait_until::(unsafe { mem::transmute::>, Pin<&Job>>(job.as_ref()) }) { Some(Ok(t)) => t, Some(Err(payload)) => std::panic::resume_unwind(payload), - None => unsafe { a.unwrap()(self) }, + None => unsafe { a.unwrap()() }, } }; @@ -1250,7 +1265,7 @@ impl Scope { #[inline] fn execute(&self, job: NonNull) { self.tick(); - Job::execute(job, self); + Job::execute(job); } #[cold]