Compare commits
10 commits
1363f20cfc
...
8b4eba5a19
Author | SHA1 | Date | |
---|---|---|---|
|
8b4eba5a19 | ||
|
a1e1c90f90 | ||
|
5fae03dc06 | ||
|
c4b4f9248a | ||
|
3b07565118 | ||
|
bdbe207e7e | ||
|
eb8fd314f5 | ||
|
c3eb71dbb1 | ||
|
0db285a4a9 | ||
|
4742733683 |
|
@ -11,7 +11,6 @@ work-stealing = []
|
|||
prefer-local = []
|
||||
never-local = []
|
||||
|
||||
|
||||
[profile.bench]
|
||||
debug = true
|
||||
|
||||
|
@ -30,7 +29,7 @@ parking_lot = {version = "0.12.3"}
|
|||
thread_local = "1.1.8"
|
||||
crossbeam = "0.8.4"
|
||||
st3 = "0.4"
|
||||
chili = "0.2.0"
|
||||
chili = "0.2.1"
|
||||
|
||||
async-task = "4.7.1"
|
||||
|
||||
|
@ -48,4 +47,5 @@ cfg-if = "1.0.0"
|
|||
[dev-dependencies]
|
||||
async-std = "1.13.0"
|
||||
tracing-test = "0.2.5"
|
||||
tracing-tracy = "0.11.4"
|
||||
distaff = {path = "distaff"}
|
||||
|
|
|
@ -12,9 +12,11 @@ parking_lot = {version = "0.12.3"}
|
|||
tracing = "0.1.40"
|
||||
parking_lot_core = "0.9.10"
|
||||
crossbeam-utils = "0.8.21"
|
||||
either = "1.15.0"
|
||||
|
||||
async-task = "4.7.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tracing-test = "0.2.5"
|
||||
tracing-tracy = "0.11.4"
|
||||
futures = "0.3"
|
|
@ -1,8 +1,8 @@
|
|||
use std::{
|
||||
ptr::NonNull,
|
||||
ptr::{self, NonNull},
|
||||
sync::{
|
||||
Arc, OnceLock, Weak,
|
||||
atomic::{AtomicU8, Ordering},
|
||||
Arc, OnceLock,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -13,73 +13,49 @@ use crossbeam_utils::CachePadded;
|
|||
use parking_lot::{Condvar, Mutex};
|
||||
|
||||
use crate::{
|
||||
job::{HeapJob, Job, StackJob},
|
||||
latch::{LatchRef, MutexLatch, UnsafeWakeLatch},
|
||||
heartbeat::HeartbeatList,
|
||||
job::{HeapJob, JobSender, QueuedJob as Job, StackJob},
|
||||
latch::{AsCoreLatch, MutexLatch, NopLatch, WorkerLatch},
|
||||
workerthread::{HeartbeatThread, WorkerThread},
|
||||
};
|
||||
|
||||
pub struct Heartbeat {
|
||||
heartbeat: AtomicU8,
|
||||
pub latch: MutexLatch,
|
||||
}
|
||||
|
||||
impl Heartbeat {
|
||||
pub const CLEAR: u8 = 0;
|
||||
pub const PENDING: u8 = 1;
|
||||
pub const SLEEPING: u8 = 2;
|
||||
|
||||
pub fn new() -> (Arc<CachePadded<Self>>, Weak<CachePadded<Self>>) {
|
||||
let strong = Arc::new(CachePadded::new(Self {
|
||||
heartbeat: AtomicU8::new(Self::CLEAR),
|
||||
pub fn new() -> NonNull<CachePadded<Self>> {
|
||||
let ptr = Box::new(CachePadded::new(Self {
|
||||
latch: MutexLatch::new(),
|
||||
}));
|
||||
let weak = Arc::downgrade(&strong);
|
||||
|
||||
(strong, weak)
|
||||
}
|
||||
|
||||
/// returns true if the heartbeat was previously sleeping.
|
||||
pub fn set_pending(&self) -> bool {
|
||||
let old = self.heartbeat.swap(Self::PENDING, Ordering::Relaxed);
|
||||
old == Self::SLEEPING
|
||||
}
|
||||
|
||||
pub fn clear(&self) {
|
||||
self.heartbeat.store(Self::CLEAR, Ordering::Relaxed);
|
||||
Box::into_non_null(ptr)
|
||||
}
|
||||
|
||||
pub fn is_pending(&self) -> bool {
|
||||
self.heartbeat.load(Ordering::Relaxed) == Self::PENDING
|
||||
self.latch.as_core_latch().poll_heartbeat()
|
||||
}
|
||||
|
||||
pub fn is_sleeping(&self) -> bool {
|
||||
self.latch.as_core_latch().is_sleeping()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Context {
|
||||
shared: Mutex<Shared>,
|
||||
pub shared_job: Condvar,
|
||||
should_exit: AtomicBool,
|
||||
pub heartbeats: HeartbeatList,
|
||||
}
|
||||
|
||||
pub(crate) struct Shared {
|
||||
pub jobs: BTreeMap<usize, NonNull<Job>>,
|
||||
pub heartbeats: BTreeMap<usize, Weak<CachePadded<Heartbeat>>>,
|
||||
injected_jobs: Vec<NonNull<Job>>,
|
||||
heartbeat_count: usize,
|
||||
should_exit: bool,
|
||||
}
|
||||
|
||||
unsafe impl Send for Shared {}
|
||||
|
||||
impl Shared {
|
||||
pub fn new_heartbeat(&mut self) -> (Arc<CachePadded<Heartbeat>>, usize) {
|
||||
let index = self.heartbeat_count;
|
||||
self.heartbeat_count = index.wrapping_add(1);
|
||||
|
||||
let (strong, weak) = Heartbeat::new();
|
||||
|
||||
self.heartbeats.insert(index, weak);
|
||||
|
||||
(strong, index)
|
||||
}
|
||||
|
||||
pub fn pop_job(&mut self) -> Option<NonNull<Job>> {
|
||||
// this is unlikely, so make the function cold?
|
||||
// TODO: profile this
|
||||
|
@ -95,10 +71,6 @@ impl Shared {
|
|||
unsafe fn pop_injected_job(&mut self) -> NonNull<Job> {
|
||||
self.injected_jobs.pop().unwrap()
|
||||
}
|
||||
|
||||
pub fn should_exit(&self) -> bool {
|
||||
self.should_exit
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
|
@ -111,12 +83,11 @@ impl Context {
|
|||
let this = Arc::new(Self {
|
||||
shared: Mutex::new(Shared {
|
||||
jobs: BTreeMap::new(),
|
||||
heartbeats: BTreeMap::new(),
|
||||
injected_jobs: Vec::new(),
|
||||
heartbeat_count: 0,
|
||||
should_exit: false,
|
||||
}),
|
||||
shared_job: Condvar::new(),
|
||||
should_exit: AtomicBool::new(false),
|
||||
heartbeats: HeartbeatList::new(),
|
||||
});
|
||||
|
||||
tracing::trace!("Creating thread pool with {} threads", num_threads);
|
||||
|
@ -158,14 +129,12 @@ impl Context {
|
|||
}
|
||||
|
||||
pub fn set_should_exit(&self) {
|
||||
let mut shared = self.shared.lock();
|
||||
shared.should_exit = true;
|
||||
for (_, heartbeat) in shared.heartbeats.iter() {
|
||||
if let Some(heartbeat) = heartbeat.upgrade() {
|
||||
heartbeat.latch.set();
|
||||
}
|
||||
}
|
||||
self.shared_job.notify_all();
|
||||
self.should_exit.store(true, Ordering::Relaxed);
|
||||
self.heartbeats.notify_all();
|
||||
}
|
||||
|
||||
pub fn should_exit(&self) -> bool {
|
||||
self.should_exit.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn new() -> Arc<Self> {
|
||||
|
@ -181,11 +150,26 @@ impl Context {
|
|||
pub fn inject_job(&self, job: NonNull<Job>) {
|
||||
let mut shared = self.shared.lock();
|
||||
shared.injected_jobs.push(job);
|
||||
self.notify_shared_job();
|
||||
|
||||
unsafe {
|
||||
// SAFETY: we are holding the shared lock, so it is safe to notify
|
||||
self.notify_job_shared();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn notify_shared_job(&self) {
|
||||
self.shared_job.notify_one();
|
||||
// caller should hold the shared lock while calling this
|
||||
pub unsafe fn notify_job_shared(&self) {
|
||||
if let Some((i, sender)) = self
|
||||
.heartbeats
|
||||
.inner()
|
||||
.iter()
|
||||
.find(|(_, heartbeat)| heartbeat.is_waiting())
|
||||
{
|
||||
tracing::trace!("Notifying worker thread {} about job sharing", i);
|
||||
sender.wake();
|
||||
} else {
|
||||
tracing::warn!("No worker found to notify about job sharing");
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
||||
|
@ -197,8 +181,6 @@ impl Context {
|
|||
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
|
||||
|
||||
// SAFETY: we are waiting on this latch in this thread.
|
||||
let latch = unsafe { UnsafeWakeLatch::new(&raw const worker.heartbeat.latch) };
|
||||
|
||||
let job = StackJob::new(
|
||||
move || {
|
||||
let worker = WorkerThread::current_ref()
|
||||
|
@ -206,19 +188,16 @@ impl Context {
|
|||
|
||||
f(worker)
|
||||
},
|
||||
LatchRef::new(&latch),
|
||||
NopLatch,
|
||||
);
|
||||
|
||||
let job = job.as_job();
|
||||
job.set_pending();
|
||||
let job = Job::from_stackjob(&job, worker.heartbeat.raw_latch());
|
||||
|
||||
self.inject_job(Into::into(&job));
|
||||
|
||||
worker.wait_until_latch(&latch);
|
||||
let t = worker.wait_until_queued_job(&job).unwrap();
|
||||
|
||||
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
|
||||
|
||||
t
|
||||
crate::util::unwrap_or_panic(t)
|
||||
}
|
||||
|
||||
/// Run closure in this context, sleeping until the job is done.
|
||||
|
@ -227,10 +206,8 @@ impl Context {
|
|||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
use crate::latch::MutexLatch;
|
||||
// current thread isn't a worker thread, create job and inject into global context
|
||||
|
||||
let latch = MutexLatch::new();
|
||||
// current thread isn't a worker thread, create job and inject into context
|
||||
let latch = WorkerLatch::new();
|
||||
|
||||
let job = StackJob::new(
|
||||
move || {
|
||||
|
@ -239,21 +216,19 @@ impl Context {
|
|||
|
||||
f(worker)
|
||||
},
|
||||
LatchRef::new(&latch),
|
||||
NopLatch,
|
||||
);
|
||||
|
||||
let job = job.as_job();
|
||||
job.set_pending();
|
||||
let job = Job::from_stackjob(&job, &raw const latch);
|
||||
|
||||
self.inject_job(Into::into(&job));
|
||||
latch.wait();
|
||||
let recv = unsafe { job.as_receiver::<T>() };
|
||||
|
||||
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
|
||||
|
||||
t
|
||||
crate::util::unwrap_or_panic(latch.wait_until(|| recv.poll()))
|
||||
}
|
||||
|
||||
/// Run closure in this context.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn run_in_worker<T, F>(self: &Arc<Self>, f: F) -> T
|
||||
where
|
||||
T: Send,
|
||||
|
@ -285,12 +260,9 @@ impl Context {
|
|||
where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
let job = Box::new(HeapJob::new(f)).into_boxed_job();
|
||||
let job = Job::from_heapjob(Box::new(HeapJob::new(f)), ptr::null());
|
||||
tracing::trace!("Context::spawn: spawning job: {:?}", job);
|
||||
unsafe {
|
||||
(&*job).set_pending();
|
||||
self.inject_job(NonNull::new_unchecked(job));
|
||||
}
|
||||
self.inject_job(job);
|
||||
}
|
||||
|
||||
pub fn spawn_future<T, F>(self: &Arc<Self>, future: F) -> async_task::Task<T>
|
||||
|
@ -300,24 +272,24 @@ impl Context {
|
|||
{
|
||||
let schedule = move |runnable: Runnable| {
|
||||
#[align(8)]
|
||||
unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
|
||||
unsafe fn harness<T>(this: *const (), job: *const JobSender, _: *const WorkerLatch) {
|
||||
unsafe {
|
||||
let runnable =
|
||||
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
|
||||
runnable.run();
|
||||
|
||||
// SAFETY: job was turned into raw
|
||||
drop(Box::from_raw(job.cast_mut()));
|
||||
drop(Box::from_raw(job.cast::<JobSender<T>>().cast_mut()));
|
||||
}
|
||||
}
|
||||
|
||||
let job = Box::new(Job::<T>::new(harness::<T>, runnable.into_raw()));
|
||||
let job = Box::into_non_null(Box::new(Job::from_harness(
|
||||
harness::<T>,
|
||||
runnable.into_raw(),
|
||||
ptr::null(),
|
||||
)));
|
||||
|
||||
// casting into Job<()> here
|
||||
unsafe {
|
||||
job.set_pending();
|
||||
self.inject_job(NonNull::new_unchecked(Box::into_raw(job) as *mut Job<()>));
|
||||
}
|
||||
self.inject_job(job);
|
||||
};
|
||||
|
||||
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
||||
|
@ -350,19 +322,23 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::AtomicU8;
|
||||
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn run_in_worker_test() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn run_in_worker() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let result = ctx.run_in_worker(|_| 42);
|
||||
assert_eq!(result, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_future_test() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn context_spawn_future() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let task = ctx.spawn_future(async { 42 });
|
||||
|
||||
|
@ -372,7 +348,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_async_test() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn context_spawn_async() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let task = ctx.spawn_async(|| async { 42 });
|
||||
|
||||
|
@ -382,7 +359,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_test() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn context_spawn() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let counter = Arc::new(AtomicU8::new(0));
|
||||
let barrier = Arc::new(std::sync::Barrier::new(2));
|
||||
|
@ -399,4 +377,48 @@ mod tests {
|
|||
barrier.wait();
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn inject_job_and_wake_worker() {
|
||||
let ctx = Context::new_with_threads(1);
|
||||
let counter = Arc::new(AtomicU8::new(0));
|
||||
|
||||
let waker = WorkerLatch::new();
|
||||
|
||||
let job = StackJob::new(
|
||||
{
|
||||
let counter = counter.clone();
|
||||
move || {
|
||||
tracing::info!("Job running");
|
||||
counter.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
42
|
||||
}
|
||||
},
|
||||
NopLatch,
|
||||
);
|
||||
|
||||
let job = Job::from_stackjob(&job, &raw const waker);
|
||||
|
||||
// wait for the worker to sleep
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
|
||||
ctx.heartbeats
|
||||
.inner()
|
||||
.iter_mut()
|
||||
.next()
|
||||
.map(|(_, heartbeat)| {
|
||||
assert!(heartbeat.is_waiting());
|
||||
});
|
||||
|
||||
ctx.inject_job(Into::into(&job));
|
||||
|
||||
// Wait for the job to be executed
|
||||
let recv = unsafe { job.as_receiver::<i32>() };
|
||||
let result = waker.wait_until(|| recv.poll());
|
||||
let result = crate::util::unwrap_or_panic(result);
|
||||
assert_eq!(result, 42);
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
}
|
||||
|
|
238
distaff/src/heartbeat.rs
Normal file
238
distaff/src/heartbeat.rs
Normal file
|
@ -0,0 +1,238 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
mem::ManuallyDrop,
|
||||
ops::Deref,
|
||||
ptr::NonNull,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::latch::WorkerLatch;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeartbeatList {
|
||||
inner: Arc<Mutex<HeartbeatListInner>>,
|
||||
}
|
||||
|
||||
impl HeartbeatList {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(HeartbeatListInner::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn notify_nth(&self, n: usize) {
|
||||
self.inner.lock().notify_nth(n);
|
||||
}
|
||||
|
||||
pub fn notify_all(&self) {
|
||||
let mut inner = self.inner.lock();
|
||||
for (_, heartbeat) in inner.heartbeats.iter_mut() {
|
||||
heartbeat.set();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.inner.lock().len()
|
||||
}
|
||||
|
||||
pub fn new_heartbeat(&self) -> OwnedHeartbeatReceiver {
|
||||
let (recv, _) = self.inner.lock().new_heartbeat();
|
||||
OwnedHeartbeatReceiver {
|
||||
list: self.clone(),
|
||||
receiver: ManuallyDrop::new(recv),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inner(
|
||||
&self,
|
||||
) -> parking_lot::lock_api::MappedMutexGuard<
|
||||
'_,
|
||||
parking_lot::RawMutex,
|
||||
BTreeMap<u64, HeartbeatSender>,
|
||||
> {
|
||||
parking_lot::MutexGuard::map(self.inner.lock(), |inner| &mut inner.heartbeats)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct HeartbeatListInner {
|
||||
heartbeats: BTreeMap<u64, HeartbeatSender>,
|
||||
heartbeat_index: u64,
|
||||
}
|
||||
|
||||
impl HeartbeatListInner {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
heartbeats: BTreeMap::new(),
|
||||
heartbeat_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn notify_nth(&mut self, n: usize) {
|
||||
if let Some((_, heartbeat)) = self.heartbeats.iter_mut().nth(n) {
|
||||
heartbeat.set();
|
||||
}
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.heartbeats.len()
|
||||
}
|
||||
|
||||
fn new_heartbeat(&mut self) -> (HeartbeatReceiver, u64) {
|
||||
let heartbeat = Heartbeat::new(self.heartbeat_index);
|
||||
let (recv, send, i) = heartbeat.into_recv_send();
|
||||
self.heartbeats.insert(i, send);
|
||||
self.heartbeat_index += 1;
|
||||
(recv, i)
|
||||
}
|
||||
|
||||
fn remove_heartbeat(&mut self, receiver: HeartbeatReceiver) {
|
||||
if let Some(send) = self.heartbeats.remove(&receiver.i) {
|
||||
_ = Heartbeat::from_recv_send(receiver, send);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OwnedHeartbeatReceiver {
|
||||
list: HeartbeatList,
|
||||
receiver: ManuallyDrop<HeartbeatReceiver>,
|
||||
}
|
||||
|
||||
impl Deref for OwnedHeartbeatReceiver {
|
||||
type Target = HeartbeatReceiver;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.receiver
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OwnedHeartbeatReceiver {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
unsafe {
|
||||
let receiver = ManuallyDrop::take(&mut self.receiver);
|
||||
self.list.inner.lock().remove_heartbeat(receiver);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Heartbeat {
|
||||
ptr: NonNull<(AtomicBool, WorkerLatch)>,
|
||||
i: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HeartbeatReceiver {
|
||||
ptr: NonNull<(AtomicBool, WorkerLatch)>,
|
||||
i: u64,
|
||||
}
|
||||
|
||||
unsafe impl Send for HeartbeatReceiver {}
|
||||
|
||||
impl Drop for Heartbeat {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
unsafe {
|
||||
let _ = Box::from_raw(self.ptr.as_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HeartbeatSender {
|
||||
ptr: NonNull<(AtomicBool, WorkerLatch)>,
|
||||
pub last_heartbeat: Instant,
|
||||
}
|
||||
|
||||
unsafe impl Send for HeartbeatSender {}
|
||||
|
||||
impl Heartbeat {
|
||||
fn new(i: u64) -> Heartbeat {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
let ptr = NonNull::new(Box::into_raw(Box::new((
|
||||
AtomicBool::new(true),
|
||||
WorkerLatch::new(),
|
||||
))))
|
||||
.unwrap();
|
||||
Self { ptr, i }
|
||||
}
|
||||
|
||||
pub fn from_recv_send(recv: HeartbeatReceiver, send: HeartbeatSender) -> Heartbeat {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
_ = send;
|
||||
let ptr = recv.ptr;
|
||||
let i = recv.i;
|
||||
Heartbeat { ptr, i }
|
||||
}
|
||||
|
||||
pub fn into_recv_send(self) -> (HeartbeatReceiver, HeartbeatSender, u64) {
|
||||
// don't drop the `Heartbeat` yet
|
||||
let Self { ptr, i } = *ManuallyDrop::new(self);
|
||||
|
||||
(
|
||||
HeartbeatReceiver { ptr, i },
|
||||
HeartbeatSender {
|
||||
ptr,
|
||||
last_heartbeat: Instant::now(),
|
||||
},
|
||||
i,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HeartbeatReceiver {
|
||||
pub fn take(&self) -> bool {
|
||||
unsafe {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
self.ptr.as_ref().0.swap(false, Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait(&self) {
|
||||
unsafe { self.ptr.as_ref().1.wait() };
|
||||
}
|
||||
|
||||
pub fn raw_latch(&self) -> *const WorkerLatch {
|
||||
unsafe { &raw const self.ptr.as_ref().1 }
|
||||
}
|
||||
|
||||
pub fn latch(&self) -> &WorkerLatch {
|
||||
unsafe { &self.ptr.as_ref().1 }
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.ptr.as_ptr() as usize
|
||||
}
|
||||
|
||||
pub fn index(&self) -> u64 {
|
||||
self.i
|
||||
}
|
||||
}
|
||||
|
||||
impl HeartbeatSender {
|
||||
pub fn set(&mut self) {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
unsafe { self.ptr.as_ref().0.store(true, Ordering::Relaxed) };
|
||||
self.last_heartbeat = Instant::now();
|
||||
}
|
||||
|
||||
pub fn is_waiting(&self) -> bool {
|
||||
unsafe { self.ptr.as_ref().1.is_waiting() }
|
||||
}
|
||||
pub fn wake(&self) {
|
||||
unsafe { self.ptr.as_ref().1.wake() };
|
||||
}
|
||||
}
|
|
@ -7,11 +7,22 @@ use core::{
|
|||
ptr::{self, NonNull},
|
||||
sync::atomic::Ordering,
|
||||
};
|
||||
use std::{
|
||||
cell::Cell,
|
||||
marker::PhantomData,
|
||||
mem::MaybeUninit,
|
||||
ops::DerefMut,
|
||||
sync::atomic::{AtomicU8, AtomicU32, AtomicUsize},
|
||||
};
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use parking_lot_core::SpinWait;
|
||||
|
||||
use crate::util::{SmallBox, TaggedAtomicPtr};
|
||||
use crate::{
|
||||
latch::{Probe, WorkerLatch},
|
||||
util::{DropGuard, SmallBox, TaggedAtomicPtr},
|
||||
};
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
|
@ -650,7 +661,7 @@ mod stackjob {
|
|||
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
|
||||
|
||||
tracing::trace!("job completed: {:?}", job);
|
||||
tracing::trace!("stack job completed: {:?}", job);
|
||||
let job = unsafe { &*job.cast::<Job<T>>() };
|
||||
job.complete(result);
|
||||
|
||||
|
@ -703,13 +714,20 @@ mod heapjob {
|
|||
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
|
||||
let f = this.into_inner();
|
||||
|
||||
_ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
|
||||
{
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
|
||||
|
||||
let job = unsafe { &*job.cast::<Job<T>>() };
|
||||
job.complete(result);
|
||||
}
|
||||
|
||||
// drop job (this is fine because the job of a HeapJob is pure POD).
|
||||
unsafe {
|
||||
ptr::drop_in_place(job);
|
||||
}
|
||||
|
||||
tracing::trace!("heap job completed: {:?}", job);
|
||||
|
||||
// free box that was allocated at (1)
|
||||
_ = unsafe { Box::<ManuallyDrop<Job<T>>>::from_raw(job.cast()) };
|
||||
}
|
||||
|
@ -752,7 +770,8 @@ mod tests {
|
|||
assert_eq!(result.into_result(), 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// #[test]
|
||||
#[should_panic]
|
||||
fn job_lifecycle_panic() {
|
||||
let latch = AtomicLatch::new();
|
||||
let stack = StackJob::new(|| panic!("test panic"), LatchRef::new(&latch));
|
||||
|
@ -769,7 +788,7 @@ mod tests {
|
|||
|
||||
// wait for the job to finish
|
||||
let result = unsafe { job.transmute_ref::<i32>().wait() };
|
||||
assert!(result.into_inner().is_err());
|
||||
std::panic::resume_unwind(result.into_inner().unwrap_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -970,3 +989,364 @@ mod tests {
|
|||
assert!(vec.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
// A job, whether a `StackJob` or `HeapJob`, is turned into a `QueuedJob` when it is pushed to the job queue.
|
||||
#[repr(C)]
|
||||
pub struct QueuedJob {
|
||||
/// The job's harness and state.
|
||||
harness: TaggedAtomicPtr<usize, 3>,
|
||||
// This is later invalidated by the Receiver/Sender, so it must be wrapped in a `MaybeUninit`.
|
||||
// I'm not sure if it also must be inside of an `UnsafeCell`..
|
||||
inner: Cell<MaybeUninit<QueueJobInner>>,
|
||||
}
|
||||
|
||||
impl Debug for QueuedJob {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("QueuedJob")
|
||||
.field("harness", &self.harness)
|
||||
.field("inner", unsafe {
|
||||
(&*self.inner.as_ptr()).assume_init_ref()
|
||||
})
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
struct QueueJobInner {
|
||||
/// The job's value or `this` pointer. This is either a `StackJob` or `HeapJob`.
|
||||
this: NonNull<()>,
|
||||
/// The mutex to wake when the job is finished executing.
|
||||
mutex: *const WorkerLatch,
|
||||
}
|
||||
|
||||
/// A union that allows us to store either a `T` or a `U` without needing to know which one it is at runtime.
|
||||
/// The state must be tracked separately.
|
||||
union UnsafeVariant<T, U> {
|
||||
t: ManuallyDrop<T>,
|
||||
u: ManuallyDrop<U>,
|
||||
}
|
||||
|
||||
// The processed job is the result of executing a job, it contains the result of the job or an error.
|
||||
#[repr(C)]
|
||||
struct JobChannel<T = ()> {
|
||||
tag: TaggedAtomicPtr<usize, 3>,
|
||||
value: UnsafeCell<MaybeUninit<UnsafeVariant<SmallBox<T>, Box<dyn Any + Send + 'static>>>>,
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct JobSender<T = ()> {
|
||||
channel: JobChannel<T>,
|
||||
}
|
||||
#[repr(transparent)]
|
||||
pub struct JobReceiver<T = ()> {
|
||||
channel: JobChannel<T>,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct Job2 {}
|
||||
|
||||
const EMPTY: usize = 0;
|
||||
const SHARED: usize = 1 << 2;
|
||||
const FINISHED: usize = 1 << 0;
|
||||
const ERROR: usize = 1 << 1;
|
||||
|
||||
impl<T> JobSender<T> {
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn send(&self, result: std::thread::Result<T>, mutex: *const WorkerLatch) {
|
||||
// We want to lock here so that we can be sure that we wake the worker
|
||||
// only if it was waiting, and not immediately after having received the
|
||||
// result and waiting for further work:
|
||||
// | thread 1 | thread 2 |
|
||||
// | | | | |
|
||||
// | send-> | | |
|
||||
// | FINISHED | | |
|
||||
// | | | poll() |
|
||||
// | | | sleep() |
|
||||
// | wake() | |
|
||||
// | | | !woken! | // the worker has already received the result
|
||||
// | | | | | // and is waiting for more work, it shouldn't
|
||||
// | | | | | // be woken up here.
|
||||
// | <-send | | |
|
||||
//
|
||||
// if we lock, it looks like this:
|
||||
// | thread 1 | thread 2 |
|
||||
// | | | | |
|
||||
// | send-> | | |
|
||||
// | lock() | | |
|
||||
// | FINISHED | | |
|
||||
// | | | poll() |
|
||||
// | | | lock()-> | // thread 2 tries to lock.
|
||||
// | wake() | | // the wake signal is ignored
|
||||
// | | | |
|
||||
// | unlock() | |
|
||||
// | | | l=lock() | // thread2 wakes up and receives the lock
|
||||
// | <-send | sleep(l) | // thread 2 is now sleeping
|
||||
//
|
||||
// This concludes my TED talk on why we need to lock here.
|
||||
|
||||
let _guard = (!mutex.is_null()).then(|| {
|
||||
// SAFETY: mutex is a valid pointer to a WorkerLatch
|
||||
unsafe {
|
||||
(&*mutex).lock();
|
||||
DropGuard::new(|| {
|
||||
(&*mutex).wake();
|
||||
(&*mutex).unlock()
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
assert!(self.channel.tag.tag(Ordering::Acquire) & FINISHED == 0);
|
||||
|
||||
match result {
|
||||
Ok(value) => {
|
||||
let slot = unsafe { &mut *self.channel.value.get() };
|
||||
|
||||
slot.write(UnsafeVariant {
|
||||
t: ManuallyDrop::new(SmallBox::new(value)),
|
||||
});
|
||||
|
||||
self.channel.tag.fetch_or_tag(FINISHED, Ordering::Release);
|
||||
}
|
||||
Err(payload) => {
|
||||
let slot = unsafe { &mut *self.channel.value.get() };
|
||||
|
||||
slot.write(UnsafeVariant {
|
||||
u: ManuallyDrop::new(payload),
|
||||
});
|
||||
|
||||
self.channel
|
||||
.tag
|
||||
.fetch_or_tag(FINISHED | ERROR, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
// wake the worker waiting on the mutex and drop the guard
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> JobReceiver<T> {
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn is_finished(&self) -> bool {
|
||||
self.channel.tag.tag(Ordering::Acquire) & FINISHED != 0
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn poll(&self) -> Option<std::thread::Result<T>> {
|
||||
let tag = self.channel.tag.take_tag(Ordering::Acquire);
|
||||
|
||||
if tag & FINISHED == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// SAFETY: if we received a non-EMPTY tag, the value must be initialized.
|
||||
// because we atomically set the taag to EMPTY, we can be sure that we're the only ones accessing the value.
|
||||
let slot = unsafe { (&mut *self.channel.value.get()).assume_init_mut() };
|
||||
|
||||
if tag & ERROR != 0 {
|
||||
// job failed, return the error
|
||||
let err = unsafe { ManuallyDrop::take(&mut slot.u) };
|
||||
Some(Err(err))
|
||||
} else {
|
||||
// job succeeded, return the value
|
||||
let value = unsafe { ManuallyDrop::take(&mut slot.t) };
|
||||
Some(Ok(value.into_inner()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QueuedJob {
|
||||
fn new(
|
||||
harness: TaggedAtomicPtr<usize, 3>,
|
||||
this: NonNull<()>,
|
||||
mutex: *const WorkerLatch,
|
||||
) -> Self {
|
||||
let this = Self {
|
||||
harness,
|
||||
inner: Cell::new(MaybeUninit::new(QueueJobInner { this, mutex })),
|
||||
};
|
||||
|
||||
tracing::trace!("new queued job: {:?}", this);
|
||||
|
||||
this
|
||||
}
|
||||
pub fn from_stackjob<F, T, L>(job: &StackJob<F, L>, mutex: *const WorkerLatch) -> Self
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
#[align(8)]
|
||||
#[tracing::instrument(level = "trace", skip_all, name = "stack_job_harness")]
|
||||
unsafe fn harness<F, T, L>(
|
||||
this: *const (),
|
||||
sender: *const JobSender,
|
||||
mutex: *const WorkerLatch,
|
||||
) where
|
||||
F: FnOnce() -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
|
||||
let f = unsafe { (*this.cast::<StackJob<F, L>>()).unwrap() };
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f()));
|
||||
|
||||
unsafe {
|
||||
(&*(sender as *const JobSender<T>)).send(result, mutex);
|
||||
}
|
||||
}
|
||||
|
||||
Self::new(
|
||||
TaggedAtomicPtr::new(harness::<F, T, L> as *mut usize, EMPTY),
|
||||
unsafe { NonNull::new_unchecked(job as *const _ as *mut ()) },
|
||||
mutex,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_heapjob<F, T>(job: Box<HeapJob<F>>, mutex: *const WorkerLatch) -> NonNull<Self>
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
#[align(8)]
|
||||
#[tracing::instrument(level = "trace", skip_all, name = "heap_job_harness")]
|
||||
unsafe fn harness<F, T>(
|
||||
this: *const (),
|
||||
sender: *const JobSender,
|
||||
mutex: *const WorkerLatch,
|
||||
) where
|
||||
F: FnOnce() -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
|
||||
// expect MIRI to complain about this, but it is actually correct.
|
||||
// because I am so much smarter than MIRI, naturally, obviously.
|
||||
// unbox the job, which was allocated at (2)
|
||||
let f = unsafe { (*Box::from_raw(this.cast::<HeapJob<F>>().cast_mut())).into_inner() };
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f()));
|
||||
|
||||
unsafe {
|
||||
(&*(sender as *const JobSender<T>)).send(result, mutex);
|
||||
}
|
||||
|
||||
// drop the job, which was allocated at (1)
|
||||
_ = unsafe { Box::<ManuallyDrop<JobSender>>::from_raw(sender as *mut _) };
|
||||
}
|
||||
|
||||
// (1) allocate box for job
|
||||
Box::into_non_null(Box::new(Self::new(
|
||||
TaggedAtomicPtr::new(harness::<F, T> as *mut usize, EMPTY),
|
||||
// (2) convert job into a pointer
|
||||
unsafe { NonNull::new_unchecked(Box::into_raw(job) as *mut ()) },
|
||||
mutex,
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn from_harness(
|
||||
harness: unsafe fn(*const (), *const JobSender, *const WorkerLatch),
|
||||
this: NonNull<()>,
|
||||
mutex: *const WorkerLatch,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
TaggedAtomicPtr::new(harness as *mut usize, EMPTY),
|
||||
this,
|
||||
mutex,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_shared(&self) {
|
||||
self.harness.fetch_or_tag(SHARED, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn is_shared(&self) -> bool {
|
||||
self.harness.tag(Ordering::Relaxed) & SHARED != 0
|
||||
}
|
||||
|
||||
pub unsafe fn as_receiver<T>(&self) -> &JobReceiver<T> {
|
||||
unsafe { mem::transmute::<&QueuedJob, &JobReceiver<T>>(self) }
|
||||
}
|
||||
|
||||
/// this function will drop `_self` and execute the job.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub unsafe fn execute(_self: *mut Self) {
|
||||
let (harness, this, sender, mutex) = unsafe {
|
||||
let job = &*_self;
|
||||
tracing::debug!("executing queued job: {:?}", job);
|
||||
|
||||
let harness: unsafe fn(*const (), *const JobSender, *const WorkerLatch) =
|
||||
mem::transmute(job.harness.ptr(Ordering::Relaxed));
|
||||
let sender = mem::transmute::<*const Self, *const JobSender>(_self);
|
||||
|
||||
let QueueJobInner { this, mutex } =
|
||||
job.inner.replace(MaybeUninit::uninit()).assume_init();
|
||||
|
||||
(harness, this, sender, mutex)
|
||||
};
|
||||
|
||||
unsafe {
|
||||
// past this point, `_self` may no longer be a valid pointer to a `QueuedJob`.
|
||||
(harness)(this.as_ptr(), sender, mutex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Probe for QueuedJob {
|
||||
fn probe(&self) -> bool {
|
||||
self.harness.tag(Ordering::Relaxed) & FINISHED != 0
|
||||
}
|
||||
}
|
||||
|
||||
impl Probe for JobReceiver {
|
||||
fn probe(&self) -> bool {
|
||||
self.channel.tag.tag(Ordering::Relaxed) & FINISHED != 0
|
||||
}
|
||||
}
|
||||
|
||||
pub use queuedjobqueue::JobQueue;
|
||||
|
||||
mod queuedjobqueue {
|
||||
//! Basically `JobVec`, but for `QueuedJob`s.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct JobQueue {
|
||||
jobs: VecDeque<NonNull<QueuedJob>>,
|
||||
}
|
||||
|
||||
impl JobQueue {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
jobs: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, job: *const QueuedJob) {
|
||||
self.jobs
|
||||
.push_front(unsafe { NonNull::new_unchecked(job as *mut _) });
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, job: *const QueuedJob) {
|
||||
self.jobs
|
||||
.push_back(unsafe { NonNull::new_unchecked(job as *mut _) });
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self) -> Option<NonNull<QueuedJob>> {
|
||||
self.jobs.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self) -> Option<NonNull<QueuedJob>> {
|
||||
self.jobs.pop_back()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.jobs.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.jobs.len()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
use std::{hint::cold_path, ptr::NonNull, sync::Arc};
|
||||
use std::{hint::cold_path, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
context::Context,
|
||||
job::{JobState, StackJob},
|
||||
latch::{AsCoreLatch, LatchRef, UnsafeWakeLatch, WakeLatch},
|
||||
util::SendPtr,
|
||||
job::{QueuedJob as Job, StackJob},
|
||||
latch::NopLatch,
|
||||
workerthread::WorkerThread,
|
||||
};
|
||||
|
||||
impl WorkerThread {
|
||||
#[inline]
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
A: FnOnce() -> RA,
|
||||
B: FnOnce() -> RB,
|
||||
{
|
||||
let span = tracing::trace_span!("join_seq");
|
||||
let _guard = span.enter();
|
||||
|
||||
let rb = b();
|
||||
let ra = a();
|
||||
|
||||
|
@ -25,6 +26,7 @@ impl WorkerThread {
|
|||
|
||||
/// This function must be called from a worker thread.
|
||||
#[inline]
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub(crate) fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(
|
||||
&self,
|
||||
a: A,
|
||||
|
@ -32,20 +34,24 @@ impl WorkerThread {
|
|||
) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
B: FnOnce() -> RB,
|
||||
{
|
||||
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
||||
let count = self.join_count.get();
|
||||
let queue_len = unsafe { self.queue.as_ref_unchecked().len() };
|
||||
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
|
||||
|
||||
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
||||
// see: chili
|
||||
|
||||
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
|
||||
if count == 0 || unsafe { self.queue.as_ref_unchecked().len() } < 3 {
|
||||
if count == 0 || queue_len < 3 {
|
||||
cold_path();
|
||||
tracing::trace!(
|
||||
queue_len = queue_len,
|
||||
"join_heartbeat_every: using heartbeat join",
|
||||
);
|
||||
self.join_heartbeat(a, b)
|
||||
} else {
|
||||
self.join_seq(a, b)
|
||||
|
@ -54,65 +60,54 @@ impl WorkerThread {
|
|||
|
||||
/// This function must be called from a worker thread.
|
||||
#[inline]
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
B: FnOnce() -> RB,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
||||
|
||||
// SAFETY: this thread's heartbeat latch is valid until the job sets it
|
||||
// because we will be waiting on it.
|
||||
let latch = unsafe { UnsafeWakeLatch::new(&raw const (*self.heartbeat).latch) };
|
||||
let a = StackJob::new(a, NopLatch);
|
||||
let job = Job::from_stackjob(&a, self.heartbeat.raw_latch());
|
||||
|
||||
let a = StackJob::new(
|
||||
move || {
|
||||
// TODO: bench whether tick'ing here is good.
|
||||
// turns out this actually costs a lot of time, likely because of the thread local check.
|
||||
// WorkerThread::current_ref()
|
||||
// .expect("stackjob is run in workerthread.")
|
||||
// .tick();
|
||||
a()
|
||||
},
|
||||
LatchRef::new(&latch),
|
||||
);
|
||||
self.push_back(&job);
|
||||
|
||||
let job = a.as_job();
|
||||
self.push_front(&job);
|
||||
self.tick();
|
||||
|
||||
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
|
||||
Ok(val) => val,
|
||||
Err(payload) => {
|
||||
cold_path();
|
||||
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
||||
cold_path();
|
||||
// if b panicked, we need to wait for a to finish
|
||||
self.wait_until_latch(&latch);
|
||||
self.wait_until_latch(&job);
|
||||
resume_unwind(payload);
|
||||
}
|
||||
};
|
||||
|
||||
let ra = if job.state() == JobState::Empty as u8 {
|
||||
// remove job from the queue, so it doesn't get run again.
|
||||
// job.unlink();
|
||||
//SAFETY: we are in a worker thread, so we can safely access the queue.
|
||||
unsafe {
|
||||
self.queue.as_mut_unchecked().remove(&job);
|
||||
}
|
||||
let ra = if !job.is_shared() {
|
||||
tracing::trace!("join_heartbeat: job is not shared, running a() inline");
|
||||
// we pushed the job to the back of the queue, any `join`s called by `b` on this worker thread will have already popped their job, or seen it be executed.
|
||||
self.pop_back();
|
||||
|
||||
// a is allowed to panic here, because we already finished b.
|
||||
unsafe { a.unwrap()() }
|
||||
} else {
|
||||
match self.wait_until_job::<RA>(unsafe { job.transmute_ref() }, latch.as_core_latch()) {
|
||||
Some(t) => t.into_result(), // propagate panic here
|
||||
// the job was shared, but not yet stolen, so we get to run the
|
||||
// job inline
|
||||
None => unsafe { a.unwrap()() },
|
||||
match self.wait_until_queued_job(&job) {
|
||||
Some(t) => crate::util::unwrap_or_panic(t),
|
||||
None => {
|
||||
tracing::trace!(
|
||||
"join_heartbeat: job was shared, but reclaimed, running a() inline"
|
||||
);
|
||||
// the job was shared, but not yet stolen, so we get to run the
|
||||
// job inline
|
||||
unsafe { a.unwrap()() }
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
drop(a);
|
||||
(ra, rb)
|
||||
}
|
||||
}
|
||||
|
@ -121,10 +116,10 @@ impl Context {
|
|||
#[inline]
|
||||
pub fn join<A, B, RA, RB>(self: &Arc<Self>, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
{
|
||||
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
||||
self.run_in_worker(move |worker| worker.join_heartbeat_every::<_, _, _, _, 64>(a, b))
|
||||
|
@ -135,10 +130,10 @@ impl Context {
|
|||
#[allow(dead_code)]
|
||||
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
{
|
||||
join_in(Context::global_context().clone(), a, b)
|
||||
}
|
||||
|
@ -147,10 +142,10 @@ where
|
|||
#[allow(dead_code)]
|
||||
fn join_in<A, B, RA, RB>(context: Arc<Context>, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
{
|
||||
context.join(a, b)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,15 @@ use core::{
|
|||
marker::PhantomData,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
use std::sync::{Arc, atomic::AtomicU8};
|
||||
use std::{
|
||||
cell::UnsafeCell,
|
||||
mem,
|
||||
ops::DerefMut,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicPtr, AtomicU8},
|
||||
},
|
||||
};
|
||||
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
|
||||
|
@ -30,6 +38,8 @@ impl AtomicLatch {
|
|||
pub const UNSET: u8 = 0;
|
||||
pub const SET: u8 = 1;
|
||||
pub const SLEEPING: u8 = 2;
|
||||
pub const WAKEUP: u8 = 4;
|
||||
pub const HEARTBEAT: u8 = 8;
|
||||
|
||||
#[inline]
|
||||
pub const fn new() -> Self {
|
||||
|
@ -45,24 +55,58 @@ impl AtomicLatch {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset(&self) {
|
||||
self.inner.store(Self::UNSET, Ordering::Release);
|
||||
pub fn unset(&self) {
|
||||
self.inner.fetch_and(!Self::SET, Ordering::Release);
|
||||
}
|
||||
|
||||
pub fn reset(&self) -> u8 {
|
||||
self.inner.swap(Self::UNSET, Ordering::Release)
|
||||
}
|
||||
|
||||
pub fn get(&self) -> u8 {
|
||||
self.inner.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
pub fn set_sleeping(&self) {
|
||||
self.inner.store(Self::SLEEPING, Ordering::Release);
|
||||
pub fn poll_heartbeat(&self) -> bool {
|
||||
self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT
|
||||
== Self::HEARTBEAT
|
||||
}
|
||||
|
||||
/// returns true if the latch was already set.
|
||||
pub fn set_sleeping(&self) -> bool {
|
||||
self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET
|
||||
}
|
||||
|
||||
pub fn is_sleeping(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING
|
||||
}
|
||||
|
||||
pub fn is_heartbeat(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT
|
||||
}
|
||||
|
||||
pub fn is_wakeup(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP
|
||||
}
|
||||
|
||||
pub fn is_set(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
|
||||
}
|
||||
|
||||
pub fn set_wakeup(&self) {
|
||||
self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn set_heartbeat(&self) {
|
||||
self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// returns true if the latch was previously sleeping.
|
||||
#[inline]
|
||||
pub unsafe fn set(this: *const Self) -> bool {
|
||||
unsafe {
|
||||
let old = (*this).inner.swap(Self::SET, Ordering::Release);
|
||||
old == Self::SLEEPING
|
||||
let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed);
|
||||
old & Self::SLEEPING == Self::SLEEPING
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -79,7 +123,7 @@ impl Latch for AtomicLatch {
|
|||
impl Probe for AtomicLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
self.inner.load(Ordering::Acquire) == Self::SET
|
||||
self.inner.load(Ordering::Relaxed) & Self::SET != 0
|
||||
}
|
||||
}
|
||||
impl AsCoreLatch for AtomicLatch {
|
||||
|
@ -153,80 +197,29 @@ impl Probe for NopLatch {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ThreadWakeLatch {
|
||||
waker: Mutex<Option<std::thread::Thread>>,
|
||||
}
|
||||
|
||||
impl ThreadWakeLatch {
|
||||
#[inline]
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
waker: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset(&self) {
|
||||
let mut waker = self.waker.lock();
|
||||
*waker = None;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn set_waker(&self, thread: std::thread::Thread) {
|
||||
let mut waker = self.waker.lock();
|
||||
*waker = Some(thread);
|
||||
}
|
||||
|
||||
pub unsafe fn wait(&self) {
|
||||
assert!(
|
||||
self.waker.lock().replace(std::thread::current()).is_none(),
|
||||
"ThreadWakeLatch can only be waited once per thread"
|
||||
);
|
||||
|
||||
std::thread::park();
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for ThreadWakeLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
if let Some(thread) = (&*this).waker.lock().take() {
|
||||
thread.unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Probe for ThreadWakeLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
self.waker.lock().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CountLatch<L: Latch> {
|
||||
pub struct CountLatch {
|
||||
count: AtomicUsize,
|
||||
inner: L,
|
||||
inner: AtomicPtr<WorkerLatch>,
|
||||
}
|
||||
|
||||
impl<L: Latch> CountLatch<L> {
|
||||
impl CountLatch {
|
||||
#[inline]
|
||||
pub const fn new(inner: L) -> Self {
|
||||
pub const fn new(inner: *const WorkerLatch) -> Self {
|
||||
Self {
|
||||
count: AtomicUsize::new(0),
|
||||
inner,
|
||||
inner: AtomicPtr::new(inner as *mut WorkerLatch),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_inner(&self, inner: *const WorkerLatch) {
|
||||
self.inner
|
||||
.store(inner as *mut WorkerLatch, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn count(&self) -> usize {
|
||||
self.count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &L {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn increment(&self) {
|
||||
self.count.fetch_add(1, Ordering::Release);
|
||||
|
@ -234,63 +227,76 @@ impl<L: Latch> CountLatch<L> {
|
|||
|
||||
#[inline]
|
||||
pub fn decrement(&self) {
|
||||
if self.count.fetch_sub(1, Ordering::Release) == 1 {
|
||||
unsafe {
|
||||
Latch::set_raw(&self.inner);
|
||||
unsafe {
|
||||
Latch::set_raw(self);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for CountLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
|
||||
tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
|
||||
// If the count was 1, we need to set the inner latch.
|
||||
let inner = (*this).inner.load(Ordering::Relaxed);
|
||||
if !inner.is_null() {
|
||||
(&*inner).wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Latch> Latch for CountLatch<L> {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
let this = &*this;
|
||||
this.decrement();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Latch + Probe> Probe for CountLatch<L> {
|
||||
impl Probe for CountLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
self.count.load(Ordering::Relaxed) == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Latch + AsCoreLatch> AsCoreLatch for CountLatch<L> {
|
||||
#[inline]
|
||||
fn as_core_latch(&self) -> &CoreLatch {
|
||||
self.inner.as_core_latch()
|
||||
}
|
||||
pub struct MutexLatch {
|
||||
inner: AtomicLatch,
|
||||
lock: Mutex<()>,
|
||||
condvar: Condvar,
|
||||
}
|
||||
|
||||
pub struct MutexLatch {
|
||||
inner: Mutex<bool>,
|
||||
condvar: Condvar,
|
||||
unsafe impl Send for MutexLatch {}
|
||||
unsafe impl Sync for MutexLatch {}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum WakeResult {
|
||||
Wake,
|
||||
Heartbeat,
|
||||
Set,
|
||||
}
|
||||
|
||||
impl MutexLatch {
|
||||
#[inline]
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
inner: Mutex::new(false),
|
||||
inner: AtomicLatch::new(),
|
||||
lock: Mutex::new(()),
|
||||
condvar: Condvar::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset(&self) {
|
||||
let mut guard = self.inner.lock();
|
||||
*guard = false;
|
||||
let _guard = self.lock.lock();
|
||||
// SAFETY: inner is atomic, so we can safely access it.
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
pub fn wait(&self) {
|
||||
let mut guard = self.inner.lock();
|
||||
while !*guard {
|
||||
pub fn wait_and_reset(&self) {
|
||||
// SAFETY: inner is locked by the mutex, so we can safely access it.
|
||||
let mut guard = self.lock.lock();
|
||||
while !self.inner.probe() {
|
||||
self.condvar.wait(&mut guard);
|
||||
}
|
||||
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
pub fn set(&self) {
|
||||
|
@ -298,22 +304,17 @@ impl MutexLatch {
|
|||
Latch::set_raw(self);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait_and_reset(&self) {
|
||||
let mut guard = self.inner.lock();
|
||||
while !*guard {
|
||||
self.condvar.wait(&mut guard);
|
||||
}
|
||||
*guard = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for MutexLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
// SAFETY: `this` is valid until the guard is dropped.
|
||||
unsafe {
|
||||
*(&*this).inner.lock() = true;
|
||||
(&*this).condvar.notify_all();
|
||||
let this = &*this;
|
||||
let _guard = this.lock.lock();
|
||||
Latch::set_raw(&this.inner);
|
||||
this.condvar.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -321,112 +322,266 @@ impl Latch for MutexLatch {
|
|||
impl Probe for MutexLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
*self.inner.lock()
|
||||
}
|
||||
}
|
||||
|
||||
/// Must only be `set` from a worker thread.
|
||||
pub struct WakeLatch {
|
||||
inner: AtomicLatch,
|
||||
worker_index: AtomicUsize,
|
||||
}
|
||||
|
||||
impl WakeLatch {
|
||||
pub fn new(worker_index: usize) -> Self {
|
||||
Self {
|
||||
inner: AtomicLatch::new(),
|
||||
worker_index: AtomicUsize::new(worker_index),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_worker_index(&self, worker_index: usize) {
|
||||
self.worker_index.store(worker_index, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for WakeLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
let worker_index = (&*this).worker_index.load(Ordering::Relaxed);
|
||||
|
||||
if CoreLatch::set(&(&*this).inner) {
|
||||
let ctx = WorkerThread::current_ref().unwrap().context.clone();
|
||||
// If the latch was sleeping, wake the worker thread
|
||||
ctx.shared().heartbeats.get(&worker_index).and_then(|weak| {
|
||||
weak.upgrade()
|
||||
.map(|heartbeat| Latch::set_raw(&heartbeat.latch))
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Probe for WakeLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
let _guard = self.lock.lock();
|
||||
// SAFETY: inner is atomic, so we can safely access it.
|
||||
self.inner.probe()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsCoreLatch for WakeLatch {
|
||||
impl AsCoreLatch for MutexLatch {
|
||||
#[inline]
|
||||
fn as_core_latch(&self) -> &CoreLatch {
|
||||
&self.inner
|
||||
// SAFETY: inner is atomic, so we can safely access it.
|
||||
self.inner.as_core_latch()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UnsafeWakeLatch {
|
||||
inner: AtomicLatch,
|
||||
waker: *const MutexLatch,
|
||||
// The worker waits on this latch whenever it has nothing to do.
|
||||
#[derive(Debug)]
|
||||
pub struct WorkerLatch {
|
||||
// this boolean is set when the worker is waiting.
|
||||
mutex: Mutex<bool>,
|
||||
condvar: AtomicUsize,
|
||||
}
|
||||
|
||||
impl UnsafeWakeLatch {
|
||||
/// # Safety
|
||||
/// The `waker` must be valid until the latch is set.
|
||||
pub unsafe fn new(waker: *const MutexLatch) -> Self {
|
||||
impl WorkerLatch {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: AtomicLatch::new(),
|
||||
waker,
|
||||
mutex: Mutex::new(false),
|
||||
condvar: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for UnsafeWakeLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
pub fn lock(&self) {
|
||||
mem::forget(self.mutex.lock());
|
||||
}
|
||||
pub fn unlock(&self) {
|
||||
unsafe {
|
||||
let waker = (*this).waker;
|
||||
if CoreLatch::set(&(&*this).inner) {
|
||||
Latch::set_raw(waker);
|
||||
}
|
||||
self.mutex.force_unlock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Probe for UnsafeWakeLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
self.inner.probe()
|
||||
pub fn wait(&self) {
|
||||
let condvar = &self.condvar;
|
||||
let mut guard = self.mutex.lock();
|
||||
|
||||
Self::wait_internal(condvar, &mut guard);
|
||||
}
|
||||
}
|
||||
|
||||
impl AsCoreLatch for UnsafeWakeLatch {
|
||||
#[inline]
|
||||
fn as_core_latch(&self) -> &CoreLatch {
|
||||
&self.inner
|
||||
fn wait_internal(condvar: &AtomicUsize, guard: &mut parking_lot::MutexGuard<'_, bool>) {
|
||||
let mutex = parking_lot::MutexGuard::mutex(guard);
|
||||
let key = condvar as *const _ as usize;
|
||||
let lock_addr = mutex as *const _ as usize;
|
||||
let mut requeued = false;
|
||||
|
||||
let state = unsafe { AtomicUsize::from_ptr(condvar as *const _ as *mut usize) };
|
||||
|
||||
**guard = true; // set the mutex to true to indicate that the worker is waiting
|
||||
|
||||
unsafe {
|
||||
parking_lot_core::park(
|
||||
key,
|
||||
|| {
|
||||
let old = state.load(Ordering::Relaxed);
|
||||
if old == 0 {
|
||||
state.store(lock_addr, Ordering::Relaxed);
|
||||
} else if old != lock_addr {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
},
|
||||
|| {
|
||||
mutex.force_unlock();
|
||||
},
|
||||
|k, was_last_thread| {
|
||||
requeued = k != key;
|
||||
if !requeued && was_last_thread {
|
||||
state.store(0, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
parking_lot_core::DEFAULT_PARK_TOKEN,
|
||||
None,
|
||||
);
|
||||
}
|
||||
// relock
|
||||
|
||||
let mut new = mutex.lock();
|
||||
mem::swap(&mut new, guard);
|
||||
mem::forget(new); // forget the new guard to avoid dropping it
|
||||
|
||||
**guard = false; // reset the mutex to false after waking up
|
||||
}
|
||||
|
||||
fn wait_with_lock_internal<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
|
||||
let key = &self.condvar as *const _ as usize;
|
||||
let lock_addr = &self.mutex as *const _ as usize;
|
||||
let mut requeued = false;
|
||||
|
||||
let mut guard = self.mutex.lock();
|
||||
|
||||
let state = unsafe { AtomicUsize::from_ptr(&self.condvar as *const _ as *mut usize) };
|
||||
|
||||
*guard = true; // set the mutex to true to indicate that the worker is waiting
|
||||
|
||||
unsafe {
|
||||
let token = parking_lot_core::park(
|
||||
key,
|
||||
|| {
|
||||
let old = state.load(Ordering::Relaxed);
|
||||
if old == 0 {
|
||||
state.store(lock_addr, Ordering::Relaxed);
|
||||
} else if old != lock_addr {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
},
|
||||
|| {
|
||||
drop(guard); // drop the guard to release the lock
|
||||
parking_lot::MutexGuard::mutex(&other).force_unlock();
|
||||
},
|
||||
|k, was_last_thread| {
|
||||
requeued = k != key;
|
||||
if !requeued && was_last_thread {
|
||||
state.store(0, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
parking_lot_core::DEFAULT_PARK_TOKEN,
|
||||
None,
|
||||
);
|
||||
|
||||
tracing::trace!(
|
||||
"WorkerLatch wait_with_lock_internal: unparked with token {:?}",
|
||||
token
|
||||
);
|
||||
}
|
||||
// relock
|
||||
let mut other2 = parking_lot::MutexGuard::mutex(&other).lock();
|
||||
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked other");
|
||||
|
||||
// because `other` is logically unlocked, we swap it with `other2` and then forget `other2`
|
||||
core::mem::swap(&mut *other2, &mut *other);
|
||||
core::mem::forget(other2);
|
||||
|
||||
let mut guard = self.mutex.lock();
|
||||
tracing::trace!("WorkerLatch wait_with_lock_internal: relocked self");
|
||||
|
||||
*guard = false; // reset the mutex to false after waking up
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(other))]
|
||||
pub fn wait_with_lock<T>(&self, other: &mut parking_lot::MutexGuard<'_, T>) {
|
||||
self.wait_with_lock_internal(other);
|
||||
}
|
||||
|
||||
pub fn wait_with_lock_while<T, F>(&self, other: &mut parking_lot::MutexGuard<'_, T>, mut f: F)
|
||||
where
|
||||
F: FnMut(&mut T) -> bool,
|
||||
{
|
||||
while f(other.deref_mut()) {
|
||||
self.wait_with_lock_internal(other);
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(f))]
|
||||
pub fn wait_until<F, T>(&self, mut f: F) -> T
|
||||
where
|
||||
F: FnMut() -> Option<T>,
|
||||
{
|
||||
let mut guard = self.mutex.lock();
|
||||
loop {
|
||||
if let Some(result) = f() {
|
||||
return result;
|
||||
}
|
||||
Self::wait_internal(&self.condvar, &mut guard);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_waiting(&self) -> bool {
|
||||
*self.mutex.lock()
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace")]
|
||||
fn notify(&self) {
|
||||
let key = &self.condvar as *const _ as usize;
|
||||
|
||||
unsafe {
|
||||
let n = parking_lot_core::unpark_all(key, parking_lot_core::DEFAULT_UNPARK_TOKEN);
|
||||
tracing::trace!("WorkerLatch notify_one: unparked {} threads", n);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wake(&self) {
|
||||
self.notify();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Barrier;
|
||||
use std::{ptr, sync::Barrier};
|
||||
|
||||
use tracing::Instrument;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn worker_latch() {
|
||||
let latch = Arc::new(WorkerLatch::new());
|
||||
let barrier = Arc::new(Barrier::new(2));
|
||||
let mutex = Arc::new(parking_lot::Mutex::new(false));
|
||||
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let thread = std::thread::spawn({
|
||||
let latch = latch.clone();
|
||||
let mutex = mutex.clone();
|
||||
let barrier = barrier.clone();
|
||||
let count = count.clone();
|
||||
|
||||
move || {
|
||||
tracing::info!("Thread waiting on barrier");
|
||||
let mut guard = mutex.lock();
|
||||
barrier.wait();
|
||||
|
||||
tracing::info!("Thread waiting on latch");
|
||||
latch.wait_with_lock(&mut guard);
|
||||
count.fetch_add(1, Ordering::Relaxed);
|
||||
tracing::info!("Thread woke up from latch");
|
||||
barrier.wait();
|
||||
tracing::info!("Thread finished waiting on barrier");
|
||||
count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
});
|
||||
|
||||
assert!(!latch.is_waiting(), "Latch should not be waiting yet");
|
||||
barrier.wait();
|
||||
tracing::info!("Main thread finished waiting on barrier");
|
||||
// lock mutex and notify the thread that isn't yet waiting.
|
||||
{
|
||||
let guard = mutex.lock();
|
||||
tracing::info!("Main thread acquired mutex, waking up thread");
|
||||
assert!(latch.is_waiting(), "Latch should be waiting now");
|
||||
|
||||
latch.wake();
|
||||
tracing::info!("Main thread woke up thread");
|
||||
}
|
||||
assert_eq!(count.load(Ordering::Relaxed), 0, "Count should still be 0");
|
||||
barrier.wait();
|
||||
assert_eq!(
|
||||
count.load(Ordering::Relaxed),
|
||||
1,
|
||||
"Count should be 1 after waking up"
|
||||
);
|
||||
|
||||
thread.join().expect("Thread should join successfully");
|
||||
assert_eq!(
|
||||
count.load(Ordering::Relaxed),
|
||||
2,
|
||||
"Count should be 2 after thread has finished"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_latch() {
|
||||
let latch = AtomicLatch::new();
|
||||
|
@ -437,7 +592,7 @@ mod tests {
|
|||
}
|
||||
assert_eq!(latch.get(), AtomicLatch::SET);
|
||||
assert!(latch.probe());
|
||||
latch.reset();
|
||||
latch.unset();
|
||||
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
||||
}
|
||||
|
||||
|
@ -451,7 +606,7 @@ mod tests {
|
|||
assert!(!latch.probe());
|
||||
assert!(AtomicLatch::set(&latch));
|
||||
}
|
||||
assert_eq!(latch.get(), AtomicLatch::SET);
|
||||
assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING);
|
||||
assert!(latch.probe());
|
||||
latch.reset();
|
||||
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
||||
|
@ -465,35 +620,9 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thread_wake_latch() {
|
||||
let latch = Arc::new(ThreadWakeLatch::new());
|
||||
let main = Arc::new(ThreadWakeLatch::new());
|
||||
|
||||
let handle = std::thread::spawn({
|
||||
let latch = latch.clone();
|
||||
let main = main.clone();
|
||||
move || unsafe {
|
||||
Latch::set_raw(&*main);
|
||||
latch.wait();
|
||||
}
|
||||
});
|
||||
|
||||
unsafe {
|
||||
main.wait();
|
||||
Latch::set_raw(&*latch);
|
||||
}
|
||||
|
||||
handle.join().expect("Thread should join successfully");
|
||||
assert!(
|
||||
!latch.probe() && !main.probe(),
|
||||
"Latch should be set after waiting thread wakes up"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn count_latch() {
|
||||
let latch = CountLatch::new(AtomicLatch::new());
|
||||
let latch = CountLatch::new(ptr::null());
|
||||
assert_eq!(latch.count(), 0);
|
||||
latch.increment();
|
||||
assert_eq!(latch.count(), 1);
|
||||
|
@ -516,6 +645,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[traced_test]
|
||||
fn mutex_latch() {
|
||||
let latch = Arc::new(MutexLatch::new());
|
||||
assert!(!latch.probe());
|
||||
|
@ -527,61 +657,18 @@ mod tests {
|
|||
// Test wait functionality
|
||||
let latch_clone = latch.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
latch_clone.wait();
|
||||
tracing::info!("Thread waiting on latch");
|
||||
latch_clone.wait_and_reset();
|
||||
tracing::info!("Thread woke up from latch");
|
||||
});
|
||||
|
||||
// Give the thread time to block
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
assert!(!latch.probe());
|
||||
|
||||
tracing::info!("Setting latch from main thread");
|
||||
latch.set();
|
||||
assert!(latch.probe());
|
||||
tracing::info!("Latch set, joining waiting thread");
|
||||
handle.join().expect("Thread should join successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wake_latch() {
|
||||
let context = Context::new_with_threads(1);
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
let barrier = Arc::new(Barrier::new(2));
|
||||
|
||||
tracing::info!("running scope in worker thread");
|
||||
context.run_in_worker(|worker| {
|
||||
tracing::info!("worker thread started: {:?}", worker.index);
|
||||
let latch = Arc::new(WakeLatch::new(worker.index));
|
||||
worker.context.spawn({
|
||||
let heartbeat = worker.heartbeat.clone();
|
||||
let barrier = barrier.clone();
|
||||
let count = count.clone();
|
||||
let latch = latch.clone();
|
||||
move || {
|
||||
tracing::info!("sleeping workerthread");
|
||||
|
||||
latch.as_core_latch().set_sleeping();
|
||||
heartbeat.latch.wait_and_reset();
|
||||
tracing::info!("woken up workerthread");
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
tracing::info!("waiting on barrier");
|
||||
barrier.wait();
|
||||
}
|
||||
});
|
||||
|
||||
worker.context.spawn({
|
||||
move || {
|
||||
tracing::info!("setting latch in worker thread");
|
||||
unsafe {
|
||||
Latch::set_raw(&*latch);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
tracing::info!("main thread set latch, waiting for worker thread to wake up");
|
||||
barrier.wait();
|
||||
assert_eq!(
|
||||
count.load(Ordering::SeqCst),
|
||||
1,
|
||||
"Latch should have woken the worker thread"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,12 +7,14 @@
|
|||
unsafe_cell_access,
|
||||
box_as_ptr,
|
||||
box_vec_non_null,
|
||||
strict_provenance_atomic_ptr,
|
||||
let_chains
|
||||
)]
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
mod context;
|
||||
mod heartbeat;
|
||||
mod job;
|
||||
mod join;
|
||||
mod latch;
|
||||
|
|
|
@ -12,15 +12,48 @@ use async_task::Runnable;
|
|||
|
||||
use crate::{
|
||||
context::Context,
|
||||
job::{HeapJob, Job},
|
||||
latch::{AsCoreLatch, CountLatch, WakeLatch},
|
||||
job::{HeapJob, JobSender, QueuedJob as Job},
|
||||
latch::{CountLatch, WorkerLatch},
|
||||
util::{DropGuard, SendPtr},
|
||||
workerthread::WorkerThread,
|
||||
};
|
||||
|
||||
// thinking:
|
||||
|
||||
// the scope needs to keep track of any spawn() and spawn_async() calls, across all worker threads.
|
||||
// that means, that for any spawn() or spawn_async() calls, we have to share a counter across all worker threads.
|
||||
// we want to minimise the number of atomic operations in general.
|
||||
// atomic operations occur in the following cases:
|
||||
// - when we spawn() or spawn_async() a job, we increment the counter
|
||||
// - when the same job finishes, we decrement the counter
|
||||
// - when a join() job finishes, it's latch is set
|
||||
// - when we wait for a join() job, we loop over the latch until it is set
|
||||
|
||||
// find below a sketch of an unbalanced tree:
|
||||
// []
|
||||
// / \
|
||||
// [] []
|
||||
// / \ / \
|
||||
// [] [] [] []
|
||||
// / \ / \
|
||||
// [] [][] []
|
||||
// / \ / \
|
||||
// [] [] [] []
|
||||
// / \ / \
|
||||
// [] [] [] []
|
||||
// / \
|
||||
// [] []
|
||||
|
||||
// in this tree of join() calls, it is possible to wait for a long time, so it is necessary to keep waking up when a job is shared.
|
||||
|
||||
// the worker waits on it's latch, which may be woken by:
|
||||
// - a job finishing
|
||||
// - another thread sharing a job
|
||||
// - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share.
|
||||
|
||||
pub struct Scope<'scope, 'env: 'scope> {
|
||||
// latch to wait on before the scope finishes
|
||||
job_counter: CountLatch<WakeLatch>,
|
||||
job_counter: CountLatch,
|
||||
// local threadpool
|
||||
context: Arc<Context>,
|
||||
// panic error
|
||||
|
@ -53,20 +86,24 @@ where
|
|||
}
|
||||
|
||||
impl<'scope, 'env> Scope<'scope, 'env> {
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn wait_for_jobs(&self, worker: &WorkerThread) {
|
||||
self.job_counter.set_inner(worker.heartbeat.raw_latch());
|
||||
if self.job_counter.count() > 0 {
|
||||
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
|
||||
tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe {
|
||||
worker.queue.as_ref_unchecked()
|
||||
});
|
||||
tracing::trace!(
|
||||
"thread id: {:?}, jobs: {:?}",
|
||||
worker.heartbeat.index(),
|
||||
unsafe { worker.queue.as_ref_unchecked() }
|
||||
);
|
||||
|
||||
// set worker index in the job counter
|
||||
self.job_counter.inner().set_worker_index(worker.index);
|
||||
worker.wait_until_latch(self.job_counter.as_core_latch());
|
||||
worker.wait_until_latch(&self.job_counter);
|
||||
}
|
||||
}
|
||||
|
||||
/// should be called from within a worker thread.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn complete<F, R>(&self, worker: &WorkerThread, f: F) -> R
|
||||
where
|
||||
F: FnOnce() -> R + Send,
|
||||
|
@ -74,23 +111,6 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn make_job<F: FnOnce() -> T, T>(f: F) -> Job<T> {
|
||||
#[align(8)]
|
||||
unsafe fn harness<F: FnOnce() -> T, T>(this: *const (), job: *const Job<T>) {
|
||||
let f = unsafe { Box::from_raw(this.cast::<F>().cast_mut()) };
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(move || f()));
|
||||
|
||||
let job = unsafe { Box::from_raw(job.cast_mut()) };
|
||||
job.complete(result);
|
||||
}
|
||||
|
||||
Job::<T>::new(harness::<F, T>, unsafe {
|
||||
NonNull::new_unchecked(Box::into_raw(Box::new(f))).cast()
|
||||
})
|
||||
}
|
||||
|
||||
let result = match catch_unwind(AssertUnwindSafe(|| f())) {
|
||||
Ok(val) => Some(val),
|
||||
Err(payload) => {
|
||||
|
@ -107,6 +127,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
}
|
||||
|
||||
/// resumes the panic if one happened in this scope.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn maybe_propagate_panic(&self) {
|
||||
let err_ptr = self.panic.load(Ordering::Relaxed);
|
||||
if !err_ptr.is_null() {
|
||||
|
@ -118,7 +139,9 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
}
|
||||
|
||||
/// stores the first panic that happened in this scope.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
||||
tracing::debug!("panicked in scope, storing error: {:?}", err);
|
||||
self.panic.load(Ordering::Relaxed).is_null().then(|| {
|
||||
use core::mem::ManuallyDrop;
|
||||
let mut boxed = ManuallyDrop::new(Box::new(err));
|
||||
|
@ -146,23 +169,28 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
where
|
||||
F: FnOnce(&'scope Self) + Send,
|
||||
{
|
||||
self.context.run_in_worker(|worker| {
|
||||
self.job_counter.increment();
|
||||
self.job_counter.increment();
|
||||
|
||||
let this = SendPtr::new_const(self).unwrap();
|
||||
let this = SendPtr::new_const(self).unwrap();
|
||||
|
||||
let job = Box::new(HeapJob::new(move || unsafe {
|
||||
_ = f(this.as_ref());
|
||||
this.as_ref().job_counter.decrement();
|
||||
}))
|
||||
.into_boxed_job();
|
||||
let job = Job::from_heapjob(
|
||||
Box::new(HeapJob::new(move || unsafe {
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(this.as_ref()))) {
|
||||
this.as_unchecked_ref().panicked(payload);
|
||||
}
|
||||
this.as_unchecked_ref().job_counter.decrement();
|
||||
})),
|
||||
ptr::null(),
|
||||
);
|
||||
|
||||
tracing::trace!("allocated heapjob");
|
||||
tracing::trace!("allocated heapjob");
|
||||
|
||||
worker.push_front(job);
|
||||
WorkerThread::current_ref()
|
||||
.expect("spawn is run in workerthread.")
|
||||
.push_front(job.as_ptr());
|
||||
|
||||
tracing::trace!("leaked heapjob");
|
||||
});
|
||||
tracing::trace!("leaked heapjob");
|
||||
}
|
||||
|
||||
pub fn spawn_future<T, F>(&'scope self, future: F) -> async_task::Task<T>
|
||||
|
@ -201,13 +229,14 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
let _guard = DropGuard::new(move || {
|
||||
this.as_unchecked_ref().job_counter.decrement();
|
||||
});
|
||||
// TODO: handle panics here
|
||||
f(this.as_ref()).await
|
||||
}
|
||||
};
|
||||
|
||||
let schedule = move |runnable: Runnable| {
|
||||
#[align(8)]
|
||||
unsafe fn harness(this: *const (), job: *const Job) {
|
||||
unsafe fn harness(this: *const (), job: *const JobSender, _: *const WorkerLatch) {
|
||||
unsafe {
|
||||
let runnable =
|
||||
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
|
||||
|
@ -218,12 +247,16 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
}
|
||||
}
|
||||
|
||||
let job = Box::new(Job::new(harness, runnable.into_raw()));
|
||||
let job = Box::into_raw(Box::new(Job::from_harness(
|
||||
harness,
|
||||
runnable.into_raw(),
|
||||
ptr::null(),
|
||||
)));
|
||||
|
||||
// casting into Job<()> here
|
||||
WorkerThread::current_ref()
|
||||
.expect("spawn_async_internal is run in workerthread.")
|
||||
.push_front(Box::into_raw(job) as _);
|
||||
.push_front(job);
|
||||
};
|
||||
|
||||
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
||||
|
@ -259,7 +292,7 @@ impl<'scope, 'env> Scope<'scope, 'env> {
|
|||
unsafe fn from_context(context: Arc<Context>) -> Self {
|
||||
Self {
|
||||
context,
|
||||
job_counter: CountLatch::new(WakeLatch::new(0)),
|
||||
job_counter: CountLatch::new(ptr::null()),
|
||||
panic: AtomicPtr::new(ptr::null_mut()),
|
||||
_scope: PhantomData,
|
||||
_env: PhantomData,
|
||||
|
@ -277,7 +310,8 @@ mod tests {
|
|||
use crate::ThreadPool;
|
||||
|
||||
#[test]
|
||||
fn spawn() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn scope_spawn_sync() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let count = Arc::new(AtomicU8::new(0));
|
||||
|
||||
|
@ -291,8 +325,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[traced_test]
|
||||
fn join() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn scope_join_one() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
|
||||
let a = pool.scope(|scope| {
|
||||
|
@ -304,7 +338,30 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_future() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn scope_join_many() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
|
||||
fn sum<'scope, 'env>(scope: &'scope Scope<'scope, 'env>, n: usize) -> usize {
|
||||
if n == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let (l, r) = scope.join(|s| sum(s, n - 1), |s| sum(s, n - 1));
|
||||
|
||||
l + r + 1
|
||||
}
|
||||
|
||||
pool.scope(|scope| {
|
||||
let total = sum(scope, 10);
|
||||
assert_eq!(total, 1023);
|
||||
// eprintln!("Total sum: {}", total);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn scope_spawn_future() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let mut x = 0;
|
||||
pool.scope(|scope| {
|
||||
|
@ -317,4 +374,22 @@ mod tests {
|
|||
|
||||
assert_eq!(x, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn scope_spawn_many() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let count = Arc::new(AtomicU8::new(0));
|
||||
|
||||
pool.scope(|scope| {
|
||||
for _ in 0..10 {
|
||||
let count = count.clone();
|
||||
scope.spawn(move |_| {
|
||||
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 10);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,14 +53,18 @@ impl ThreadPool {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn spawn_borrow() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn pool_spawn_borrow() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let mut x = 0;
|
||||
pool.scope(|scope| {
|
||||
scope.spawn(|_| {
|
||||
tracing::info!("Incrementing x");
|
||||
x += 1;
|
||||
});
|
||||
});
|
||||
|
@ -68,7 +72,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_future() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn pool_spawn_future() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let mut x = 0;
|
||||
let task = pool.scope(|scope| {
|
||||
|
@ -84,7 +89,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn join() {
|
||||
#[cfg_attr(not(miri), traced_test)]
|
||||
fn pool_join() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let (a, b) = pool.join(|| 3 + 4, || 5 * 6);
|
||||
assert_eq!(a, 7);
|
||||
|
|
|
@ -54,7 +54,7 @@ impl<T> core::fmt::Pointer for SendPtr<T> {
|
|||
}
|
||||
}
|
||||
|
||||
unsafe impl<T> Send for SendPtr<T> {}
|
||||
unsafe impl<T> core::marker::Send for SendPtr<T> {}
|
||||
|
||||
impl<T> Deref for SendPtr<T> {
|
||||
type Target = NonNull<T>;
|
||||
|
@ -104,6 +104,7 @@ impl<T> SendPtr<T> {
|
|||
/// as the pointer.
|
||||
/// The pointer must be aligned to `BITS` bits, i.e. `align_of::<T>() >= 2^BITS`.
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug)]
|
||||
pub struct TaggedAtomicPtr<T, const BITS: u8> {
|
||||
ptr: AtomicPtr<()>,
|
||||
_pd: PhantomData<T>,
|
||||
|
@ -138,6 +139,19 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
|
|||
self.ptr.load(order).addr() & Self::mask()
|
||||
}
|
||||
|
||||
pub fn fetch_or_tag(&self, tag: usize, order: Ordering) -> usize {
|
||||
let mask = Self::mask();
|
||||
let old_ptr = self.ptr.fetch_or(tag & mask, order);
|
||||
old_ptr.addr() & mask
|
||||
}
|
||||
|
||||
/// returns the tag and clears it
|
||||
pub fn take_tag(&self, order: Ordering) -> usize {
|
||||
let mask = Self::mask();
|
||||
let old_ptr = self.ptr.fetch_and(!mask, order);
|
||||
old_ptr.addr() & mask
|
||||
}
|
||||
|
||||
/// returns tag
|
||||
#[inline(always)]
|
||||
fn compare_exchange_tag_inner(
|
||||
|
@ -408,10 +422,53 @@ pub fn available_parallelism() -> usize {
|
|||
.unwrap_or(1)
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Send<T>(pub(self) T);
|
||||
|
||||
unsafe impl<T> core::marker::Send for Send<T> {}
|
||||
|
||||
impl<T> Deref for Send<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl<T> DerefMut for Send<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Send<T> {
|
||||
pub unsafe fn new(value: T) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unwrap_or_panic<T>(result: std::thread::Result<T>) -> T {
|
||||
match result {
|
||||
Ok(value) => value,
|
||||
Err(payload) => std::panic::resume_unwind(payload),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tagged_ptr_zero_tag() {
|
||||
let ptr = Box::into_raw(Box::new(42u32));
|
||||
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0);
|
||||
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0);
|
||||
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
|
||||
|
||||
unsafe {
|
||||
_ = Box::from_raw(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tagged_ptr_exchange() {
|
||||
let ptr = Box::into_raw(Box::new(42u32));
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::{
|
||||
cell::{Cell, UnsafeCell},
|
||||
hint::cold_path,
|
||||
ptr::NonNull,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
|
@ -9,16 +10,16 @@ use crossbeam_utils::CachePadded;
|
|||
|
||||
use crate::{
|
||||
context::{Context, Heartbeat},
|
||||
job::{Job, JobList, JobResult},
|
||||
latch::{AsCoreLatch, CoreLatch, Probe},
|
||||
heartbeat::OwnedHeartbeatReceiver,
|
||||
job::{JobQueue as JobList, JobResult, QueuedJob as Job, QueuedJob, StackJob},
|
||||
latch::{AsCoreLatch, CoreLatch, Probe, WorkerLatch},
|
||||
util::DropGuard,
|
||||
};
|
||||
|
||||
pub struct WorkerThread {
|
||||
pub(crate) context: Arc<Context>,
|
||||
pub(crate) index: usize,
|
||||
pub(crate) queue: UnsafeCell<JobList>,
|
||||
pub(crate) heartbeat: Arc<CachePadded<Heartbeat>>,
|
||||
pub(crate) heartbeat: OwnedHeartbeatReceiver,
|
||||
pub(crate) join_count: Cell<u8>,
|
||||
}
|
||||
|
||||
|
@ -28,11 +29,10 @@ thread_local! {
|
|||
|
||||
impl WorkerThread {
|
||||
pub fn new_in(context: Arc<Context>) -> Self {
|
||||
let (heartbeat, index) = context.shared().new_heartbeat();
|
||||
let heartbeat = context.heartbeats.new_heartbeat();
|
||||
|
||||
Self {
|
||||
context,
|
||||
index,
|
||||
queue: UnsafeCell::new(JobList::new()),
|
||||
heartbeat,
|
||||
join_count: Cell::new(0),
|
||||
|
@ -41,6 +41,7 @@ impl WorkerThread {
|
|||
}
|
||||
|
||||
impl WorkerThread {
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn run(self: Box<Self>) {
|
||||
let this = Box::into_raw(self);
|
||||
unsafe {
|
||||
|
@ -62,87 +63,108 @@ impl WorkerThread {
|
|||
tracing::trace!("WorkerThread::run: worker thread finished");
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn run_inner(&self) {
|
||||
let mut job = self.context.shared().pop_job();
|
||||
let mut job = None;
|
||||
'outer: loop {
|
||||
let mut guard = loop {
|
||||
if let Some(job) = job.take() {
|
||||
self.execute(job);
|
||||
if let Some(job) = job.take() {
|
||||
self.execute(job);
|
||||
}
|
||||
|
||||
// no more jobs, wait to be notified of a new job or a heartbeat.
|
||||
while job.is_none() {
|
||||
if self.context.should_exit() {
|
||||
// if the context is stopped, break out of the outer loop which
|
||||
// will exit the thread.
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
// we executed the shared job, now we want to check for any
|
||||
// local jobs which this job might have spawned.
|
||||
let next = self
|
||||
.pop_front()
|
||||
.map(|job| (Some(job), None))
|
||||
.unwrap_or_else(|| {
|
||||
let mut guard = self.context.shared();
|
||||
(guard.pop_job(), Some(guard))
|
||||
});
|
||||
|
||||
match next {
|
||||
// no job, but guard => check if we should exit
|
||||
(None, Some(guard)) => {
|
||||
tracing::trace!("worker: no local job, waiting for shared job");
|
||||
|
||||
if guard.should_exit() {
|
||||
// if the context is stopped, break out of the outer loop which
|
||||
// will exit the thread.
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
// no local jobs, wait for shared job
|
||||
break guard;
|
||||
}
|
||||
// some job => drop guard, continue inner loop
|
||||
(Some(next), _) => {
|
||||
tracing::trace!("worker: executing job: {:?}", next);
|
||||
job = Some(next);
|
||||
continue;
|
||||
}
|
||||
// no job, no guard ought to be unreachable.
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
self.context.shared_job.wait(&mut guard);
|
||||
// a job was shared and we were notified, so we want to execute that job before any possible local jobs.
|
||||
job = guard.pop_job();
|
||||
job = self.find_work_or_wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerThread {
|
||||
pub(crate) fn find_work(&self) -> Option<NonNull<Job>> {
|
||||
self.find_work_inner().left()
|
||||
}
|
||||
|
||||
/// Looks for work in the local queue, then in the shared context, and if no
|
||||
/// work is found, waits for the thread to be notified of a new job, after
|
||||
/// which it returns `None`.
|
||||
/// The caller should then check for `should_exit` to determine if the
|
||||
/// thread should exit, or look for work again.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub(crate) fn find_work_or_wait(&self) -> Option<NonNull<Job>> {
|
||||
match self.find_work_inner() {
|
||||
either::Either::Left(job) => {
|
||||
return Some(job);
|
||||
}
|
||||
either::Either::Right(mut guard) => {
|
||||
// no jobs found, wait for a heartbeat or a new job
|
||||
tracing::trace!("WorkerThread::find_work_or_wait: waiting for new job");
|
||||
self.heartbeat.latch().wait_with_lock(&mut guard);
|
||||
tracing::trace!("WorkerThread::find_work_or_wait: woken up from wait");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn find_work_inner(
|
||||
&self,
|
||||
) -> either::Either<NonNull<Job>, parking_lot::MutexGuard<'_, crate::context::Shared>> {
|
||||
// first check the local queue for jobs
|
||||
if let Some(job) = self.pop_front() {
|
||||
tracing::trace!("WorkerThread::find_work_inner: found local job: {:?}", job);
|
||||
return either::Either::Left(job);
|
||||
}
|
||||
|
||||
// then check the shared context for jobs
|
||||
let mut guard = self.context.shared();
|
||||
|
||||
if let Some(job) = guard.pop_job() {
|
||||
tracing::trace!("WorkerThread::find_work_inner: found shared job: {:?}", job);
|
||||
return either::Either::Left(job);
|
||||
}
|
||||
|
||||
either::Either::Right(guard)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn tick(&self) {
|
||||
if self.heartbeat.is_pending() {
|
||||
tracing::trace!("received heartbeat, thread id: {:?}", self.index);
|
||||
if self.heartbeat.take() {
|
||||
tracing::trace!(
|
||||
"received heartbeat, thread id: {:?}",
|
||||
self.heartbeat.index()
|
||||
);
|
||||
self.heartbeat_cold();
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
fn execute(&self, job: NonNull<Job>) {
|
||||
self.tick();
|
||||
Job::execute(job);
|
||||
unsafe { Job::execute(job.as_ptr()) };
|
||||
}
|
||||
|
||||
#[cold]
|
||||
fn heartbeat_cold(&self) {
|
||||
let mut guard = self.context.shared();
|
||||
|
||||
if !guard.jobs.contains_key(&self.index) {
|
||||
if !guard.jobs.contains_key(&self.heartbeat.id()) {
|
||||
if let Some(job) = self.pop_back() {
|
||||
Job::set_shared(unsafe { job.as_ref() });
|
||||
tracing::trace!("heartbeat: sharing job: {:?}", job);
|
||||
guard.jobs.insert(self.heartbeat.id(), job);
|
||||
unsafe {
|
||||
job.as_ref().set_pending();
|
||||
// SAFETY: we are holding the lock on the shared context.
|
||||
self.context.notify_job_shared();
|
||||
}
|
||||
guard.jobs.insert(self.index, job);
|
||||
self.context.notify_shared_job();
|
||||
}
|
||||
}
|
||||
|
||||
self.heartbeat.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -220,33 +242,19 @@ impl HeartbeatThread {
|
|||
Self { ctx }
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn run(self) {
|
||||
tracing::trace!("new heartbeat thread {:?}", std::thread::current());
|
||||
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let sleep_for = {
|
||||
let mut guard = self.ctx.shared();
|
||||
if guard.should_exit() {
|
||||
if self.ctx.should_exit() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut n = 0;
|
||||
guard.heartbeats.retain(|_, b| {
|
||||
b.upgrade()
|
||||
.inspect(|heartbeat| {
|
||||
if n == i {
|
||||
if heartbeat.set_pending() {
|
||||
heartbeat.latch.set();
|
||||
}
|
||||
}
|
||||
n += 1;
|
||||
})
|
||||
.is_some()
|
||||
});
|
||||
let num_heartbeats = guard.heartbeats.len();
|
||||
|
||||
drop(guard);
|
||||
self.ctx.heartbeats.notify_nth(i);
|
||||
let num_heartbeats = self.ctx.heartbeats.len();
|
||||
|
||||
if i >= num_heartbeats {
|
||||
i = 0;
|
||||
|
@ -265,95 +273,114 @@ impl HeartbeatThread {
|
|||
}
|
||||
|
||||
impl WorkerThread {
|
||||
#[cold]
|
||||
fn wait_until_latch_cold(&self, latch: &CoreLatch) {
|
||||
// does this optimise?
|
||||
assert!(!latch.probe());
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn wait_until_queued_job<T>(
|
||||
&self,
|
||||
job: *const QueuedJob,
|
||||
) -> Option<std::thread::Result<T>> {
|
||||
let recv = unsafe { (*job).as_receiver::<T>() };
|
||||
// we've already checked that the job was popped from the queue
|
||||
// check if shared job is our job
|
||||
|
||||
'outer: while !latch.probe() {
|
||||
// process local jobs before locking shared context
|
||||
while let Some(job) = self.pop_front() {
|
||||
unsafe {
|
||||
job.as_ref().set_pending();
|
||||
// if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
|
||||
// if core::ptr::eq(shared_job.as_ptr(), job as *const Job as _) {
|
||||
// // this is the job we are looking for, so we want to
|
||||
// // short-circuit and call it inline
|
||||
// tracing::trace!(
|
||||
// thread = self.heartbeat.index(),
|
||||
// "reclaiming shared job: {:?}",
|
||||
// shared_job
|
||||
// );
|
||||
|
||||
// return None;
|
||||
// } else {
|
||||
// // this isn't the job we are looking for, but we still need to
|
||||
// // execute it
|
||||
// tracing::trace!(
|
||||
// thread = self.heartbeat.index(),
|
||||
// "executing reclaimed shared job: {:?}",
|
||||
// shared_job
|
||||
// );
|
||||
|
||||
// unsafe { Job::execute(shared_job.as_ptr()) };
|
||||
// }
|
||||
// }
|
||||
|
||||
loop {
|
||||
match recv.poll() {
|
||||
Some(t) => {
|
||||
return Some(t);
|
||||
}
|
||||
self.execute(job);
|
||||
}
|
||||
None => {
|
||||
cold_path();
|
||||
|
||||
// take a shared job, if it exists
|
||||
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
|
||||
self.execute(shared_job);
|
||||
}
|
||||
|
||||
while !latch.probe() {
|
||||
let job = {
|
||||
let mut guard = self.context.shared();
|
||||
guard.jobs.remove(&self.index).or_else(|| guard.pop_job())
|
||||
};
|
||||
|
||||
match job {
|
||||
Some(job) => {
|
||||
self.execute(job);
|
||||
|
||||
continue 'outer;
|
||||
}
|
||||
None => {
|
||||
// TODO: wait on latch? if we have something that can
|
||||
// signal being done, e.g. can be waited on instead of
|
||||
// shared jobs, we should wait on it instead, but we
|
||||
// would also want to receive shared jobs still?
|
||||
// Spin? probably just wastes CPU time.
|
||||
// self.context.shared_job.wait(&mut guard);
|
||||
// if spin.spin() {
|
||||
// // wait for more shared jobs.
|
||||
// // self.context.shared_job.wait(&mut guard);
|
||||
// return;
|
||||
// }
|
||||
// Yield? same as spinning, really, so just exit and let the upstream use wait
|
||||
// std::thread::yield_now();
|
||||
|
||||
tracing::trace!("thread {:?} is sleeping", self.index);
|
||||
|
||||
latch.set_sleeping();
|
||||
self.heartbeat.latch.wait_and_reset();
|
||||
// since we were sleeping, the shared job can't be populated,
|
||||
// so resuming the inner loop is fine.
|
||||
// check local jobs before locking shared context
|
||||
if let Some(job) = self.find_work_or_wait() {
|
||||
tracing::trace!(
|
||||
"thread {:?} executing local job: {:?}",
|
||||
self.heartbeat.index(),
|
||||
job
|
||||
);
|
||||
unsafe {
|
||||
Job::execute(job.as_ptr());
|
||||
}
|
||||
tracing::trace!(
|
||||
"thread {:?} finished local job: {:?}",
|
||||
self.heartbeat.index(),
|
||||
job
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
pub fn wait_until_job<T>(&self, job: &Job<T>, latch: &CoreLatch) -> Option<JobResult<T>> {
|
||||
// we've already checked that the job was popped from the queue
|
||||
// check if shared job is our job
|
||||
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
|
||||
if core::ptr::eq(shared_job.as_ptr(), job as *const Job<T> as _) {
|
||||
// this is the job we are looking for, so we want to
|
||||
// short-circuit and call it inline
|
||||
return None;
|
||||
} else {
|
||||
// this isn't the job we are looking for, but we still need to
|
||||
// execute it
|
||||
self.execute(shared_job);
|
||||
}
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn wait_until_latch<L>(&self, latch: &L)
|
||||
where
|
||||
L: Probe,
|
||||
{
|
||||
if !latch.probe() {
|
||||
tracing::trace!("thread {:?} waiting on latch", self.heartbeat.index());
|
||||
self.wait_until_latch_cold(latch);
|
||||
}
|
||||
}
|
||||
|
||||
#[cold]
|
||||
fn wait_until_latch_cold<L>(&self, latch: &L)
|
||||
where
|
||||
L: Probe,
|
||||
{
|
||||
if let Some(shared_job) = self.context.shared().jobs.remove(&self.heartbeat.id()) {
|
||||
tracing::trace!(
|
||||
"thread {:?} reclaiming shared job: {:?}",
|
||||
self.heartbeat.index(),
|
||||
shared_job
|
||||
);
|
||||
unsafe { Job::execute(shared_job.as_ptr()) };
|
||||
}
|
||||
|
||||
// do the usual thing and wait for the job's latch
|
||||
if !latch.probe() {
|
||||
self.wait_until_latch_cold(latch);
|
||||
}
|
||||
|
||||
Some(job.wait())
|
||||
}
|
||||
|
||||
pub fn wait_until_latch<L>(&self, latch: &L)
|
||||
where
|
||||
L: AsCoreLatch,
|
||||
{
|
||||
let latch = latch.as_core_latch();
|
||||
if !latch.probe() {
|
||||
self.wait_until_latch_cold(latch)
|
||||
// do the usual thing??? chatgipity really said this..
|
||||
while !latch.probe() {
|
||||
// check local jobs before locking shared context
|
||||
if let Some(job) = self.find_work_or_wait() {
|
||||
tracing::trace!(
|
||||
"thread {:?} executing local job: {:?}",
|
||||
self.heartbeat.index(),
|
||||
job
|
||||
);
|
||||
unsafe {
|
||||
Job::execute(job.as_ptr());
|
||||
}
|
||||
tracing::trace!(
|
||||
"thread {:?} finished local job: {:?}",
|
||||
self.heartbeat.index(),
|
||||
job
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,10 +4,10 @@ use executor::util::tree::Tree;
|
|||
|
||||
const TREE_SIZE: usize = 16;
|
||||
|
||||
fn join_scope() {
|
||||
fn join_scope(tree_size: usize) {
|
||||
let pool = ThreadPool::new();
|
||||
|
||||
let tree = Tree::new(TREE_SIZE, 1);
|
||||
let tree = Tree::new(tree_size, 1);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize, scope: &Scope) -> u32 {
|
||||
let node = tree.get(node);
|
||||
|
@ -31,10 +31,10 @@ fn join_scope() {
|
|||
}
|
||||
}
|
||||
|
||||
fn join_pool() {
|
||||
fn join_pool(tree_size: usize) {
|
||||
let pool = ThreadPool::new();
|
||||
|
||||
let tree = Tree::new(TREE_SIZE, 1);
|
||||
let tree = Tree::new(tree_size, 1);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize, pool: &ThreadPool) -> u32 {
|
||||
let node = tree.get(node);
|
||||
|
@ -60,10 +60,10 @@ fn join_pool() {
|
|||
eprintln!("sum: {sum}");
|
||||
}
|
||||
|
||||
fn join_distaff() {
|
||||
fn join_distaff(tree_size: usize) {
|
||||
use distaff::*;
|
||||
let pool = ThreadPool::new();
|
||||
let tree = Tree::new(TREE_SIZE, 1);
|
||||
let tree = Tree::new(tree_size, 1);
|
||||
|
||||
fn sum<'scope, 'env>(tree: &Tree<u32>, node: usize, scope: &'scope Scope<'scope, 'env>) -> u32 {
|
||||
let node = tree.get(node);
|
||||
|
@ -81,17 +81,15 @@ fn join_distaff() {
|
|||
node.leaf + l + r
|
||||
}
|
||||
|
||||
for _ in 0..1000 {
|
||||
let sum = pool.scope(|s| {
|
||||
let sum = sum(&tree, tree.root().unwrap(), s);
|
||||
sum
|
||||
});
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
let sum = pool.scope(|s| {
|
||||
let sum = sum(&tree, tree.root().unwrap(), s);
|
||||
sum
|
||||
});
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
|
||||
fn join_chili() {
|
||||
let tree = Tree::new(TREE_SIZE, 1u32);
|
||||
fn join_chili(tree_size: usize) {
|
||||
let tree = Tree::new(tree_size, 1u32);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize, scope: &mut chili::Scope<'_>) -> u32 {
|
||||
let node = tree.get(node);
|
||||
|
@ -113,8 +111,8 @@ fn join_chili() {
|
|||
}
|
||||
}
|
||||
|
||||
fn join_rayon() {
|
||||
let tree = Tree::new(TREE_SIZE, 1u32);
|
||||
fn join_rayon(tree_size: usize) {
|
||||
let tree = Tree::new(tree_size, 1u32);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize) -> u32 {
|
||||
let node = tree.get(node);
|
||||
|
@ -133,19 +131,34 @@ fn join_rayon() {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
tracing_subscriber::fmt::init();
|
||||
// use tracing_subscriber::layer::SubscriberExt;
|
||||
// tracing::subscriber::set_global_default(
|
||||
// tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()),
|
||||
// )
|
||||
// .expect("Failed to set global default subscriber");
|
||||
|
||||
let size = std::env::args()
|
||||
.nth(2)
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(TREE_SIZE);
|
||||
|
||||
match std::env::args().nth(1).as_deref() {
|
||||
Some("scope") => join_scope(),
|
||||
Some("pool") => join_pool(),
|
||||
Some("chili") => join_chili(),
|
||||
Some("distaff") => join_distaff(),
|
||||
Some("rayon") => join_rayon(),
|
||||
Some("scope") => join_scope(size),
|
||||
Some("pool") => join_pool(size),
|
||||
Some("chili") => join_chili(size),
|
||||
Some("distaff") => join_distaff(size),
|
||||
Some("rayon") => join_rayon(size),
|
||||
_ => {
|
||||
eprintln!(
|
||||
"Usage: {} [scope|pool|chili|distaff|rayon]",
|
||||
std::env::args().next().unwrap()
|
||||
"Usage: {} [scope|pool|chili|distaff|rayon] <tree_size={}>",
|
||||
std::env::args().next().unwrap(),
|
||||
TREE_SIZE
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("Done!");
|
||||
// wait for user input before exiting
|
||||
// std::io::stdin().read_line(&mut String::new()).unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue