sdkkkkkkkkk

This commit is contained in:
Janis 2025-06-28 16:25:19 +02:00
parent 2a0372a8a0
commit 9cc125e558
8 changed files with 101 additions and 75 deletions

View file

@ -56,7 +56,7 @@ mod tree {
}
}
const TREE_SIZE: usize = 16;
const TREE_SIZE: usize = 8;
#[bench]
fn join_melange(b: &mut Bencher) {
@ -210,10 +210,12 @@ fn join_distaff(b: &mut Bencher) {
}
b.iter(move || {
pool.scope(|s| {
let sum = pool.scope(|s| {
let sum = sum(&tree, tree.root().unwrap(), s);
// eprintln!("{sum}");
assert_ne!(sum, 0);
sum
});
eprintln!("{sum}");
});
eprintln!("Done with distaff join");
}

View file

@ -16,6 +16,7 @@ use crate::{
heartbeat::HeartbeatList,
job::{HeapJob, JobSender, QueuedJob as Job, StackJob},
latch::{AsCoreLatch, MutexLatch, NopLatch, WorkerLatch},
util::DropGuard,
workerthread::{HeartbeatThread, WorkerThread},
};
@ -80,6 +81,8 @@ impl Context {
}
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
tracing::trace!("Creating context with {} threads", num_threads);
let this = Arc::new(Self {
shared: Mutex::new(Shared {
jobs: BTreeMap::new(),
@ -90,8 +93,6 @@ impl Context {
heartbeats: HeartbeatList::new(),
});
tracing::trace!("Creating thread pool with {} threads", num_threads);
// Create a barrier to synchronize the worker threads and the heartbeat thread
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2));
@ -104,8 +105,7 @@ impl Context {
.spawn(move || {
let worker = Box::new(WorkerThread::new_in(ctx));
barrier.wait();
worker.run();
worker.run(barrier);
})
.expect("Failed to spawn worker thread");
}
@ -117,8 +117,7 @@ impl Context {
std::thread::Builder::new()
.name("heartbeat-thread".to_string())
.spawn(move || {
barrier.wait();
HeartbeatThread::new(ctx).run();
HeartbeatThread::new(ctx).run(barrier);
})
.expect("Failed to spawn heartbeat thread");
}
@ -234,6 +233,9 @@ impl Context {
T: Send,
F: FnOnce(&WorkerThread) -> T + Send,
{
let _guard = DropGuard::new(|| {
tracing::trace!("run_in_worker: finished");
});
match WorkerThread::current_ref() {
Some(worker) => {
// check if worker is in the same context

View file

@ -1054,6 +1054,7 @@ const ERROR: usize = 1 << 1;
impl<T> JobSender<T> {
#[tracing::instrument(level = "trace", skip_all)]
pub fn send(&self, result: std::thread::Result<T>, mutex: *const WorkerLatch) {
tracing::trace!("sending job ({:?}) result", &raw const *self);
// We want to lock here so that we can be sure that we wake the worker
// only if it was waiting, and not immediately after having received the
// result and waiting for further work:

View file

@ -460,7 +460,7 @@ impl WorkerLatch {
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other");
// because `other` is logically unlocked, we swap it with `other2` and then forget `other2`
core::mem::swap(&mut *other2, &mut *other);
core::mem::swap(&mut other2, other);
core::mem::forget(other2);
let mut guard = self.mutex.lock();
@ -546,11 +546,12 @@ mod tests {
tracing::info!("Thread waiting on latch");
latch.wait_with_lock(&mut guard);
count.fetch_add(1, Ordering::Relaxed);
count.fetch_add(1, Ordering::SeqCst);
tracing::info!("Thread woke up from latch");
barrier.wait();
barrier.wait();
tracing::info!("Thread finished waiting on barrier");
count.fetch_add(1, Ordering::Relaxed);
count.fetch_add(1, Ordering::SeqCst);
}
});
@ -566,17 +567,18 @@ mod tests {
latch.wake();
tracing::info!("Main thread woke up thread");
}
assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0");
assert_eq!(count.load(Ordering::SeqCst), 0, "Count should still be 0");
barrier.wait();
assert_eq!(
count.load(Ordering::Relaxed),
count.load(Ordering::SeqCst),
1,
"Count should be 1 after waking up"
);
barrier.wait();
thread.join().expect("Thread should join successfully");
assert_eq!(
count.load(Ordering::Relaxed),
count.load(Ordering::SeqCst),
2,
"Count should be 2 after thread has finished"
);
@ -645,7 +647,7 @@ mod tests {
}
#[test]
#[traced_test]
#[cfg_attr(not(miri), traced_test)]
fn mutex_latch() {
let latch = Arc::new(MutexLatch::new());
assert!(!latch.probe());

View file

@ -8,6 +8,7 @@
box_as_ptr,
box_vec_non_null,
strict_provenance_atomic_ptr,
likely_unlikely,
let_chains
)]

View file

@ -184,13 +184,10 @@ impl<'scope, 'env> Scope<'scope, 'env> {
ptr::null(),
);
tracing::trace!("allocated heapjob");
WorkerThread::current_ref()
.expect("spawn is run in workerthread.")
.push_front(job.as_ptr());
tracing::trace!("leaked heapjob");
self.context.inject_job(job);
// WorkerThread::current_ref()
// .expect("spawn is run in workerthread.")
// .push_front(job.as_ptr());
}
pub fn spawn_future<T, F>(&'scope self, future: F) -> async_task::Task<T>
@ -247,16 +244,17 @@ impl<'scope, 'env> Scope<'scope, 'env> {
}
}
let job = Box::into_raw(Box::new(Job::from_harness(
let job = Box::into_non_null(Box::new(Job::from_harness(
harness,
runnable.into_raw(),
ptr::null(),
)));
// casting into Job<()> here
WorkerThread::current_ref()
.expect("spawn_async_internal is run in workerthread.")
.push_front(job);
self.context.inject_job(job);
// WorkerThread::current_ref()
// .expect("spawn_async_internal is run in workerthread.")
// .push_front(job);
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };

View file

@ -2,7 +2,7 @@ use std::{
cell::{Cell, UnsafeCell},
hint::cold_path,
ptr::NonNull,
sync::Arc,
sync::{Arc, Barrier},
time::Duration,
};
@ -42,7 +42,7 @@ impl WorkerThread {
impl WorkerThread {
#[tracing::instrument(level = "trace", skip_all)]
pub fn run(self: Box<Self>) {
pub fn run(self: Box<Self>, barrier: Arc<Barrier>) {
let this = Box::into_raw(self);
unsafe {
Self::set_current(this);
@ -56,6 +56,7 @@ impl WorkerThread {
tracing::trace!("WorkerThread::run: starting worker thread");
barrier.wait();
unsafe {
(&*this).run_inner();
}
@ -106,6 +107,34 @@ impl WorkerThread {
tracing::trace!("WorkerThread::find_work_or_wait: waiting for new job");
self.heartbeat.latch().wait_with_lock(&mut guard);
tracing::trace!("WorkerThread::find_work_or_wait: woken up from wait");
None
}
}
}
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn find_work_or_wait_unless<F>(&self, mut pred: F) -> Option<NonNull<Job>>
where
F: FnMut(&mut crate::context::Shared) -> bool,
{
match self.find_work_inner() {
either::Either::Left(job) => {
return Some(job);
}
either::Either::Right(mut guard) => {
// check the predicate while holding the lock
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// this is very important, because the lock must be held when
// notifying us of the result of a job we scheduled.
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
if !pred(std::ops::DerefMut::deref_mut(&mut guard)) {
// no jobs found, wait for a heartbeat or a new job
tracing::trace!("WorkerThread::find_work_or_wait_unless: waiting for new job");
self.heartbeat.latch().wait_with_lock(&mut guard);
tracing::trace!("WorkerThread::find_work_or_wait_unless: woken up from wait");
}
None
}
}
@ -146,8 +175,8 @@ impl WorkerThread {
#[inline]
#[tracing::instrument(level = "trace", skip(self))]
fn execute(&self, job: NonNull<Job>) {
self.tick();
unsafe { Job::execute(job.as_ptr()) };
self.tick();
}
#[cold]
@ -243,8 +272,9 @@ impl HeartbeatThread {
}
#[tracing::instrument(level = "trace", skip(self))]
pub fn run(self) {
pub fn run(self, barrier: Arc<Barrier>) {
tracing::trace!("new heartbeat thread {:?}", std::thread::current());
barrier.wait();
let mut i = 0;
loop {
@ -282,6 +312,10 @@ impl WorkerThread {
// we've already checked that the job was popped from the queue
// check if shared job is our job
// skip checking if the job hasn't yet been claimed, because the
// overhead of waking a thread is so much bigger that it might never get
// the chance to actually claim it.
// if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
// if core::ptr::eq(shared_job.as_ptr(), job as *const Job as _) {
// // this is the job we are looking for, so we want to
@ -306,34 +340,22 @@ impl WorkerThread {
// }
// }
loop {
match recv.poll() {
Some(t) => {
return Some(t);
}
None => {
cold_path();
let mut out = recv.poll();
// check local jobs before locking shared context
if let Some(job) = self.find_work_or_wait() {
tracing::trace!(
"thread {:?} executing local job: {:?}",
self.heartbeat.index(),
job
);
unsafe {
Job::execute(job.as_ptr());
}
tracing::trace!(
"thread {:?} finished local job: {:?}",
self.heartbeat.index(),
job
);
continue;
}
while std::hint::unlikely(out.is_none()) {
if let Some(job) = self.find_work_or_wait_unless(|_| {
out = recv.poll();
out.is_some()
}) {
unsafe {
Job::execute(job.as_ptr());
}
}
out = recv.poll();
}
out
}
#[tracing::instrument(level = "trace", skip_all)]
@ -365,20 +387,10 @@ impl WorkerThread {
// do the usual thing??? chatgipity really said this..
while !latch.probe() {
// check local jobs before locking shared context
if let Some(job) = self.find_work_or_wait() {
tracing::trace!(
"thread {:?} executing local job: {:?}",
self.heartbeat.index(),
job
);
if let Some(job) = self.find_work_or_wait_unless(|_| latch.probe()) {
unsafe {
Job::execute(job.as_ptr());
}
tracing::trace!(
"thread {:?} finished local job: {:?}",
self.heartbeat.index(),
job
);
continue;
}
}

View file

@ -62,7 +62,8 @@ fn join_pool(tree_size: usize) {
fn join_distaff(tree_size: usize) {
use distaff::*;
let pool = ThreadPool::new();
let pool = ThreadPool::new_with_threads(6);
let tree = Tree::new(tree_size, 1);
fn sum<'scope, 'env>(tree: &Tree<u32>, node: usize, scope: &'scope Scope<'scope, 'env>) -> u32 {
@ -81,11 +82,13 @@ fn join_distaff(tree_size: usize) {
node.leaf + l + r
}
let sum = pool.scope(|s| {
let sum = sum(&tree, tree.root().unwrap(), s);
sum
});
std::hint::black_box(sum);
for _ in 0..1000 {
let sum = pool.scope(|s| {
let sum = sum(&tree, tree.root().unwrap(), s);
sum
});
std::hint::black_box(sum);
}
}
fn join_chili(tree_size: usize) {
@ -131,12 +134,17 @@ fn join_rayon(tree_size: usize) {
}
fn main() {
tracing_subscriber::fmt::init();
// use tracing_subscriber::layer::SubscriberExt;
// tracing::subscriber::set_global_default(
// tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()),
// tracing_subscriber::registry()
// .with(tracing_tracy::TracyLayer::default()),
// )
// .expect("Failed to set global default subscriber");
// eprintln!("Press Enter to start profiling...");
// std::io::stdin().read_line(&mut String::new()).unwrap();
let size = std::env::args()
.nth(2)
.and_then(|s| s.parse::<usize>().ok())
@ -158,7 +166,7 @@ fn main() {
}
}
eprintln!("Done!");
// wait for user input before exiting
// eprintln!("Done!");
// // wait for user input before exiting
// std::io::stdin().read_line(&mut String::new()).unwrap();
}