some sigsev somewhere idk where or why..........

This commit is contained in:
Janis 2025-07-04 15:43:13 +02:00
parent b635ea5579
commit 0836c7c958
7 changed files with 332 additions and 355 deletions

View file

@ -15,15 +15,22 @@ use crate::{
channel::{Parker, Sender},
heartbeat::HeartbeatList,
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
queue::ReceiverToken,
util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread},
};
pub struct Context {
shared: Mutex<Shared>,
pub shared_job: Condvar,
should_exit: AtomicBool,
pub heartbeats: HeartbeatList,
pub(crate) queue: Arc<crate::queue::Queue<Message>>,
}
pub(crate) enum Message {
Shared(SharedJob),
Finished(werkzeug::util::Send<NonNull<std::thread::Result<()>>>),
Exit,
ScopeFinished,
}
pub(crate) struct Shared {
@ -52,22 +59,14 @@ impl Shared {
}
impl Context {
pub(crate) fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
self.shared.lock()
}
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
#[cfg(feature = "tracing")]
tracing::trace!("Creating context with {} threads", num_threads);
let this = Arc::new(Self {
shared: Mutex::new(Shared {
jobs: BTreeMap::new(),
injected_jobs: Vec::new(),
}),
shared_job: Condvar::new(),
should_exit: AtomicBool::new(false),
heartbeats: HeartbeatList::new(),
queue: crate::queue::Queue::new(),
});
// Create a barrier to synchronize the worker threads and the heartbeat thread
@ -106,7 +105,7 @@ impl Context {
pub fn set_should_exit(&self) {
self.should_exit.store(true, Ordering::Relaxed);
self.heartbeats.notify_all();
self.queue.as_sender().broadcast_with(|| Message::Exit);
}
pub fn should_exit(&self) -> bool {
@ -124,31 +123,7 @@ impl Context {
}
pub fn inject_job(&self, job: SharedJob) {
let mut shared = self.shared.lock();
shared.injected_jobs.push(job);
unsafe {
// SAFETY: we are holding the shared lock, so it is safe to notify
self.notify_job_shared();
}
}
/// caller should hold the shared lock while calling this
pub unsafe fn notify_job_shared(&self) {
let heartbeats = self.heartbeats.inner();
if let Some((i, sender)) = heartbeats
.iter()
.find(|(_, heartbeat)| heartbeat.is_waiting())
.or_else(|| heartbeats.iter().next())
{
_ = i;
#[cfg(feature = "tracing")]
tracing::trace!("Notifying worker thread {} about job sharing", i);
sender.wake();
} else {
#[cfg(feature = "tracing")]
tracing::warn!("No worker found to notify about job sharing");
}
self.queue.as_sender().anycast(Message::Shared(job));
}
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
@ -164,9 +139,9 @@ impl Context {
let job = Job::from_stackjob(&job);
self.inject_job(job.share(Some(worker.heartbeat.parker())));
self.inject_job(job.share(Some(worker.receiver.get_token())));
let t = worker.wait_until_shared_job(&job).unwrap();
let t = worker.wait_until_shared_job(&job);
crate::util::unwrap_or_panic(t)
}
@ -178,17 +153,27 @@ impl Context {
T: Send,
{
// current thread isn't a worker thread, create job and inject into context
let parker = Parker::new();
let recv = self.queue.new_receiver();
let job = StackJob::new(move |worker: &WorkerThread| f(worker));
let job = Job::from_stackjob(&job);
self.inject_job(job.share(Some(&parker)));
self.inject_job(job.share(Some(recv.get_token())));
let recv = job.take_receiver().unwrap();
crate::util::unwrap_or_panic(recv.recv())
loop {
match recv.recv() {
Message::Finished(send) => {
break crate::util::unwrap_or_panic(unsafe {
*Box::from_non_null(send.0.cast::<std::thread::Result<T>>())
});
}
msg @ Message::Shared(_) => {
self.queue.as_sender().anycast(msg);
}
_ => {}
}
}
}
/// Run closure in this context.
@ -244,7 +229,7 @@ impl Context {
{
let schedule = move |runnable: Runnable| {
#[align(8)]
unsafe fn harness<T>(_: &WorkerThread, this: NonNull<()>, _: Option<Sender>) {
unsafe fn harness<T>(_: &WorkerThread, this: NonNull<()>, _: Option<ReceiverToken>) {
unsafe {
let runnable = Runnable::<()>::from_raw(this);
runnable.run();
@ -347,6 +332,7 @@ mod tests {
let counter = Arc::new(AtomicU8::new(0));
let parker = Parker::new();
let receiver = ctx.queue.new_receiver();
let job = StackJob::new({
let counter = counter.clone();
@ -372,12 +358,16 @@ mod tests {
assert!(heartbeat.is_waiting());
});
ctx.inject_job(job.share(Some(&parker)));
ctx.inject_job(job.share(Some(receiver.get_token())));
// Wait for the job to be executed
let recv = job.take_receiver().unwrap();
let result = recv.recv();
let result = crate::util::unwrap_or_panic(result);
assert!(job.is_shared());
let Message::Finished(werkzeug::util::Send(result)) = receiver.recv() else {
panic!("Expected a finished message");
};
let result = unsafe { *Box::from_non_null(result.cast()) };
let result = crate::util::unwrap_or_panic::<i32>(result);
assert_eq!(result, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}

View file

@ -11,6 +11,8 @@ use alloc::boxed::Box;
use crate::{
WorkerThread,
channel::{Parker, Sender},
context::Message,
queue::ReceiverToken,
};
#[repr(transparent)]
@ -44,13 +46,13 @@ impl<F> HeapJob<F> {
}
type JobHarness =
unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<crate::channel::Sender>);
unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<crate::queue::ReceiverToken>);
#[repr(C)]
pub struct Job2<T = ()> {
harness: JobHarness,
harness: Cell<Option<JobHarness>>,
this: NonNull<()>,
receiver: Cell<Option<crate::channel::Receiver<T>>>,
_phantom: core::marker::PhantomData<fn(T)>,
}
impl<T> Debug for Job2<T> {
@ -66,7 +68,7 @@ impl<T> Debug for Job2<T> {
pub struct SharedJob {
harness: JobHarness,
this: NonNull<()>,
sender: Option<crate::channel::Sender>,
sender: Option<crate::queue::ReceiverToken>,
}
unsafe impl Send for SharedJob {}
@ -74,9 +76,9 @@ unsafe impl Send for SharedJob {}
impl<T: Send> Job2<T> {
fn new(harness: JobHarness, this: NonNull<()>) -> Self {
let this = Self {
harness,
harness: Cell::new(Some(harness)),
this,
receiver: Cell::new(None),
_phantom: core::marker::PhantomData,
};
#[cfg(feature = "tracing")]
@ -85,25 +87,25 @@ impl<T: Send> Job2<T> {
this
}
pub fn share(&self, parker: Option<&Parker>) -> SharedJob {
pub fn share(&self, parker: Option<crate::queue::ReceiverToken>) -> SharedJob {
#[cfg(feature = "tracing")]
tracing::trace!("sharing job: {:?}", self);
let (sender, receiver) = parker
.map(|parker| crate::channel::channel::<T>(parker.into()))
.unzip();
// let (sender, receiver) = parker
// .map(|parker| crate::channel::channel::<T>(parker.into()))
// .unzip();
self.receiver.set(receiver);
// self.receiver.set(receiver);
SharedJob {
harness: self.harness,
harness: self.harness.take().unwrap(),
this: self.this,
sender: unsafe { mem::transmute(sender) },
sender: parker,
}
}
pub fn take_receiver(&self) -> Option<crate::channel::Receiver<T>> {
self.receiver.take()
pub fn is_shared(&self) -> bool {
self.harness.get().is_none()
}
pub fn from_stackjob<F>(job: &StackJob<F>) -> Self
@ -115,15 +117,17 @@ impl<T: Send> Job2<T> {
feature = "tracing",
tracing::instrument(level = "trace", skip_all, name = "stack_job_harness")
)]
unsafe fn harness<F, T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
where
unsafe fn harness<F, T>(
worker: &WorkerThread,
this: NonNull<()>,
sender: Option<ReceiverToken>,
) where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
use std::panic::{AssertUnwindSafe, catch_unwind};
let f = unsafe { this.cast::<StackJob<F>>().as_ref().unwrap() };
let sender: Sender<T> = unsafe { mem::transmute(sender) };
// #[cfg(feature = "metrics")]
// if worker.heartbeat.parker() == mutex {
@ -134,7 +138,18 @@ impl<T: Send> Job2<T> {
// tracing::trace!("job sent to self");
// }
sender.send(catch_unwind(AssertUnwindSafe(|| f(worker))));
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
if let Some(token) = sender {
worker.context.queue.as_sender().unicast(
Message::Finished(werkzeug::util::Send(
// SAFETY: T is guaranteed to be `Sized`, so
// `NonNull<T>` is the same size for any `T`.
Box::into_non_null(Box::new(result)).cast(),
)),
token,
);
}
}
Self::new(harness::<F, T>, NonNull::from(job).cast())
@ -149,8 +164,11 @@ impl<T: Send> Job2<T> {
feature = "tracing",
tracing::instrument(level = "trace", skip_all, name = "heap_job_harness")
)]
unsafe fn harness<F, T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
where
unsafe fn harness<F, T>(
worker: &WorkerThread,
this: NonNull<()>,
sender: Option<ReceiverToken>,
) where
F: FnOnce(&WorkerThread) -> T + Send,
T: Send,
{
@ -162,9 +180,15 @@ impl<T: Send> Job2<T> {
let f = unsafe { (*Box::from_non_null(this.cast::<HeapJob<F>>())).into_inner() };
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
if let Some(sender) = sender {
sender.send(result);
if let Some(token) = sender {
worker.context.queue.as_sender().unicast(
Message::Finished(werkzeug::util::Send(
// SAFETY: T is guaranteed to be `Sized`, so
// `NonNull<T>` is the same size for any `T`.
Box::into_non_null(Box::new(result)).cast(),
)),
token,
);
}
}
@ -179,10 +203,6 @@ impl<T: Send> Job2<T> {
pub fn from_harness(harness: JobHarness, this: NonNull<()>) -> Self {
Self::new(harness, this)
}
pub fn is_shared(&self) -> bool {
unsafe { (&*self.receiver.as_ptr()).is_some() }
}
}
impl SharedJob {

View file

@ -117,32 +117,16 @@ impl WorkerThread {
cold_path();
// if b panicked, we need to wait for a to finish
let mut receiver = job.take_receiver();
self.wait_until_pred(|| match &receiver {
Some(recv) => recv.poll().is_some(),
None => {
receiver = job.take_receiver();
false
}
});
if job.is_shared() {
_ = self.wait_until_recv::<RA>();
}
resume_unwind(payload);
}
};
let ra = if let Some(recv) = job.take_receiver() {
match self.wait_until_recv(recv) {
Some(t) => crate::util::unwrap_or_panic(t),
None => {
#[cfg(feature = "tracing")]
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
a.run_inline(self)
}
}
let ra = if job.is_shared() {
crate::util::unwrap_or_panic(self.wait_until_recv())
} else {
self.pop_back();
@ -183,41 +167,23 @@ impl WorkerThread {
cold_path();
// if b panicked, we need to wait for a to finish
let mut receiver = job.take_receiver();
self.wait_until_pred(|| match &receiver {
Some(recv) => recv.poll().is_some(),
None => {
receiver = job.take_receiver();
false
}
});
if job.is_shared() {
_ = self.wait_until_recv::<RA>();
}
resume_unwind(payload);
}
};
let ra = if let Some(recv) = job.take_receiver() {
match self.wait_until_recv(recv) {
Some(t) => crate::util::unwrap_or_panic(t),
None => {
#[cfg(feature = "tracing")]
tracing::trace!(
"join_heartbeat: job was shared, but reclaimed, running a() inline"
);
// the job was shared, but not yet stolen, so we get to run the
// job inline
unsafe { a.unwrap()(self) }
}
}
let ra = if job.is_shared() {
crate::util::unwrap_or_panic(self.wait_until_recv())
} else {
self.pop_back();
unsafe {
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.unwrap()(self)
}
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
#[cfg(feature = "tracing")]
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
a.run_inline(self)
};
(ra, rb)

View file

@ -11,7 +11,7 @@ use std::{
},
};
use crossbeam_utils::CachePadded;
use werkzeug::CachePadded;
use werkzeug::ptr::TaggedAtomicPtr;
@ -19,8 +19,7 @@ use werkzeug::ptr::TaggedAtomicPtr;
// After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it.
struct QueueInner<T> {
parked: HashSet<ReceiverToken>,
owned: HashMap<ReceiverToken, CachePadded<Slot<T>>>,
receivers: HashMap<ReceiverToken, CachePadded<(Slot<T>, bool)>>,
messages: Vec<T>,
_phantom: std::marker::PhantomData<T>,
}
@ -38,13 +37,13 @@ enum SlotKey {
Indexed(usize),
}
struct Receiver<T> {
pub struct Receiver<T> {
queue: Arc<Queue<T>>,
lock: Pin<Box<(AtomicU32, PhantomPinned)>>,
}
#[repr(transparent)]
struct Sender<T> {
pub struct Sender<T> {
queue: Arc<Queue<T>>,
}
@ -79,7 +78,7 @@ impl<T> Slot<T> {
.or_else(|| {
if self
.next_and_state
.swap_tag(0, Ordering::Acquire, Ordering::Relaxed)
.swap_tag(0, Ordering::AcqRel, Ordering::Relaxed)
== 1
{
// SAFETY: The value is only initialized when the state is set to 1.
@ -126,6 +125,12 @@ impl<T> Slot<T> {
Ok(_) => break next,
Err(other) => {
// next was allocated under us, so we need to drop the slot we just allocated again.
#[cfg(feature = "tracing")]
tracing::trace!(
"Slot::alloc_next: next was allocated under us, dropping it. ours: {:p}, other: {:p}",
next,
other
);
_ = unsafe { Box::from_raw(next) };
break other;
}
@ -171,15 +176,14 @@ impl<T> Drop for Slot<T> {
/// A token that can be used to identify a specific receiver in a queue.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct ReceiverToken(werkzeug::util::Send<*const u32>);
pub struct ReceiverToken(werkzeug::util::Send<NonNull<u32>>);
impl<T> Queue<T> {
pub fn new() -> Arc<Self> {
Arc::new(Self {
inner: UnsafeCell::new(QueueInner {
parked: HashSet::new(),
messages: Vec::new(),
owned: HashMap::new(),
receivers: HashMap::new(),
_phantom: PhantomData,
}),
lock: AtomicU32::new(0),
@ -207,8 +211,8 @@ impl<T> Queue<T> {
let _guard = recv.queue.lock();
recv.queue
.inner()
.owned
.insert(token, CachePadded::new(Slot::new()));
.receivers
.insert(token, CachePadded::new((Slot::new(), false)));
drop(_guard);
recv
@ -229,25 +233,27 @@ impl<T> Queue<T> {
}
impl<T> QueueInner<T> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn poll(&mut self, token: ReceiverToken) -> Option<T> {
// check if someone has sent a message to this receiver
let slot = self.owned.get(&token).unwrap();
let CachePadded((slot, _)) = self.receivers.get(&token)?;
unsafe { slot.pop() }.or_else(|| {
// if the slot is empty, we can check the indexed messages
#[cfg(feature = "tracing")]
tracing::trace!("QueueInner::poll: checking open messages");
self.messages.pop()
})
}
}
impl<T> Receiver<T> {
fn get_token(&self) -> ReceiverToken {
pub fn get_token(&self) -> ReceiverToken {
// the token is just the pointer to the lock of this receiver.
// the lock is pinned, so it's address is stable across calls to `receive`.
ReceiverToken(werkzeug::util::Send(
&self.lock.0 as *const AtomicU32 as *const u32,
))
ReceiverToken(werkzeug::util::Send(NonNull::from(&self.lock.0).cast()))
}
}
@ -259,12 +265,13 @@ impl<T> Drop for Receiver<T> {
let queue = self.queue.inner();
// remove the receiver from the queue
_ = queue.owned.remove(&self.get_token());
_ = queue.receivers.remove(&self.get_token());
}
}
}
impl<T: Send> Receiver<T> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn recv(&self) -> T {
let token = self.get_token();
@ -275,22 +282,23 @@ impl<T: Send> Receiver<T> {
// check if someone has sent a message to this receiver
if let Some(t) = queue.poll(token) {
queue.parked.remove(&token);
queue.receivers.get_mut(&token).unwrap().1 = false; // mark the slot as not parked
return t;
}
// there was no message for this receiver, so we need to park it
queue.parked.insert(token);
queue.receivers.get_mut(&token).unwrap().1 = true; // mark the slot as parked
// wait for a message to be sent to this receiver
drop(_guard);
unsafe {
let lock = werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut());
let lock = werkzeug::sync::Lock::from_ptr(token.0.into_inner().as_ptr());
lock.wait();
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn try_recv(&self) -> Option<T> {
let token = self.get_token();
@ -306,23 +314,30 @@ impl<T: Send> Receiver<T> {
impl<T: Send> Sender<T> {
/// Sends a message to one of the receivers in the queue, or makes it
/// available to any receiver that will park in the future.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn anycast(&self, value: T) {
// look for a receiver that is parked
let _guard = self.queue.lock();
let queue = self.queue.inner();
if let Some((token, slot)) = queue.parked.iter().find_map(|token| {
// ensure the slot is available
if let Some((token, slot)) =
queue
.owned
.get(token)
.and_then(|s| if !s.is_set() { Some((*token, s)) } else { None })
}) {
.receivers
.iter()
.find_map(|(token, CachePadded((slot, is_parked)))| {
// ensure the slot is available
if *is_parked && !slot.is_set() {
Some((*token, slot))
} else {
None
}
})
{
// we found a receiver that is parked, so we can send the message to it
unsafe {
slot.value.as_mut_unchecked().write(value);
slot.next_and_state
.set_tag(1, Ordering::Release, Ordering::Relaxed);
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
werkzeug::sync::Lock::from_ptr(token.0.into_inner().as_ptr()).wake_one();
}
return;
@ -336,12 +351,13 @@ impl<T: Send> Sender<T> {
}
/// Sends a message to a specific receiver, waking it if it is parked.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> {
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
let Some(slot) = queue.owned.get_mut(&receiver) else {
let Some(CachePadded((slot, is_parked))) = queue.receivers.get_mut(&receiver) else {
return Err(value);
};
@ -350,16 +366,17 @@ impl<T: Send> Sender<T> {
}
// check if the receiver is parked
if queue.parked.contains(&receiver) {
if *is_parked {
// wake the receiver
unsafe {
werkzeug::sync::Lock::from_ptr(receiver.0.into_inner().cast_mut()).wake_one();
werkzeug::sync::Lock::from_ptr(receiver.0.into_inner().as_ptr()).wake_one();
}
}
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn broadcast(&self, value: T)
where
T: Clone,
@ -369,16 +386,41 @@ impl<T: Send> Sender<T> {
let queue = self.queue.inner();
// send the message to all receivers
for (token, slot) in queue.owned.iter() {
for (token, CachePadded((slot, is_parked))) in queue.receivers.iter() {
// SAFETY: The slot is owned by this receiver.
unsafe { slot.push(value.clone()) };
// check if the receiver is parked
if queue.parked.contains(token) {
if *is_parked {
// wake the receiver
unsafe {
werkzeug::sync::Lock::from_ptr(token.0.into_inner().cast_mut()).wake_one();
werkzeug::sync::Lock::from_ptr(token.0.into_inner().as_ptr()).wake_one();
}
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn broadcast_with<F>(&self, mut f: F)
where
F: FnMut() -> T,
{
// lock the queue
let _guard = self.queue.lock();
let queue = self.queue.inner();
// send the message to all receivers
for (token, CachePadded((slot, is_parked))) in queue.receivers.iter() {
// SAFETY: The slot is owned by this receiver.
unsafe { slot.push(f()) };
// check if the receiver is parked
if *is_parked {
// wake the receiver
unsafe {
werkzeug::sync::Lock::from_ptr(token.0.into_inner().as_ptr()).wake_one();
}
}
}
@ -481,6 +523,47 @@ mod tests {
}
println!("All threads have exited.");
}
#[test]
fn drop_slot() {
// Test that dropping a slot does not cause a double free or panic
let slot = Slot::<i32>::new();
unsafe {
slot.push(42);
drop(slot);
}
}
#[test]
fn drop_slot_chain() {
struct DropCheck<'a>(&'a AtomicU32);
impl Drop for DropCheck<'_> {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::SeqCst);
}
}
impl<'a> DropCheck<'a> {
fn new(counter: &'a AtomicU32) -> Self {
counter.fetch_add(1, Ordering::SeqCst);
Self(counter)
}
}
let counter = AtomicU32::new(0);
let slot = Slot::<DropCheck>::new();
for _ in 0..10 {
unsafe {
slot.push(DropCheck::new(&counter));
}
}
assert_eq!(counter.load(Ordering::SeqCst), 10);
drop(slot);
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"All DropCheck instances should have been dropped"
);
}
}
// struct AtomicLIFO<T> {

View file

@ -11,15 +11,17 @@ use std::{
};
use async_task::Runnable;
use werkzeug::util;
use crate::{
channel::Sender,
context::Context,
context::{Context, Message},
job::{
HeapJob, Job2 as Job,
HeapJob, Job2 as Job, SharedJob,
traits::{InlineJob, IntoJob},
},
latch::{CountLatch, Probe},
queue::ReceiverToken,
util::{DropGuard, SendPtr},
workerthread::WorkerThread,
};
@ -47,7 +49,7 @@ use crate::{
struct ScopeInner {
outstanding_jobs: AtomicUsize,
parker: NonNull<crate::channel::Parker>,
parker: ReceiverToken,
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
}
@ -66,7 +68,7 @@ impl ScopeInner {
fn from_worker(worker: &WorkerThread) -> Self {
Self {
outstanding_jobs: AtomicUsize::new(0),
parker: worker.heartbeat.parker().into(),
parker: worker.receiver.get_token(),
panic: AtomicPtr::new(ptr::null_mut()),
}
}
@ -75,11 +77,13 @@ impl ScopeInner {
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
}
fn decrement(&self) {
fn decrement(&self, worker: &WorkerThread) {
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
unsafe {
self.parker.as_ref().unpark();
}
worker
.context
.queue
.as_sender()
.unicast(Message::ScopeFinished, self.parker);
}
}
@ -196,19 +200,38 @@ impl<'scope, 'env> Scope<'scope, 'env> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn wait_for_jobs(&self) {
#[cfg(feature = "tracing")]
tracing::trace!(
"waiting for {} jobs to finish.",
self.inner().outstanding_jobs.load(Ordering::Relaxed)
);
self.worker().wait_until_pred(|| {
// SAFETY: we are in a worker thread, so the inner is valid.
loop {
let count = self.inner().outstanding_jobs.load(Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!("waiting for {} jobs to finish.", count);
count == 0
});
if count == 0 {
break;
}
match self.worker().receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self.worker());
},
Message::Finished(util::Send(result)) => {
#[cfg(feature = "tracing")]
tracing::error!(
"received result when waiting for jobs to finish: {:p}.",
result
);
}
Message::Exit => {}
Message::ScopeFinished => {
#[cfg(feature = "tracing")]
tracing::trace!("scope finished, decrementing outstanding jobs.");
assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0);
break;
}
}
}
}
fn decrement(&self) {
self.inner().decrement(self.worker());
}
fn inner(&self) -> &ScopeInner {
@ -246,7 +269,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
_: Option<Sender>,
_: Option<ReceiverToken>,
) where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
@ -268,7 +291,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
scope.inner().panicked(payload);
}
scope.inner().decrement();
scope.decrement();
},
self.inner,
);
@ -309,7 +332,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let future = async move {
let _guard = DropGuard::new(move || {
scope.inner().decrement();
scope.decrement();
});
// TODO: handle panics here
@ -318,7 +341,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let schedule = move |runnable: Runnable| {
#[align(8)]
unsafe fn harness(_: &WorkerThread, this: NonNull<()>, _: Option<Sender>) {
unsafe fn harness(_: &WorkerThread, this: NonNull<()>, _: Option<ReceiverToken>) {
unsafe {
let runnable = Runnable::<()>::from_raw(this.cast());
runnable.run();
@ -384,7 +407,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
unsafe fn harness<'scope, 'env, T>(
worker: &WorkerThread,
this: NonNull<()>,
sender: Option<Sender>,
sender: Option<ReceiverToken>,
) where
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
'env: 'scope,
@ -393,10 +416,14 @@ impl<'scope, 'env> Scope<'scope, 'env> {
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
let f = unsafe { this.unwrap() };
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
let sender: Sender<T> = unsafe { mem::transmute(sender) };
// SAFETY: we are in a worker thread, so the inner is valid.
sender.send(catch_unwind(AssertUnwindSafe(|| f(scope))));
_ = worker.context.queue.as_sender().unicast(
Message::Finished(werkzeug::util::Send(
Box::into_non_null(Box::new(catch_unwind(AssertUnwindSafe(|| f(scope)))))
.cast(),
)),
sender.unwrap(),
);
}
}
@ -528,13 +555,18 @@ mod tests {
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
fn scope_join_one() {
let pool = ThreadPool::new_with_threads(1);
let count = AtomicU8::new(0);
let a = pool.scope(|scope| {
let (a, b) = scope.join(|_| 3 + 4, |_| 5 + 6);
let (a, b) = scope.join(
|_| count.fetch_add(1, Ordering::Relaxed) + 4,
|_| count.fetch_add(2, Ordering::Relaxed) + 6,
);
a + b
});
assert_eq!(a, 18);
assert_eq!(a, 12);
assert_eq!(count.load(Ordering::Relaxed), 3);
}
#[test]
@ -555,7 +587,7 @@ mod tests {
pool.scope(|scope| {
let total = sum(scope, 10);
assert_eq!(total, 1023);
// eprintln!("Total sum: {}", total);
eprintln!("Total sum: {}", total);
});
}

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use crate::{Scope, context::Context, scope::scope_with_context};
#[repr(transparent)]
pub struct ThreadPool {
pub(crate) context: Arc<Context>,
}
@ -9,7 +10,7 @@ pub struct ThreadPool {
impl Drop for ThreadPool {
fn drop(&mut self) {
// TODO: Ensure that the context is properly cleaned up when the thread pool is dropped.
// self.context.set_should_exit();
self.context.set_should_exit();
}
}
@ -25,9 +26,9 @@ impl ThreadPool {
Self { context }
}
pub fn global() -> Self {
let context = Context::global_context().clone();
Self { context }
pub fn global() -> &'static Self {
// SAFETY: ThreadPool is a transparent wrapper around Arc<Context>,
unsafe { core::mem::transmute(Context::global_context()) }
}
pub fn scope<'env, F, R>(&self, f: F) -> R

View file

@ -8,18 +8,21 @@ use std::{
time::Duration,
};
use crossbeam_utils::CachePadded;
#[cfg(feature = "metrics")]
use werkzeug::CachePadded;
use crate::{
channel::Receiver,
context::Context,
context::{Context, Message},
heartbeat::OwnedHeartbeatReceiver,
job::{Job2 as Job, JobQueue as JobList, SharedJob},
queue,
util::DropGuard,
};
pub struct WorkerThread {
pub(crate) context: Arc<Context>,
pub(crate) receiver: queue::Receiver<Message>,
pub(crate) queue: UnsafeCell<JobList>,
pub(crate) heartbeat: OwnedHeartbeatReceiver,
pub(crate) join_count: Cell<u8>,
@ -37,6 +40,7 @@ impl WorkerThread {
let heartbeat = context.heartbeats.new_heartbeat();
Self {
receiver: context.queue.new_receiver(),
context,
queue: UnsafeCell::new(JobList::new()),
heartbeat,
@ -82,85 +86,30 @@ impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
fn run_inner(&self) {
let mut job = None;
'outer: loop {
if let Some(job) = job.take() {
self.execute(job);
loop {
if self.context.should_exit() {
break;
}
// no more jobs, wait to be notified of a new job or a heartbeat.
while job.is_none() {
if self.context.should_exit() {
// if the context is stopped, break out of the outer loop which
// will exit the thread.
break 'outer;
match self.receiver.recv() {
Message::Shared(shared_job) => {
self.execute(shared_job);
}
job = self.find_work_or_wait();
Message::Finished(werkzeug::util::Send(ptr)) => {
#[cfg(feature = "tracing")]
tracing::error!(
"WorkerThread::run_inner: received finished message: {:?}",
ptr
);
}
Message::Exit => break,
Message::ScopeFinished => {}
}
}
}
}
impl WorkerThread {
/// Looks for work in the local queue, then in the shared context, and if no
/// work is found, waits for the thread to be notified of a new job, after
/// which it returns `None`.
/// The caller should then check for `should_exit` to determine if the
/// thread should exit, or look for work again.
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn find_work_or_wait(&self) -> Option<SharedJob> {
if let Some(job) = self.find_work() {
return Some(job);
}
#[cfg(feature = "tracing")]
tracing::trace!("waiting for new job");
self.heartbeat.parker().park();
#[cfg(feature = "tracing")]
tracing::trace!("woken up from wait");
None
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub(crate) fn find_work_or_wait_unless<F>(&self, mut pred: F) -> Option<SharedJob>
where
F: FnMut() -> bool,
{
if let Some(job) = self.find_work() {
return Some(job);
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Check the predicate while holding the lock. This is very important,
// because the lock must be held when notifying us of the result of a
// job we scheduled.
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// no jobs found, wait for a heartbeat or a new job
#[cfg(feature = "tracing")]
tracing::trace!(worker = self.heartbeat.index(), "waiting for new job");
if !pred() {
self.heartbeat.parker().park();
}
#[cfg(feature = "tracing")]
tracing::trace!(worker = self.heartbeat.index(), "woken up from wait");
None
}
fn find_work(&self) -> Option<SharedJob> {
let mut guard = self.context.shared();
if let Some(job) = guard.pop_job() {
#[cfg(feature = "metrics")]
self.metrics.num_jobs_stolen.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "tracing")]
tracing::trace!("WorkerThread::find_work_inner: found shared job: {:?}", job);
return Some(job);
}
None
}
pub(crate) fn tick(&self) {
if self.heartbeat.take() {
#[cfg(feature = "metrics")]
@ -182,25 +131,19 @@ impl WorkerThread {
#[cold]
fn heartbeat_cold(&self) {
let mut guard = self.context.shared();
if let Some(job) = self.pop_back() {
#[cfg(feature = "tracing")]
tracing::trace!("heartbeat: sharing job: {:?}", job);
if !guard.jobs.contains_key(&self.heartbeat.id()) {
if let Some(job) = self.pop_back() {
#[cfg(feature = "tracing")]
tracing::trace!("heartbeat: sharing job: {:?}", job);
#[cfg(feature = "metrics")]
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
unsafe {
guard.jobs.insert(
self.heartbeat.id(),
job.as_ref().share(Some(self.heartbeat.parker())),
);
// SAFETY: we are holding the lock on the shared context.
self.context.notify_job_shared();
}
}
self.context
.queue
.as_sender()
.anycast(Message::Shared(unsafe {
job.as_ref().share(Some(self.receiver.get_token()))
}));
}
}
}
@ -307,87 +250,29 @@ impl HeartbeatThread {
impl WorkerThread {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> Option<std::thread::Result<T>> {
let recv = (*job).take_receiver().unwrap();
let mut out = recv.poll();
while std::hint::unlikely(out.is_none()) {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
pub fn wait_until_shared_job<T: Send>(&self, job: &Job<T>) -> std::thread::Result<T> {
loop {
match self.receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self);
},
Message::Finished(send) => {
break unsafe { *Box::from_non_null(send.0.cast()) };
}
Message::Exit | Message::ScopeFinished => {}
}
out = recv.poll();
}
out
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> Option<std::thread::Result<T>> {
if self
.context
.shared()
.jobs
.remove(&self.heartbeat.id())
.is_some()
{
#[cfg(feature = "tracing")]
tracing::trace!("reclaiming shared job");
return None;
}
while recv.is_empty() {
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
continue;
}
recv.wait();
}
Some(recv.recv())
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
pub fn wait_until_pred<F>(&self, mut pred: F)
where
F: FnMut() -> bool,
{
if !pred() {
#[cfg(feature = "tracing")]
tracing::trace!("thread {:?} waiting on predicate", self.heartbeat.index());
self.wait_until_latch_cold(pred);
}
}
#[cold]
fn wait_until_latch_cold<F>(&self, mut pred: F)
where
F: FnMut() -> bool,
{
if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
#[cfg(feature = "tracing")]
tracing::trace!(
"thread {:?} reclaiming shared job: {:?}",
self.heartbeat.index(),
shared_job
);
unsafe { SharedJob::execute(shared_job, self) };
}
// do the usual thing and wait for the job's latch
// do the usual thing??? chatgipity really said this..
while !pred() {
// check local jobs before locking shared context
if let Some(job) = self.find_work() {
unsafe {
SharedJob::execute(job, self);
}
pub fn wait_until_recv<T: Send>(&self) -> std::thread::Result<T> {
loop {
match self.receiver.recv() {
Message::Shared(shared_job) => unsafe {
SharedJob::execute(shared_job, self);
},
Message::Finished(send) => break unsafe { *Box::from_non_null(send.0.cast()) },
Message::Exit | Message::ScopeFinished => {}
}
}
}