Compare commits
33 commits
main
...
per-thread
Author | SHA1 | Date | |
---|---|---|---|
|
fad57b74e9 | ||
|
26b6ef264c | ||
|
268879d97e | ||
|
7c6e338b77 | ||
|
0836c7c958 | ||
|
b635ea5579 | ||
|
d1244026ca | ||
|
41166898ff | ||
|
f8aa8d9615 | ||
|
edf25e407f | ||
|
6e4f6a1285 | ||
|
69d3794ff1 | ||
|
f384f61f81 | ||
|
19ef21e2ef | ||
|
38ce1de3ac | ||
|
09166a8eb7 | ||
|
228aa4d544 | ||
|
6fe5351e59 | ||
|
9cc125e558 | ||
|
2a0372a8a0 | ||
|
8b4eba5a19 | ||
|
a1e1c90f90 | ||
|
5fae03dc06 | ||
|
c4b4f9248a | ||
|
3b07565118 | ||
|
bdbe207e7e | ||
|
eb8fd314f5 | ||
|
c3eb71dbb1 | ||
|
0db285a4a9 | ||
|
4742733683 | ||
|
1363f20cfc | ||
|
a3b9222ed9 | ||
|
ed4acbfbd7 |
|
@ -11,7 +11,6 @@ work-stealing = []
|
|||
prefer-local = []
|
||||
never-local = []
|
||||
|
||||
|
||||
[profile.bench]
|
||||
debug = true
|
||||
|
||||
|
@ -30,7 +29,7 @@ parking_lot = {version = "0.12.3"}
|
|||
thread_local = "1.1.8"
|
||||
crossbeam = "0.8.4"
|
||||
st3 = "0.4"
|
||||
chili = "0.2.0"
|
||||
chili = "0.2.1"
|
||||
|
||||
async-task = "4.7.1"
|
||||
|
||||
|
@ -48,3 +47,5 @@ cfg-if = "1.0.0"
|
|||
[dev-dependencies]
|
||||
async-std = "1.13.0"
|
||||
tracing-test = "0.2.5"
|
||||
tracing-tracy = "0.11.4"
|
||||
distaff = {path = "distaff", features = ["tracing"]}
|
||||
|
|
|
@ -56,7 +56,7 @@ mod tree {
|
|||
}
|
||||
}
|
||||
|
||||
const TREE_SIZE: usize = 16;
|
||||
const TREE_SIZE: usize = 8;
|
||||
|
||||
#[bench]
|
||||
fn join_melange(b: &mut Bencher) {
|
||||
|
@ -184,3 +184,34 @@ fn join_rayon(b: &mut Bencher) {
|
|||
assert_ne!(sum(&tree, tree.root().unwrap()), 0);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn join_distaff(b: &mut Bencher) {
|
||||
use distaff::*;
|
||||
let pool = ThreadPool::new();
|
||||
let tree = tree::Tree::new(TREE_SIZE, 1u32);
|
||||
|
||||
fn sum<'scope, 'env>(tree: &tree::Tree<u32>, node: usize, scope: Scope<'scope, 'env>) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = scope.join(
|
||||
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
||||
|s| {
|
||||
node.right
|
||||
.map(|node| sum(tree, node, s))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
);
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
b.iter(move || {
|
||||
let sum = pool.scope(|s| {
|
||||
let sum = sum(&tree, tree.root().unwrap(), s);
|
||||
assert_ne!(sum, 0);
|
||||
sum
|
||||
});
|
||||
std::hint::black_box(sum);
|
||||
});
|
||||
eprintln!("Done with distaff join");
|
||||
}
|
||||
|
|
|
@ -3,14 +3,40 @@ name = "distaff"
|
|||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[profile.bench]
|
||||
opt-level = 0
|
||||
debug = true
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tracing = ["dep:tracing"]
|
||||
std = []
|
||||
metrics = []
|
||||
|
||||
[dependencies]
|
||||
parking_lot = {version = "0.12.3"}
|
||||
tracing = "0.1.40"
|
||||
atomic-wait = "1.1.0"
|
||||
tracing = {version = "0.1", optional = true}
|
||||
parking_lot_core = "0.9.10"
|
||||
crossbeam-utils = "0.8.21"
|
||||
either = "1.15.0"
|
||||
|
||||
werkzeug = {path = "../../werkzeug", features = ["std", "nightly"]}
|
||||
|
||||
async-task = "4.7.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tracing-test = {version = "0.2"}
|
||||
tracing-tracy = {version = "0.11"}
|
||||
futures = "0.3"
|
||||
|
||||
divan = "0.1.14"
|
||||
rayon = "1.10.0"
|
||||
chili = {path = "../../chili"}
|
||||
|
||||
[[bench]]
|
||||
name = "overhead"
|
||||
harness = false
|
119
distaff/benches/overhead.rs
Normal file
119
distaff/benches/overhead.rs
Normal file
|
@ -0,0 +1,119 @@
|
|||
use distaff::{Scope, ThreadPool};
|
||||
use divan::Bencher;
|
||||
|
||||
struct Node {
|
||||
val: u64,
|
||||
left: Option<Box<Node>>,
|
||||
right: Option<Box<Node>>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn tree(layers: usize) -> Self {
|
||||
Self {
|
||||
val: 1,
|
||||
left: (layers != 1).then(|| Box::new(Self::tree(layers - 1))),
|
||||
right: (layers != 1).then(|| Box::new(Self::tree(layers - 1))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const LAYERS: &[usize] = &[10, 24];
|
||||
fn nodes() -> impl Iterator<Item = (usize, usize)> {
|
||||
LAYERS.iter().map(|&l| (l, (1 << l) - 1))
|
||||
}
|
||||
|
||||
#[divan::bench(args = nodes())]
|
||||
fn no_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||
fn join_no_overhead<A, B, RA, RB>(scope: Scope<'_, '_>, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
A: FnOnce(Scope<'_, '_>) -> RA + Send,
|
||||
B: FnOnce(Scope<'_, '_>) -> RB + Send,
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
{
|
||||
(a(scope), b(scope))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn sum(node: &Node, scope: Scope<'_, '_>) -> u64 {
|
||||
let (left, right) = join_no_overhead(
|
||||
scope,
|
||||
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||
|s| node.right.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||
);
|
||||
|
||||
node.val + left + right
|
||||
}
|
||||
|
||||
let tree = Node::tree(nodes.0);
|
||||
let pool = ThreadPool::global();
|
||||
|
||||
bencher.bench_local(move || {
|
||||
pool.scope(|scope| {
|
||||
assert_eq!(sum(&tree, scope), nodes.1 as u64);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench(args = nodes())]
|
||||
fn distaff_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||
fn sum<'scope, 'env>(node: &Node, scope: Scope<'scope, 'env>) -> u64 {
|
||||
let (left, right) = scope.join(
|
||||
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||
|s| node.right.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||
);
|
||||
|
||||
node.val + left + right
|
||||
}
|
||||
|
||||
let tree = Node::tree(nodes.0);
|
||||
let pool = ThreadPool::global();
|
||||
|
||||
bencher.bench_local(move || {
|
||||
pool.scope(|scope| {
|
||||
assert_eq!(sum(&tree, scope), nodes.1 as u64);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench(args = nodes())]
|
||||
fn rayon_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||
fn sum(node: &Node) -> u64 {
|
||||
let (left, right) = rayon::join(
|
||||
|| node.left.as_deref().map(sum).unwrap_or_default(),
|
||||
|| node.right.as_deref().map(sum).unwrap_or_default(),
|
||||
);
|
||||
|
||||
node.val + left + right
|
||||
}
|
||||
|
||||
let tree = Node::tree(nodes.0);
|
||||
|
||||
bencher.bench_local(move || {
|
||||
assert_eq!(sum(&tree), nodes.1 as u64);
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench(args = nodes())]
|
||||
fn chili_overhead(bencher: Bencher, nodes: (usize, usize)) {
|
||||
use chili::Scope;
|
||||
fn sum(node: &Node, scope: &mut Scope<'_>) -> u64 {
|
||||
let (left, right) = scope.join(
|
||||
|s| node.left.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||
|s| node.right.as_deref().map(|n| sum(n, s)).unwrap_or_default(),
|
||||
);
|
||||
|
||||
node.val + left + right
|
||||
}
|
||||
|
||||
let tree = Node::tree(nodes.0);
|
||||
let mut scope = Scope::global();
|
||||
|
||||
bencher.bench_local(move || {
|
||||
assert_eq!(sum(&tree, &mut scope), nodes.1 as u64);
|
||||
});
|
||||
}
|
||||
|
||||
fn main() {
|
||||
divan::main();
|
||||
}
|
205
distaff/src/channel.rs
Normal file
205
distaff/src/channel.rs
Normal file
|
@ -0,0 +1,205 @@
|
|||
// This file is taken from [`chili`]
|
||||
|
||||
use std::{
|
||||
cell::UnsafeCell,
|
||||
ptr::NonNull,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU8, AtomicU32, Ordering},
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
enum State {
|
||||
Pending,
|
||||
Waiting,
|
||||
Ready,
|
||||
Taken,
|
||||
}
|
||||
|
||||
pub use werkzeug::sync::Parker;
|
||||
|
||||
use crate::queue::Queue;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
struct Channel<T = ()> {
|
||||
state: AtomicU8,
|
||||
/// Can only be written only by the `Receiver` and read by the `Sender` if
|
||||
/// `state` is `State::Waiting`.
|
||||
waiting_thread: NonNull<Parker>,
|
||||
/// Can only be written only by the `Sender` and read by the `Receiver` if
|
||||
/// `state` is `State::Ready`.
|
||||
val: UnsafeCell<Option<Box<thread::Result<T>>>>,
|
||||
}
|
||||
|
||||
impl<T> Channel<T> {
|
||||
fn new(waiting_thread: NonNull<Parker>) -> Self {
|
||||
Self {
|
||||
state: AtomicU8::new(State::Pending as u8),
|
||||
waiting_thread,
|
||||
val: UnsafeCell::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Receiver<T = ()>(Arc<Channel<T>>);
|
||||
|
||||
impl<T: Send> Receiver<T> {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.state.load(Ordering::Acquire) != State::Ready as u8
|
||||
}
|
||||
|
||||
pub fn sender(&self) -> Sender<T> {
|
||||
Sender(self.0.clone())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn wait(&self) {
|
||||
loop {
|
||||
match self.0.state.compare_exchange(
|
||||
State::Pending as u8,
|
||||
State::Waiting as u8,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => {
|
||||
// SAFETY:
|
||||
// The `waiting_thread` is set to the current thread's parker
|
||||
// before we park it.
|
||||
unsafe {
|
||||
let thread = self.0.waiting_thread.as_ref();
|
||||
thread.park();
|
||||
}
|
||||
|
||||
// we might have been woken up because of a shared job.
|
||||
// In that case, we need to check the state again.
|
||||
if self
|
||||
.0
|
||||
.state
|
||||
.compare_exchange(
|
||||
State::Waiting as u8,
|
||||
State::Pending as u8,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
continue;
|
||||
} else {
|
||||
// The state was changed to `State::Ready` by the `Sender`, so we can return.
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(state) if state == State::Ready as u8 => {
|
||||
// The channel is ready, so we can return immediately.
|
||||
return;
|
||||
}
|
||||
_ => {
|
||||
panic!("Receiver is already waiting or consumed.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn poll(&self) -> Option<thread::Result<T>> {
|
||||
if self
|
||||
.0
|
||||
.state
|
||||
.compare_exchange(
|
||||
State::Ready as u8,
|
||||
State::Taken as u8,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
unsafe { Some(self.take()) }
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn recv(self) -> thread::Result<T> {
|
||||
self.wait();
|
||||
|
||||
// SAFETY:
|
||||
// To arrive here, either `state` is `State::Ready` or the above
|
||||
// `compare_exchange` succeeded, the thread was parked and then
|
||||
// unparked by the `Sender` *after* the `state` was set to
|
||||
// `State::Ready`.
|
||||
//
|
||||
// In either case, this thread now has unique access to `val`.
|
||||
assert_eq!(
|
||||
self.0.state.swap(State::Taken as u8, Ordering::Acquire),
|
||||
State::Ready as u8
|
||||
);
|
||||
|
||||
unsafe { self.take() }
|
||||
}
|
||||
|
||||
unsafe fn take(&self) -> thread::Result<T> {
|
||||
let result = unsafe { (*self.0.val.get()).take().map(|b| *b).unwrap() };
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct Sender<T = ()>(Arc<Channel<T>>);
|
||||
|
||||
impl<T: Send> Sender<T> {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn send(self, val: thread::Result<T>) {
|
||||
// SAFETY:
|
||||
// Only this thread can write to `val` and none can read it
|
||||
// yet.
|
||||
unsafe {
|
||||
*self.0.val.get() = Some(Box::new(val));
|
||||
}
|
||||
|
||||
if self.0.state.swap(State::Ready as u8, Ordering::AcqRel) == State::Waiting as u8 {
|
||||
// SAFETY:
|
||||
// A `Receiver` already wrote its thread to `waiting_thread`
|
||||
// *before* setting the `state` to `State::Waiting`.
|
||||
unsafe {
|
||||
let thread = self.0.waiting_thread.as_ref();
|
||||
thread.unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn parker(&self) -> &Parker {
|
||||
unsafe { self.0.waiting_thread.as_ref() }
|
||||
}
|
||||
|
||||
/// The caller must ensure that this function or `send` are only ever called once.
|
||||
pub unsafe fn send_as_ref(&self, val: thread::Result<T>) {
|
||||
// SAFETY:
|
||||
// Only this thread can write to `val` and none can read it
|
||||
// yet.
|
||||
unsafe {
|
||||
*self.0.val.get() = Some(Box::new(val));
|
||||
}
|
||||
|
||||
if self.0.state.swap(State::Ready as u8, Ordering::AcqRel) == State::Waiting as u8 {
|
||||
// SAFETY:
|
||||
// A `Receiver` already wrote its thread to `waiting_thread`
|
||||
// *before* setting the `state` to `State::Waiting`.
|
||||
unsafe {
|
||||
let thread = self.0.waiting_thread.as_ref();
|
||||
thread.unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn channel<T: Send>(thread: NonNull<Parker>) -> (Sender<T>, Receiver<T>) {
|
||||
let channel = Arc::new(Channel::new(thread));
|
||||
|
||||
(Sender(channel.clone()), Receiver(channel))
|
||||
}
|
|
@ -1,92 +1,56 @@
|
|||
use std::{
|
||||
cell::UnsafeCell,
|
||||
marker::PhantomPinned,
|
||||
mem::{self, ManuallyDrop},
|
||||
panic::{AssertUnwindSafe, catch_unwind},
|
||||
pin::Pin,
|
||||
ptr::NonNull,
|
||||
sync::{
|
||||
Arc, OnceLock, Weak,
|
||||
atomic::{AtomicU8, Ordering},
|
||||
Arc, OnceLock,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use alloc::collections::BTreeMap;
|
||||
|
||||
use crossbeam_utils::CachePadded;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use async_task::Runnable;
|
||||
|
||||
use crate::{
|
||||
job::{Job, StackJob},
|
||||
latch::{LatchRef, MutexLatch, WakeLatch},
|
||||
channel::{Parker, Sender},
|
||||
heartbeat::HeartbeatList,
|
||||
job::{HeapJob, Job2 as Job, SharedJob, StackJob},
|
||||
queue::ReceiverToken,
|
||||
util::DropGuard,
|
||||
workerthread::{HeartbeatThread, WorkerThread},
|
||||
};
|
||||
|
||||
pub struct Heartbeat {
|
||||
heartbeat: AtomicU8,
|
||||
pub latch: MutexLatch,
|
||||
}
|
||||
|
||||
impl Heartbeat {
|
||||
pub const CLEAR: u8 = 0;
|
||||
pub const PENDING: u8 = 1;
|
||||
pub const SLEEPING: u8 = 2;
|
||||
|
||||
pub fn new() -> (Arc<CachePadded<Self>>, Weak<CachePadded<Self>>) {
|
||||
let strong = Arc::new(CachePadded::new(Self {
|
||||
heartbeat: AtomicU8::new(Self::CLEAR),
|
||||
latch: MutexLatch::new(),
|
||||
}));
|
||||
let weak = Arc::downgrade(&strong);
|
||||
|
||||
(strong, weak)
|
||||
}
|
||||
|
||||
/// returns true if the heartbeat was previously sleeping.
|
||||
pub fn set_pending(&self) -> bool {
|
||||
let old = self.heartbeat.swap(Self::PENDING, Ordering::Relaxed);
|
||||
old == Self::SLEEPING
|
||||
}
|
||||
|
||||
pub fn clear(&self) {
|
||||
self.heartbeat.store(Self::CLEAR, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn is_pending(&self) -> bool {
|
||||
self.heartbeat.load(Ordering::Relaxed) == Self::PENDING
|
||||
}
|
||||
|
||||
pub fn is_sleeping(&self) -> bool {
|
||||
self.heartbeat.load(Ordering::Relaxed) == Self::SLEEPING
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Context {
|
||||
shared: Mutex<Shared>,
|
||||
pub shared_job: Condvar,
|
||||
should_exit: AtomicBool,
|
||||
pub heartbeats: HeartbeatList,
|
||||
pub(crate) queue: Arc<crate::queue::Queue<Message>>,
|
||||
pub(crate) heartbeat: Parker,
|
||||
}
|
||||
|
||||
pub(crate) enum Message {
|
||||
Shared(SharedJob),
|
||||
WakeUp,
|
||||
Exit,
|
||||
ScopeFinished,
|
||||
}
|
||||
|
||||
pub(crate) struct Shared {
|
||||
pub jobs: BTreeMap<usize, NonNull<Job>>,
|
||||
pub heartbeats: BTreeMap<usize, Weak<CachePadded<Heartbeat>>>,
|
||||
injected_jobs: Vec<NonNull<Job>>,
|
||||
heartbeat_count: usize,
|
||||
should_exit: bool,
|
||||
pub jobs: BTreeMap<usize, SharedJob>,
|
||||
injected_jobs: Vec<SharedJob>,
|
||||
}
|
||||
|
||||
unsafe impl Send for Shared {}
|
||||
|
||||
impl Shared {
|
||||
pub fn new_heartbeat(&mut self) -> (Arc<CachePadded<Heartbeat>>, usize) {
|
||||
let index = self.heartbeat_count;
|
||||
self.heartbeat_count = index.wrapping_add(1);
|
||||
|
||||
let (strong, weak) = Heartbeat::new();
|
||||
|
||||
self.heartbeats.insert(index, weak);
|
||||
|
||||
(strong, index)
|
||||
}
|
||||
|
||||
pub fn pop_job(&mut self) -> Option<NonNull<Job>> {
|
||||
pub fn pop_job(&mut self) -> Option<SharedJob> {
|
||||
// this is unlikely, so make the function cold?
|
||||
// TODO: profile this
|
||||
if !self.injected_jobs.is_empty() {
|
||||
// SAFETY: we checked that injected_jobs is not empty
|
||||
unsafe { return Some(self.pop_injected_job()) };
|
||||
} else {
|
||||
self.jobs.pop_first().map(|(_, job)| job)
|
||||
|
@ -94,34 +58,22 @@ impl Shared {
|
|||
}
|
||||
|
||||
#[cold]
|
||||
unsafe fn pop_injected_job(&mut self) -> NonNull<Job> {
|
||||
unsafe fn pop_injected_job(&mut self) -> SharedJob {
|
||||
self.injected_jobs.pop().unwrap()
|
||||
}
|
||||
|
||||
pub fn should_exit(&self) -> bool {
|
||||
self.should_exit
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
#[inline]
|
||||
pub fn shared(&self) -> parking_lot::MutexGuard<'_, Shared> {
|
||||
self.shared.lock()
|
||||
}
|
||||
|
||||
pub fn new_with_threads(num_threads: usize) -> Arc<Self> {
|
||||
let this = Arc::new(Self {
|
||||
shared: Mutex::new(Shared {
|
||||
jobs: BTreeMap::new(),
|
||||
heartbeats: BTreeMap::new(),
|
||||
injected_jobs: Vec::new(),
|
||||
heartbeat_count: 0,
|
||||
should_exit: false,
|
||||
}),
|
||||
shared_job: Condvar::new(),
|
||||
});
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Creating context with {} threads", num_threads);
|
||||
|
||||
tracing::trace!("Creating thread pool with {} threads", num_threads);
|
||||
let this = Arc::new(Self {
|
||||
should_exit: AtomicBool::new(false),
|
||||
heartbeats: HeartbeatList::new(),
|
||||
queue: crate::queue::Queue::new(),
|
||||
heartbeat: Parker::new(),
|
||||
});
|
||||
|
||||
// Create a barrier to synchronize the worker threads and the heartbeat thread
|
||||
let barrier = Arc::new(std::sync::Barrier::new(num_threads + 2));
|
||||
|
@ -135,8 +87,7 @@ impl Context {
|
|||
.spawn(move || {
|
||||
let worker = Box::new(WorkerThread::new_in(ctx));
|
||||
|
||||
barrier.wait();
|
||||
worker.run();
|
||||
worker.run(barrier);
|
||||
})
|
||||
.expect("Failed to spawn worker thread");
|
||||
}
|
||||
|
@ -148,8 +99,7 @@ impl Context {
|
|||
std::thread::Builder::new()
|
||||
.name("heartbeat-thread".to_string())
|
||||
.spawn(move || {
|
||||
barrier.wait();
|
||||
HeartbeatThread::new(ctx).run();
|
||||
HeartbeatThread::new(ctx, num_threads).run(barrier);
|
||||
})
|
||||
.expect("Failed to spawn heartbeat thread");
|
||||
}
|
||||
|
@ -159,6 +109,15 @@ impl Context {
|
|||
this
|
||||
}
|
||||
|
||||
pub fn set_should_exit(&self) {
|
||||
self.should_exit.store(true, Ordering::Relaxed);
|
||||
self.queue.as_sender().broadcast_with(|| Message::Exit);
|
||||
}
|
||||
|
||||
pub fn should_exit(&self) -> bool {
|
||||
self.should_exit.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn new() -> Arc<Self> {
|
||||
Self::new_with_threads(crate::util::available_parallelism())
|
||||
}
|
||||
|
@ -169,14 +128,8 @@ impl Context {
|
|||
GLOBAL_CONTEXT.get_or_init(|| Self::new())
|
||||
}
|
||||
|
||||
pub fn inject_job(&self, job: NonNull<Job>) {
|
||||
let mut shared = self.shared.lock();
|
||||
shared.injected_jobs.push(job);
|
||||
self.notify_shared_job();
|
||||
}
|
||||
|
||||
pub fn notify_shared_job(&self) {
|
||||
self.shared_job.notify_one();
|
||||
pub fn inject_job(&self, job: SharedJob) {
|
||||
self.queue.as_sender().anycast(Message::Shared(job));
|
||||
}
|
||||
|
||||
/// Runs closure in this context, processing the other context's worker's jobs while waiting for the result.
|
||||
|
@ -187,28 +140,21 @@ impl Context {
|
|||
{
|
||||
// current thread is not in the same context, create a job and inject it into the other thread's context, then wait while working on our jobs.
|
||||
|
||||
let latch = WakeLatch::new(self.clone(), worker.index);
|
||||
// SAFETY: we are waiting on this latch in this thread.
|
||||
let _pinned = StackJob::new(move |worker: &WorkerThread| f(worker));
|
||||
let job = unsafe { Pin::new_unchecked(&_pinned) };
|
||||
|
||||
let job = StackJob::new(
|
||||
move || {
|
||||
let worker = WorkerThread::current_ref()
|
||||
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
||||
let job = Job::from_stackjob(&job);
|
||||
unsafe {
|
||||
self.inject_job(job.share(Some(worker.receiver.get_token().as_parker())));
|
||||
}
|
||||
|
||||
f(worker)
|
||||
},
|
||||
LatchRef::new(&latch),
|
||||
);
|
||||
let t = worker.wait_until_recv(job.take_receiver().expect("Job should have a receiver"));
|
||||
|
||||
let job = job.as_job();
|
||||
job.set_pending();
|
||||
// touch the job to ensure it is dropped after we are done with it.
|
||||
drop(_pinned);
|
||||
|
||||
self.inject_job(Into::into(&job));
|
||||
|
||||
worker.wait_until_latch(&latch);
|
||||
|
||||
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
|
||||
|
||||
t
|
||||
crate::util::unwrap_or_panic(t)
|
||||
}
|
||||
|
||||
/// Run closure in this context, sleeping until the job is done.
|
||||
|
@ -217,52 +163,102 @@ impl Context {
|
|||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
use crate::latch::MutexLatch;
|
||||
// current thread isn't a worker thread, create job and inject into global context
|
||||
// current thread isn't a worker thread, create job and inject into context
|
||||
let parker = Parker::new();
|
||||
|
||||
let latch = MutexLatch::new();
|
||||
struct CrossJob<F> {
|
||||
f: UnsafeCell<ManuallyDrop<F>>,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
|
||||
let job = StackJob::new(
|
||||
move || {
|
||||
let worker = WorkerThread::current_ref()
|
||||
.expect("WorkerThread::run_in_worker called outside of worker thread");
|
||||
impl<F> CrossJob<F> {
|
||||
fn new(f: F) -> Self {
|
||||
Self {
|
||||
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
f(worker)
|
||||
},
|
||||
LatchRef::new(&latch),
|
||||
);
|
||||
fn into_job<T>(self: &Self) -> Job<T>
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
|
||||
}
|
||||
|
||||
let job = job.as_job();
|
||||
job.set_pending();
|
||||
unsafe fn unwrap(&self) -> F {
|
||||
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
||||
}
|
||||
|
||||
self.inject_job(Into::into(&job));
|
||||
latch.wait();
|
||||
#[align(8)]
|
||||
unsafe fn harness<T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
let this: &CrossJob<F> = unsafe { this.cast().as_ref() };
|
||||
let f = unsafe { this.unwrap() };
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
|
||||
let t = unsafe { job.transmute_ref::<T>().wait().into_result() };
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
|
||||
|
||||
t
|
||||
let sender = sender.unwrap();
|
||||
unsafe {
|
||||
sender.send_as_ref(result);
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pinned = CrossJob::new(move |worker: &WorkerThread| f(worker));
|
||||
let job2 = pinned.into_job();
|
||||
|
||||
self.inject_job(job2.share(Some(&parker)));
|
||||
|
||||
let recv = job2.take_receiver().unwrap();
|
||||
|
||||
let out = crate::util::unwrap_or_panic(recv.recv());
|
||||
|
||||
// touch the job to ensure it is dropped after we are done with it.
|
||||
drop(pinned);
|
||||
drop(parker);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Run closure in this context.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn run_in_worker<T, F>(self: &Arc<Self>, f: F) -> T
|
||||
where
|
||||
T: Send,
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
{
|
||||
let _guard = DropGuard::new(|| {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("run_in_worker: finished");
|
||||
});
|
||||
match WorkerThread::current_ref() {
|
||||
Some(worker) => {
|
||||
// check if worker is in the same context
|
||||
if Arc::ptr_eq(&worker.context, self) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("run_in_worker: current thread");
|
||||
f(worker)
|
||||
} else {
|
||||
// current thread is a worker for a different context
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("run_in_worker: cross-context");
|
||||
self.run_in_worker_cross(worker, f)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// current thread is not a worker for any context
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("run_in_worker: inject into context");
|
||||
self.run_in_worker_cold(f)
|
||||
}
|
||||
|
@ -270,6 +266,56 @@ impl Context {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn spawn<F>(self: &Arc<Self>, f: F)
|
||||
where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
let job = Job::from_heapjob(HeapJob::new(|_: &WorkerThread| f()));
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Context::spawn: spawning job: {:?}", job);
|
||||
self.inject_job(job.share(None));
|
||||
}
|
||||
|
||||
pub fn spawn_future<T, F>(self: &Arc<Self>, future: F) -> async_task::Task<T>
|
||||
where
|
||||
F: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let schedule = move |runnable: Runnable| {
|
||||
#[align(8)]
|
||||
unsafe fn harness<T>(_: &WorkerThread, this: NonNull<()>, _: Option<Sender>) {
|
||||
unsafe {
|
||||
let runnable = Runnable::<()>::from_raw(this);
|
||||
runnable.run();
|
||||
}
|
||||
}
|
||||
|
||||
let job = Job::<T>::from_harness(harness::<T>, runnable.into_raw());
|
||||
|
||||
self.inject_job(job.share(None));
|
||||
};
|
||||
|
||||
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
||||
|
||||
runnable.schedule();
|
||||
|
||||
task
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn spawn_async<T, Fut, Fn>(self: &Arc<Self>, f: Fn) -> async_task::Task<T>
|
||||
where
|
||||
Fn: FnOnce() -> Fut + Send + 'static,
|
||||
Fut: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let future = async move { f().await };
|
||||
|
||||
self.spawn_future(future)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_in_worker<T, F>(f: F) -> T
|
||||
where
|
||||
T: Send,
|
||||
|
@ -277,3 +323,105 @@ where
|
|||
{
|
||||
Context::global_context().run_in_worker(f)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::AtomicU8;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn run_in_worker() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let result = ctx.run_in_worker(|_| 42);
|
||||
assert_eq!(result, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn context_spawn_future() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let task = ctx.spawn_future(async { 42 });
|
||||
|
||||
// Wait for the task to complete
|
||||
let result = futures::executor::block_on(task);
|
||||
assert_eq!(result, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn context_spawn_async() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let task = ctx.spawn_async(|| async { 42 });
|
||||
|
||||
// Wait for the task to complete
|
||||
let result = futures::executor::block_on(task);
|
||||
assert_eq!(result, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn context_spawn() {
|
||||
let ctx = Context::global_context().clone();
|
||||
let counter = Arc::new(AtomicU8::new(0));
|
||||
let barrier = Arc::new(std::sync::Barrier::new(2));
|
||||
|
||||
ctx.spawn({
|
||||
let counter = counter.clone();
|
||||
let barrier = barrier.clone();
|
||||
move || {
|
||||
counter.fetch_add(1, Ordering::SeqCst);
|
||||
barrier.wait();
|
||||
}
|
||||
});
|
||||
|
||||
barrier.wait();
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn inject_job_and_wake_worker() {
|
||||
let ctx = Context::new_with_threads(1);
|
||||
let counter = Arc::new(AtomicU8::new(0));
|
||||
|
||||
let parker = Parker::new();
|
||||
|
||||
let job = StackJob::new({
|
||||
let counter = counter.clone();
|
||||
move |_: &WorkerThread| {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::info!("Job running");
|
||||
counter.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
42
|
||||
}
|
||||
});
|
||||
|
||||
let job = Job::from_stackjob(&job);
|
||||
|
||||
// wait for the worker to sleep
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
|
||||
ctx.heartbeats
|
||||
.inner()
|
||||
.iter_mut()
|
||||
.next()
|
||||
.map(|(_, heartbeat)| {
|
||||
assert!(heartbeat.is_waiting());
|
||||
});
|
||||
|
||||
ctx.inject_job(job.share(Some(&parker)));
|
||||
|
||||
// Wait for the job to be executed
|
||||
let recv = job.take_receiver().expect("Job should have a receiver");
|
||||
let Some(result) = recv.poll() else {
|
||||
panic!("Expected a finished message");
|
||||
};
|
||||
|
||||
let result = crate::util::unwrap_or_panic::<i32>(result);
|
||||
assert_eq!(result, 42);
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
}
|
||||
|
|
232
distaff/src/heartbeat.rs
Normal file
232
distaff/src/heartbeat.rs
Normal file
|
@ -0,0 +1,232 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
mem::ManuallyDrop,
|
||||
ops::Deref,
|
||||
ptr::NonNull,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::channel::Parker;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeartbeatList {
|
||||
inner: Arc<Mutex<HeartbeatListInner>>,
|
||||
}
|
||||
|
||||
impl HeartbeatList {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(HeartbeatListInner::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn notify_nth(&self, n: usize) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("notifying worker-{}", n);
|
||||
self.inner.lock().notify_nth(n);
|
||||
}
|
||||
|
||||
pub fn notify_all(&self) {
|
||||
let mut inner = self.inner.lock();
|
||||
for (_, heartbeat) in inner.heartbeats.iter_mut() {
|
||||
heartbeat.set();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.inner.lock().len()
|
||||
}
|
||||
|
||||
pub fn new_heartbeat(&self) -> OwnedHeartbeatReceiver {
|
||||
let (recv, _) = self.inner.lock().new_heartbeat();
|
||||
OwnedHeartbeatReceiver {
|
||||
list: self.clone(),
|
||||
receiver: ManuallyDrop::new(recv),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inner(
|
||||
&self,
|
||||
) -> parking_lot::lock_api::MappedMutexGuard<
|
||||
'_,
|
||||
parking_lot::RawMutex,
|
||||
BTreeMap<u64, HeartbeatSender>,
|
||||
> {
|
||||
parking_lot::MutexGuard::map(self.inner.lock(), |inner| &mut inner.heartbeats)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct HeartbeatListInner {
|
||||
heartbeats: BTreeMap<u64, HeartbeatSender>,
|
||||
heartbeat_index: u64,
|
||||
}
|
||||
|
||||
impl HeartbeatListInner {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
heartbeats: BTreeMap::new(),
|
||||
heartbeat_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn notify_nth(&mut self, n: usize) {
|
||||
if let Some((_, heartbeat)) = self.heartbeats.iter_mut().nth(n) {
|
||||
heartbeat.set();
|
||||
}
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.heartbeats.len()
|
||||
}
|
||||
|
||||
fn new_heartbeat(&mut self) -> (HeartbeatReceiver, u64) {
|
||||
let heartbeat = Heartbeat::new(self.heartbeat_index);
|
||||
let (recv, send, i) = heartbeat.into_recv_send();
|
||||
self.heartbeats.insert(i, send);
|
||||
self.heartbeat_index += 1;
|
||||
(recv, i)
|
||||
}
|
||||
|
||||
fn remove_heartbeat(&mut self, receiver: HeartbeatReceiver) {
|
||||
if let Some(send) = self.heartbeats.remove(&receiver.i) {
|
||||
_ = Heartbeat::from_recv_send(receiver, send);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OwnedHeartbeatReceiver {
|
||||
list: HeartbeatList,
|
||||
receiver: ManuallyDrop<HeartbeatReceiver>,
|
||||
}
|
||||
|
||||
impl Deref for OwnedHeartbeatReceiver {
|
||||
type Target = HeartbeatReceiver;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.receiver
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OwnedHeartbeatReceiver {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
unsafe {
|
||||
let receiver = ManuallyDrop::take(&mut self.receiver);
|
||||
self.list.inner.lock().remove_heartbeat(receiver);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Heartbeat {
|
||||
ptr: NonNull<(AtomicBool, Parker)>,
|
||||
i: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HeartbeatReceiver {
|
||||
ptr: NonNull<(AtomicBool, Parker)>,
|
||||
i: u64,
|
||||
}
|
||||
|
||||
unsafe impl Send for HeartbeatReceiver {}
|
||||
|
||||
impl Drop for Heartbeat {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
unsafe {
|
||||
let _ = Box::from_raw(self.ptr.as_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HeartbeatSender {
|
||||
ptr: NonNull<(AtomicBool, Parker)>,
|
||||
pub last_heartbeat: Instant,
|
||||
}
|
||||
|
||||
unsafe impl Send for HeartbeatSender {}
|
||||
|
||||
impl Heartbeat {
|
||||
fn new(i: u64) -> Heartbeat {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
let ptr = NonNull::new(Box::into_raw(Box::new((
|
||||
AtomicBool::new(true),
|
||||
Parker::new(),
|
||||
))))
|
||||
.unwrap();
|
||||
Self { ptr, i }
|
||||
}
|
||||
|
||||
pub fn from_recv_send(recv: HeartbeatReceiver, send: HeartbeatSender) -> Heartbeat {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
_ = send;
|
||||
let ptr = recv.ptr;
|
||||
let i = recv.i;
|
||||
Heartbeat { ptr, i }
|
||||
}
|
||||
|
||||
pub fn into_recv_send(self) -> (HeartbeatReceiver, HeartbeatSender, u64) {
|
||||
// don't drop the `Heartbeat` yet
|
||||
let Self { ptr, i } = *ManuallyDrop::new(self);
|
||||
|
||||
(
|
||||
HeartbeatReceiver { ptr, i },
|
||||
HeartbeatSender {
|
||||
ptr,
|
||||
last_heartbeat: Instant::now(),
|
||||
},
|
||||
i,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HeartbeatReceiver {
|
||||
pub fn take(&self) -> bool {
|
||||
unsafe {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
self.ptr.as_ref().0.swap(false, Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parker(&self) -> &Parker {
|
||||
unsafe { &self.ptr.as_ref().1 }
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.ptr.as_ptr() as usize
|
||||
}
|
||||
|
||||
pub fn index(&self) -> u64 {
|
||||
self.i
|
||||
}
|
||||
}
|
||||
|
||||
impl HeartbeatSender {
|
||||
pub fn set(&mut self) {
|
||||
// SAFETY:
|
||||
// `AtomicBool` is `Sync` and `Send`, so it can be safely shared between threads.
|
||||
unsafe { self.ptr.as_ref().0.store(true, Ordering::Relaxed) };
|
||||
// self.last_heartbeat = Instant::now();
|
||||
}
|
||||
|
||||
pub fn is_waiting(&self) -> bool {
|
||||
unsafe { self.ptr.as_ref().1.is_parked() }
|
||||
}
|
||||
pub fn wake(&self) {
|
||||
unsafe { self.ptr.as_ref().1.unpark() };
|
||||
}
|
||||
}
|
|
@ -1,666 +1,361 @@
|
|||
use core::{
|
||||
any::Any,
|
||||
cell::UnsafeCell,
|
||||
fmt::Debug,
|
||||
hint::cold_path,
|
||||
mem::{self, ManuallyDrop},
|
||||
ptr::{self, NonNull},
|
||||
sync::atomic::Ordering,
|
||||
ptr::NonNull,
|
||||
};
|
||||
use std::{cell::Cell, marker::PhantomPinned};
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use parking_lot_core::SpinWait;
|
||||
|
||||
use crate::util::{SmallBox, TaggedAtomicPtr};
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum JobState {
|
||||
Empty,
|
||||
Locked = 1,
|
||||
Pending,
|
||||
Finished,
|
||||
// Inline = 1 << (u8::BITS - 1),
|
||||
// IsError = 1 << (u8::BITS - 2),
|
||||
}
|
||||
|
||||
impl JobState {
|
||||
#[allow(dead_code)]
|
||||
const MASK: u8 = 0; // Self::Inline as u8 | Self::IsError as u8;
|
||||
|
||||
fn from_u8(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0 => Some(Self::Empty),
|
||||
1 => Some(Self::Locked),
|
||||
2 => Some(Self::Pending),
|
||||
3 => Some(Self::Finished),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub use joblist::JobList;
|
||||
|
||||
mod joblist {
|
||||
use core::{fmt::Debug, ptr::NonNull};
|
||||
|
||||
use alloc::boxed::Box;
|
||||
|
||||
use super::Job;
|
||||
|
||||
// the list looks like this:
|
||||
// head <-> job1 <-> job2 <-> ... <-> jobN <-> tail
|
||||
pub struct JobList {
|
||||
// these cannot be boxes because boxes are noalias.
|
||||
head: NonNull<Job>,
|
||||
tail: NonNull<Job>,
|
||||
// the number of jobs in the list.
|
||||
// this is used to judge whether or not to join sync or async.
|
||||
job_count: usize,
|
||||
}
|
||||
|
||||
impl JobList {
|
||||
pub fn new() -> Self {
|
||||
let head = Box::into_raw(Box::new(Job::empty()));
|
||||
let tail = Box::into_raw(Box::new(Job::empty()));
|
||||
|
||||
// head and tail point at themselves
|
||||
unsafe {
|
||||
(&*head).link_mut().prev = None;
|
||||
(&*head).link_mut().next = Some(NonNull::new_unchecked(tail));
|
||||
|
||||
(&*tail).link_mut().prev = Some(NonNull::new_unchecked(head));
|
||||
(&*tail).link_mut().next = None;
|
||||
|
||||
Self {
|
||||
head: NonNull::new_unchecked(head),
|
||||
tail: NonNull::new_unchecked(tail),
|
||||
job_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn head(&self) -> NonNull<Job> {
|
||||
self.head
|
||||
}
|
||||
fn tail(&self) -> NonNull<Job> {
|
||||
self.tail
|
||||
}
|
||||
|
||||
/// `job` must be valid until it is removed from the list.
|
||||
pub unsafe fn push_front<T>(&mut self, job: *const Job<T>) {
|
||||
self.job_count += 1;
|
||||
let headlink = unsafe { self.head.as_ref().link_mut() };
|
||||
|
||||
let next = headlink.next.unwrap();
|
||||
let next_link = unsafe { next.as_ref().link_mut() };
|
||||
|
||||
let job_ptr = unsafe { NonNull::new_unchecked(job as _) };
|
||||
|
||||
headlink.next = Some(job_ptr);
|
||||
next_link.prev = Some(job_ptr);
|
||||
|
||||
let job_link = unsafe { job_ptr.as_ref().link_mut() };
|
||||
job_link.next = Some(next);
|
||||
job_link.prev = Some(self.head);
|
||||
}
|
||||
|
||||
/// `job` must be valid until it is removed from the list.
|
||||
pub unsafe fn push_back<T>(&mut self, job: *const Job<T>) {
|
||||
self.job_count += 1;
|
||||
let taillink = unsafe { self.tail.as_ref().link_mut() };
|
||||
|
||||
let prev = taillink.prev.unwrap();
|
||||
let prev_link = unsafe { prev.as_ref().link_mut() };
|
||||
|
||||
let job_ptr = unsafe { NonNull::new_unchecked(job as _) };
|
||||
|
||||
taillink.prev = Some(job_ptr);
|
||||
prev_link.next = Some(job_ptr);
|
||||
|
||||
let job_link = unsafe { job_ptr.as_ref().link_mut() };
|
||||
job_link.prev = Some(prev);
|
||||
job_link.next = Some(self.tail);
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self) -> Option<NonNull<Job>> {
|
||||
self.job_count -= 1;
|
||||
|
||||
let headlink = unsafe { self.head.as_ref().link_mut() };
|
||||
|
||||
// SAFETY: headlink.next is guaranteed to be Some.
|
||||
let job = headlink.next.unwrap();
|
||||
let job_link = unsafe { job.as_ref().link_mut() };
|
||||
|
||||
// short-circuit here if the job is the tail
|
||||
let next = job_link.next?;
|
||||
let next_link = unsafe { next.as_ref().link_mut() };
|
||||
|
||||
headlink.next = Some(next);
|
||||
next_link.prev = Some(self.head);
|
||||
|
||||
Some(job)
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self) -> Option<NonNull<Job>> {
|
||||
self.job_count -= 1;
|
||||
|
||||
let taillink = unsafe { self.tail.as_ref().link_mut() };
|
||||
|
||||
// SAFETY: taillink.prev is guaranteed to be Some.
|
||||
let job = taillink.prev.unwrap();
|
||||
let job_link = unsafe { job.as_ref().link_mut() };
|
||||
|
||||
// short-circuit here if the job is the head
|
||||
let prev = job_link.prev?;
|
||||
let prev_link = unsafe { prev.as_ref().link_mut() };
|
||||
|
||||
taillink.prev = Some(prev);
|
||||
prev_link.next = Some(self.tail);
|
||||
|
||||
Some(job)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.job_count == 0
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.job_count
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for JobList {
|
||||
fn drop(&mut self) {
|
||||
// Need to drop the head and tail, which were allocated on the heap.
|
||||
// elements of the list are managed externally.
|
||||
unsafe {
|
||||
drop((Box::from_non_null(self.head), Box::from_non_null(self.tail)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for JobList {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("JobList")
|
||||
.field("head", &self.head)
|
||||
.field("tail", &self.tail)
|
||||
.field("job_count", &self.job_count)
|
||||
.field_with("jobs", |f| {
|
||||
let mut jobs = f.debug_list();
|
||||
|
||||
// SAFETY: head.next is guaranteed to be non-null and valid
|
||||
let mut job = unsafe { self.head.as_ref().link_mut().next.unwrap() };
|
||||
|
||||
while job != self.tail {
|
||||
let job_ref = unsafe { job.as_ref() };
|
||||
jobs.entry(job_ref);
|
||||
|
||||
// SAFETY: job is guaranteed to be non-null and valid
|
||||
// only the tail has a next of None
|
||||
job = unsafe { job_ref.link_mut().next.unwrap() };
|
||||
}
|
||||
|
||||
jobs.finish()
|
||||
})
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
use crate::{
|
||||
WorkerThread,
|
||||
channel::{Parker, Receiver, Sender},
|
||||
context::Message,
|
||||
queue::ReceiverToken,
|
||||
};
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct JobResult<T> {
|
||||
inner: std::thread::Result<T>,
|
||||
pub struct StackJob<F> {
|
||||
f: UnsafeCell<ManuallyDrop<F>>,
|
||||
}
|
||||
|
||||
impl<T> JobResult<T> {
|
||||
pub fn new(result: std::thread::Result<T>) -> Self {
|
||||
Self { inner: result }
|
||||
}
|
||||
|
||||
/// convert JobResult into a thread result.
|
||||
#[allow(dead_code)]
|
||||
pub fn into_inner(self) -> std::thread::Result<T> {
|
||||
self.inner
|
||||
}
|
||||
|
||||
// unwraps the result, propagating panics
|
||||
pub fn into_result(self) -> T {
|
||||
match self.inner {
|
||||
Ok(val) => val,
|
||||
Err(payload) => {
|
||||
cold_path();
|
||||
|
||||
std::panic::resume_unwind(payload);
|
||||
// #[cfg(feature = "std")]
|
||||
// {
|
||||
// std::panic::resume_unwind(err);
|
||||
// }
|
||||
// #[cfg(not(feature = "std"))]
|
||||
// {
|
||||
// // in no-std, we just panic with the error
|
||||
// // TODO: figure out how to propagate the error
|
||||
// panic!("Job failed: {:?}", payload);
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
struct Link<T> {
|
||||
prev: Option<NonNull<T>>,
|
||||
next: Option<NonNull<T>>,
|
||||
}
|
||||
|
||||
// `Link` is invariant over `T`
|
||||
impl<T> Clone for Link<T> {
|
||||
fn clone(&self) -> Self {
|
||||
impl<F> StackJob<F> {
|
||||
pub fn new(f: F) -> Self {
|
||||
Self {
|
||||
prev: self.prev.clone(),
|
||||
next: self.next.clone(),
|
||||
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn unwrap(&self) -> F {
|
||||
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
||||
}
|
||||
}
|
||||
pub struct HeapJob<F> {
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<F> HeapJob<F> {
|
||||
pub fn new(f: F) -> Box<Self> {
|
||||
Box::new(Self { f })
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> F {
|
||||
self.f
|
||||
}
|
||||
}
|
||||
|
||||
// `Link` is invariant over `T`
|
||||
impl<T> Copy for Link<T> {}
|
||||
type JobHarness = unsafe fn(&WorkerThread, this: NonNull<()>, sender: Option<Sender>);
|
||||
|
||||
struct Thread;
|
||||
|
||||
union ValueOrThis<T> {
|
||||
uninit: (),
|
||||
value: ManuallyDrop<SmallBox<T>>,
|
||||
this: NonNull<()>,
|
||||
#[repr(C)]
|
||||
pub struct Job2<T = ()> {
|
||||
inner: UnsafeCell<Job2Inner<T>>,
|
||||
}
|
||||
|
||||
union LinkOrError<T> {
|
||||
link: Link<T>,
|
||||
waker: ManuallyDrop<Option<std::thread::Thread>>,
|
||||
error: ManuallyDrop<Option<Box<dyn Any + Send + 'static>>>,
|
||||
impl<T> Debug for Job2<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Job2").field("inner", &self.inner).finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Job<T = ()> {
|
||||
/// stores the job's harness as a *const usize
|
||||
harness_and_state: TaggedAtomicPtr<usize, 3>,
|
||||
/// `this` before `execute()` is called, or `value` after `execute()`
|
||||
value_or_this: UnsafeCell<ValueOrThis<T>>,
|
||||
/// `link` before `execute()` is called, or `error` after `execute()`
|
||||
error_or_link: UnsafeCell<LinkOrError<Job>>,
|
||||
pub enum Job2Inner<T = ()> {
|
||||
Local {
|
||||
harness: JobHarness,
|
||||
this: NonNull<()>,
|
||||
_pin: PhantomPinned,
|
||||
},
|
||||
Shared {
|
||||
receiver: Cell<Option<Receiver<T>>>,
|
||||
},
|
||||
}
|
||||
|
||||
unsafe impl<T> Send for Job<T> {}
|
||||
|
||||
impl<T> Debug for Job<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let state = JobState::from_u8(self.harness_and_state.tag(Ordering::Relaxed) as u8).unwrap();
|
||||
let mut debug = f.debug_struct("Job");
|
||||
debug.field("state", &state).field_with("harness", |f| {
|
||||
write!(f, "{:?}", self.harness_and_state.ptr(Ordering::Relaxed))
|
||||
});
|
||||
|
||||
match state {
|
||||
JobState::Empty => {
|
||||
debug
|
||||
.field_with("this", |f| {
|
||||
write!(f, "{:?}", unsafe { &(&*self.value_or_this.get()).this })
|
||||
})
|
||||
.field_with("link", |f| {
|
||||
write!(f, "{:?}", unsafe { &(&*self.error_or_link.get()).link })
|
||||
});
|
||||
}
|
||||
JobState::Locked => {
|
||||
#[derive(Debug)]
|
||||
struct Locked;
|
||||
debug.field("locked", &Locked);
|
||||
}
|
||||
JobState::Pending => {
|
||||
debug
|
||||
.field_with("this", |f| {
|
||||
write!(f, "{:?}", unsafe { &(&*self.value_or_this.get()).this })
|
||||
})
|
||||
.field_with("waker", |f| {
|
||||
write!(f, "{:?}", unsafe { &(&*self.error_or_link.get()).waker })
|
||||
});
|
||||
}
|
||||
JobState::Finished => {
|
||||
let err = unsafe { &(&*self.error_or_link.get()).error };
|
||||
|
||||
let result = match err.as_ref() {
|
||||
Some(err) => Err(err),
|
||||
None => Ok(unsafe { (&*self.value_or_this.get()).value.0.as_ptr() }),
|
||||
};
|
||||
|
||||
debug.field("result", &result);
|
||||
}
|
||||
}
|
||||
|
||||
debug.finish()
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct SharedJob {
|
||||
harness: JobHarness,
|
||||
this: NonNull<()>,
|
||||
sender: Option<Sender<()>>,
|
||||
}
|
||||
|
||||
impl<T> Job<T> {
|
||||
pub fn empty() -> Job<T> {
|
||||
Self {
|
||||
harness_and_state: TaggedAtomicPtr::new(ptr::dangling_mut(), JobState::Empty as usize),
|
||||
value_or_this: UnsafeCell::new(ValueOrThis {
|
||||
this: NonNull::dangling(),
|
||||
unsafe impl Send for SharedJob {}
|
||||
|
||||
impl<T: Send> Job2<T> {
|
||||
fn new(harness: JobHarness, this: NonNull<()>) -> Self {
|
||||
let this = Self {
|
||||
inner: UnsafeCell::new(Job2Inner::Local {
|
||||
harness: harness,
|
||||
this,
|
||||
_pin: PhantomPinned,
|
||||
}),
|
||||
error_or_link: UnsafeCell::new(LinkOrError {
|
||||
link: Link {
|
||||
prev: None,
|
||||
next: None,
|
||||
},
|
||||
}),
|
||||
// _phantom: PhantomPinned,
|
||||
}
|
||||
}
|
||||
pub fn new(harness: unsafe fn(*const (), *const Job<T>), this: NonNull<()>) -> Job<T> {
|
||||
Self {
|
||||
harness_and_state: TaggedAtomicPtr::new(
|
||||
unsafe { mem::transmute(harness) },
|
||||
JobState::Empty as usize,
|
||||
),
|
||||
value_or_this: UnsafeCell::new(ValueOrThis { this }),
|
||||
error_or_link: UnsafeCell::new(LinkOrError {
|
||||
link: Link {
|
||||
prev: None,
|
||||
next: None,
|
||||
},
|
||||
}),
|
||||
// _phantom: PhantomPinned,
|
||||
}
|
||||
};
|
||||
|
||||
this
|
||||
}
|
||||
|
||||
// Job is passed around type-erased as `Job<()>`, to complete the job we
|
||||
// need to cast it back to the original type.
|
||||
pub unsafe fn transmute_ref<U>(&self) -> &Job<U> {
|
||||
unsafe { mem::transmute::<&Job<T>, &Job<U>>(self) }
|
||||
}
|
||||
pub fn share(&self, parker: Option<&Parker>) -> SharedJob {
|
||||
let (sender, receiver) = parker
|
||||
.map(|parker| crate::channel::channel::<T>(parker.into()))
|
||||
.unzip();
|
||||
|
||||
#[inline]
|
||||
unsafe fn link_mut(&self) -> &mut Link<Job> {
|
||||
unsafe { &mut (&mut *self.error_or_link.get()).link }
|
||||
}
|
||||
|
||||
/// assumes job is in a `JobList`
|
||||
pub unsafe fn unlink(&self) {
|
||||
unsafe {
|
||||
let mut dummy = None;
|
||||
let Link { prev, next } = *self.link_mut();
|
||||
|
||||
*prev
|
||||
.map(|ptr| &mut ptr.as_ref().link_mut().next)
|
||||
.unwrap_or(&mut dummy) = next;
|
||||
*next
|
||||
.map(|ptr| &mut ptr.as_ref().link_mut().prev)
|
||||
.unwrap_or(&mut dummy) = prev;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(&self) -> u8 {
|
||||
self.harness_and_state.tag(Ordering::Relaxed) as u8
|
||||
}
|
||||
|
||||
pub fn wait(&self) -> JobResult<T> {
|
||||
let mut spin = SpinWait::new();
|
||||
loop {
|
||||
match self.harness_and_state.compare_exchange_weak_tag(
|
||||
JobState::Pending as usize,
|
||||
JobState::Locked as usize,
|
||||
Ordering::Acquire,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
// if still pending, sleep until completed
|
||||
Ok(state) => {
|
||||
debug_assert_eq!(state, JobState::Pending as usize);
|
||||
unsafe {
|
||||
*(&mut *self.error_or_link.get()).waker = Some(std::thread::current());
|
||||
}
|
||||
|
||||
self.harness_and_state.set_tag(
|
||||
JobState::Pending as usize,
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
|
||||
std::thread::park();
|
||||
spin.reset();
|
||||
|
||||
// after sleeping, state should be `Finished`
|
||||
}
|
||||
Err(state) => {
|
||||
// job finished under us, check if it was successful
|
||||
if state == JobState::Finished as usize {
|
||||
let err = unsafe { (&mut *self.error_or_link.get()).error.take() };
|
||||
|
||||
let result: std::thread::Result<T> = if let Some(err) = err {
|
||||
cold_path();
|
||||
Err(err)
|
||||
} else {
|
||||
let val = unsafe {
|
||||
ManuallyDrop::take(&mut (&mut *self.value_or_this.get()).value)
|
||||
};
|
||||
|
||||
Ok(val.into_inner())
|
||||
};
|
||||
|
||||
return JobResult::new(result);
|
||||
} else {
|
||||
// spin until lock is released.
|
||||
tracing::trace!("spin-waiting for job: {:?}", self);
|
||||
spin.spin();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// call this when popping value from local queue
|
||||
pub fn set_pending(&self) {
|
||||
let mut spin = SpinWait::new();
|
||||
loop {
|
||||
match self.harness_and_state.compare_exchange_weak_tag(
|
||||
JobState::Empty as usize,
|
||||
JobState::Pending as usize,
|
||||
Ordering::Acquire,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(state) => {
|
||||
debug_assert_eq!(state, JobState::Empty as usize);
|
||||
// set waker to None
|
||||
unsafe {
|
||||
(&mut *self.error_or_link.get()).waker = ManuallyDrop::new(None);
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
// debug_assert_ne!(state, JobState::Empty as usize);
|
||||
|
||||
tracing::error!("######## what the sigma?");
|
||||
spin.spin();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute(job: NonNull<Self>) {
|
||||
tracing::trace!("executing job: {:?}", job);
|
||||
|
||||
// SAFETY: self is non-null
|
||||
unsafe {
|
||||
let this = job.as_ref();
|
||||
let (ptr, state) = this.harness_and_state.ptr_and_tag(Ordering::Relaxed);
|
||||
|
||||
debug_assert_eq!(state, JobState::Pending as usize);
|
||||
let harness: unsafe fn(*const (), *const Self) = mem::transmute(ptr.as_ptr());
|
||||
|
||||
let this = (*this.value_or_this.get()).this;
|
||||
|
||||
harness(this.as_ptr().cast(), job.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn complete(&self, result: std::thread::Result<T>) {
|
||||
let mut spin = SpinWait::new();
|
||||
loop {
|
||||
match self.harness_and_state.compare_exchange_weak_tag(
|
||||
JobState::Pending as usize,
|
||||
JobState::Locked as usize,
|
||||
Ordering::Acquire,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(state) => {
|
||||
debug_assert_eq!(state, JobState::Pending as usize);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
// debug_assert_ne!(state, JobState::Pending as usize);
|
||||
spin.spin();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let waker = unsafe { (&mut *self.error_or_link.get()).waker.take() };
|
||||
|
||||
match result {
|
||||
Ok(val) => unsafe {
|
||||
(&mut *self.value_or_this.get()).value = ManuallyDrop::new(SmallBox::new(val));
|
||||
(&mut *self.error_or_link.get()).error = ManuallyDrop::new(None);
|
||||
},
|
||||
Err(err) => unsafe {
|
||||
(&mut *self.value_or_this.get()).uninit = ();
|
||||
(&mut *self.error_or_link.get()).error = ManuallyDrop::new(Some(err));
|
||||
},
|
||||
}
|
||||
|
||||
if let Some(thread) = waker {
|
||||
thread.unpark();
|
||||
}
|
||||
|
||||
self.harness_and_state.set_tag(
|
||||
JobState::Finished as usize,
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
mod stackjob {
|
||||
use crate::latch::Latch;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct StackJob<F, L> {
|
||||
latch: L,
|
||||
f: UnsafeCell<ManuallyDrop<F>>,
|
||||
}
|
||||
|
||||
impl<F, L> StackJob<F, L> {
|
||||
pub fn new(f: F, latch: L) -> Self {
|
||||
Self {
|
||||
latch,
|
||||
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn unwrap(&self) -> F {
|
||||
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, L> StackJob<F, L>
|
||||
where
|
||||
L: Latch,
|
||||
{
|
||||
pub fn as_job<T>(&self) -> Job<()>
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
#[align(8)]
|
||||
unsafe fn harness<F, T, L: Latch>(this: *const (), job: *const Job<()>)
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
T: Sized + Send,
|
||||
{
|
||||
let this = unsafe { &*this.cast::<StackJob<F, L>>() };
|
||||
let f = unsafe { this.unwrap() };
|
||||
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
|
||||
|
||||
let job = unsafe { &*job.cast::<Job<T>>() };
|
||||
job.complete(result);
|
||||
|
||||
unsafe {
|
||||
Latch::set_raw(&this.latch);
|
||||
}
|
||||
}
|
||||
|
||||
Job::new(harness::<F, T, L>, unsafe {
|
||||
NonNull::new_unchecked(self as *const _ as *mut ())
|
||||
// self.receiver.set(receiver);
|
||||
if let Job2Inner::Local {
|
||||
harness,
|
||||
this,
|
||||
_pin: _,
|
||||
} = unsafe {
|
||||
self.inner.replace(Job2Inner::Shared {
|
||||
receiver: Cell::new(receiver),
|
||||
})
|
||||
} {
|
||||
// SAFETY: `this` is a valid pointer to the job.
|
||||
unsafe {
|
||||
SharedJob {
|
||||
harness,
|
||||
this,
|
||||
sender: mem::transmute(sender), // Convert `Option<Sender<T>>` to `Option<Sender<()>>`
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("Job2 is already shared");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod heapjob {
|
||||
use super::*;
|
||||
|
||||
pub struct HeapJob<F> {
|
||||
f: F,
|
||||
pub fn take_receiver(&self) -> Option<Receiver<T>> {
|
||||
unsafe {
|
||||
if let Job2Inner::Shared { receiver } = self.inner.as_ref_unchecked() {
|
||||
receiver.take()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> HeapJob<F> {
|
||||
pub fn new(f: F) -> Self {
|
||||
Self { f }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> F {
|
||||
self.f
|
||||
}
|
||||
|
||||
pub fn into_boxed_job<T>(self: Box<Self>) -> *mut Job<()>
|
||||
pub fn from_stackjob<F>(job: &StackJob<F>) -> Self
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
{
|
||||
#[align(8)]
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip_all, name = "stack_job_harness")
|
||||
)]
|
||||
unsafe fn harness<F, T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
#[align(8)]
|
||||
unsafe fn harness<F, T>(this: *const (), job: *const Job<()>)
|
||||
where
|
||||
F: FnOnce() -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
let job = job.cast_mut();
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
|
||||
// turn `this`, which was allocated at (2), into box.
|
||||
// miri complains this is a use-after-free, but it isn't? silly miri...
|
||||
// Turns out this is actually correct on miri's end, but because
|
||||
// we ensure that the scope lives as long as any jobs, this is
|
||||
// actually fine, as far as I can tell.
|
||||
let this = unsafe { Box::from_raw(this.cast::<HeapJob<F>>().cast_mut()) };
|
||||
let f = this.into_inner();
|
||||
let f = unsafe { this.cast::<StackJob<F>>().as_ref().unwrap() };
|
||||
|
||||
_ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f()));
|
||||
// #[cfg(feature = "metrics")]
|
||||
// if worker.heartbeat.parker() == mutex {
|
||||
// worker
|
||||
// .metrics
|
||||
// .num_sent_to_self
|
||||
// .fetch_add(1, Ordering::Relaxed);
|
||||
// tracing::trace!("job sent to self");
|
||||
// }
|
||||
|
||||
// drop job (this is fine because the job of a HeapJob is pure POD).
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
|
||||
|
||||
if let Some(sender) = sender {
|
||||
unsafe {
|
||||
ptr::drop_in_place(job);
|
||||
sender.send_as_ref(result);
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
|
||||
// free box that was allocated at (1)
|
||||
_ = unsafe { Box::<ManuallyDrop<Job<T>>>::from_raw(job.cast()) };
|
||||
}
|
||||
}
|
||||
|
||||
// (1) allocate box for job
|
||||
Box::into_raw(Box::new(Job::new(harness::<F, T>, {
|
||||
// (2) convert self into a pointer
|
||||
Box::into_non_null(self).cast()
|
||||
})))
|
||||
Self::new(harness::<F, T>, NonNull::from(job).cast())
|
||||
}
|
||||
|
||||
pub fn from_heapjob<F>(job: Box<HeapJob<F>>) -> Self
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
{
|
||||
#[align(8)]
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip_all, name = "heap_job_harness")
|
||||
)]
|
||||
unsafe fn harness<F, T>(worker: &WorkerThread, this: NonNull<()>, sender: Option<Sender>)
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
|
||||
// expect MIRI to complain about this, but it is actually correct.
|
||||
// because I am so much smarter than MIRI, naturally, obviously.
|
||||
// unbox the job, which was allocated at (2)
|
||||
let f = unsafe { (*Box::from_non_null(this.cast::<HeapJob<F>>())).into_inner() };
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(worker)));
|
||||
if let Some(sender) = sender {
|
||||
unsafe {
|
||||
sender.send_as_ref(result);
|
||||
_ = worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// (1) allocate box for job
|
||||
Self::new(
|
||||
harness::<F, T>,
|
||||
// (2) convert job into a pointer
|
||||
Box::into_non_null(job).cast(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_harness(harness: JobHarness, this: NonNull<()>) -> Self {
|
||||
Self::new(harness, this)
|
||||
}
|
||||
}
|
||||
|
||||
impl SharedJob {
|
||||
pub unsafe fn execute(self, worker: &WorkerThread) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("executing shared job: {:?}", self);
|
||||
|
||||
let Self {
|
||||
harness,
|
||||
this,
|
||||
sender,
|
||||
} = self;
|
||||
|
||||
unsafe {
|
||||
(harness)(worker, this, sender);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("finished executing shared job: {:?}", this);
|
||||
}
|
||||
}
|
||||
|
||||
pub use queuedjobqueue::JobQueue;
|
||||
|
||||
mod queuedjobqueue {
|
||||
//! Basically `JobVec`, but for `QueuedJob`s.
|
||||
|
||||
// TODO: use non-null's here and rely on Into/From for &T
|
||||
|
||||
use std::{collections::VecDeque, ptr::NonNull};
|
||||
|
||||
use super::Job2 as Job;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct JobQueue {
|
||||
jobs: VecDeque<NonNull<Job>>,
|
||||
}
|
||||
|
||||
impl JobQueue {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
jobs: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, job: *const Job) {
|
||||
self.jobs
|
||||
.push_front(unsafe { NonNull::new_unchecked(job as *mut _) });
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, job: *const Job) {
|
||||
self.jobs
|
||||
.push_back(unsafe { NonNull::new_unchecked(job as *mut _) });
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self) -> Option<NonNull<Job>> {
|
||||
self.jobs.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self) -> Option<NonNull<Job>> {
|
||||
self.jobs.pop_back()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.jobs.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.jobs.len()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub use heapjob::HeapJob;
|
||||
pub use stackjob::StackJob;
|
||||
pub mod traits {
|
||||
use std::{cell::UnsafeCell, mem::ManuallyDrop};
|
||||
|
||||
use crate::WorkerThread;
|
||||
|
||||
use super::{HeapJob, Job2, StackJob};
|
||||
|
||||
pub trait IntoJob<T> {
|
||||
fn into_job(self) -> Job2<T>;
|
||||
}
|
||||
|
||||
pub trait InlineJob<T>: IntoJob<T> {
|
||||
fn run_inline(self, worker: &WorkerThread) -> T;
|
||||
}
|
||||
|
||||
impl<F, T> IntoJob<T> for F
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn into_job(self) -> Job2<T> {
|
||||
Job2::from_heapjob(HeapJob::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, T> IntoJob<T> for &UnsafeCell<ManuallyDrop<F>>
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn into_job(self) -> Job2<T> {
|
||||
Job2::from_stackjob(unsafe { std::mem::transmute::<Self, &StackJob<F>>(self) })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, T> InlineJob<T> for &UnsafeCell<ManuallyDrop<F>>
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn run_inline(self, worker: &WorkerThread) -> T {
|
||||
unsafe { ManuallyDrop::take(&mut *self.get())(worker) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, T> IntoJob<T> for &StackJob<F>
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn into_job(self) -> Job2<T> {
|
||||
Job2::from_stackjob(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, T> InlineJob<T> for &StackJob<F>
|
||||
where
|
||||
F: FnOnce(&WorkerThread) -> T + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn run_inline(self, worker: &WorkerThread) -> T {
|
||||
unsafe { self.unwrap()(worker) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,48 +1,59 @@
|
|||
use std::hint::cold_path;
|
||||
#[cfg(feature = "metrics")]
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use std::{hint::cold_path, pin::Pin, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
job::{JobState, StackJob},
|
||||
latch::{AsCoreLatch, LatchRef, WakeLatch},
|
||||
context::Context,
|
||||
job::{
|
||||
Job2 as Job, StackJob,
|
||||
traits::{InlineJob, IntoJob},
|
||||
},
|
||||
workerthread::WorkerThread,
|
||||
};
|
||||
|
||||
impl WorkerThread {
|
||||
#[inline]
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn join_seq<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
A: FnOnce(&WorkerThread) -> RA,
|
||||
B: FnOnce(&WorkerThread) -> RB,
|
||||
{
|
||||
let rb = b();
|
||||
let ra = a();
|
||||
let rb = b(self);
|
||||
let ra = a(self);
|
||||
|
||||
(ra, rb)
|
||||
}
|
||||
|
||||
pub(crate) fn join_heartbeat_every<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
A: FnOnce(&WorkerThread) -> RA + Send,
|
||||
B: FnOnce(&WorkerThread) -> RB,
|
||||
RA: Send,
|
||||
{
|
||||
// self.join_heartbeat_every_inner::<A, B, RA, RB, 2>(a, b)
|
||||
self.join_heartbeat(a, b)
|
||||
}
|
||||
|
||||
/// This function must be called from a worker thread.
|
||||
#[inline]
|
||||
pub(crate) fn join_heartbeat_every<A, B, RA, RB, const TIMES: usize>(
|
||||
&self,
|
||||
a: A,
|
||||
b: B,
|
||||
) -> (RA, RB)
|
||||
#[allow(dead_code)]
|
||||
#[inline(always)]
|
||||
fn join_heartbeat_every_inner<A, B, RA, RB, const TIMES: usize>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
A: FnOnce(&WorkerThread) -> RA + Send,
|
||||
B: FnOnce(&WorkerThread) -> RB,
|
||||
{
|
||||
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
||||
let count = self.join_count.get();
|
||||
let queue_len = unsafe { self.queue.as_ref_unchecked().len() };
|
||||
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
|
||||
|
||||
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
||||
// see: chili
|
||||
|
||||
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
|
||||
if count == 0 || unsafe { self.queue.as_ref_unchecked().len() } < 3 {
|
||||
if count == 0 || queue_len < 3 {
|
||||
cold_path();
|
||||
self.join_heartbeat(a, b)
|
||||
} else {
|
||||
|
@ -51,60 +62,172 @@ impl WorkerThread {
|
|||
}
|
||||
|
||||
/// This function must be called from a worker thread.
|
||||
#[inline]
|
||||
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
#[allow(dead_code)]
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub(crate) fn join_heartbeat2_every<A, B, RA, RB, const TIMES: usize>(
|
||||
&self,
|
||||
a: A,
|
||||
b: B,
|
||||
) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
A: InlineJob<RA> + Copy,
|
||||
B: FnOnce(&WorkerThread) -> RB,
|
||||
{
|
||||
// SAFETY: each worker is only ever used by one thread, so this is safe.
|
||||
let count = self.join_count.get();
|
||||
let queue_len = unsafe { self.queue.as_ref_unchecked().len() };
|
||||
self.join_count.set(count.wrapping_add(1) % TIMES as u8);
|
||||
|
||||
// TODO: add counter to job queue, check for low job count to decide whether to use heartbeat or seq.
|
||||
// see: chili
|
||||
|
||||
// SAFETY: this function runs in a worker thread, so we can access the queue safely.
|
||||
if count == 0 || queue_len < 3 {
|
||||
self.join_heartbeat2(a, b)
|
||||
} else {
|
||||
(a.run_inline(self), b(self))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub(crate) fn join_heartbeat2<RA, A, B, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
B: FnOnce(&WorkerThread) -> RB,
|
||||
A: InlineJob<RA> + Copy,
|
||||
RA: Send,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
||||
|
||||
let latch = WakeLatch::new(self.context.clone(), self.index);
|
||||
let a = StackJob::new(
|
||||
move || {
|
||||
// TODO: bench whether tick'ing here is good.
|
||||
// turns out this actually costs a lot of time, likely because of the thread local check.
|
||||
// WorkerThread::current_ref()
|
||||
// .expect("stackjob is run in workerthread.")
|
||||
// .tick();
|
||||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
a()
|
||||
},
|
||||
LatchRef::new(&latch),
|
||||
);
|
||||
let _pinned = a.into_job();
|
||||
let job = unsafe { Pin::new_unchecked(&_pinned) };
|
||||
|
||||
let job = a.as_job();
|
||||
self.push_front(&job);
|
||||
self.push_back(&*job);
|
||||
|
||||
let rb = match catch_unwind(AssertUnwindSafe(|| b())) {
|
||||
self.tick();
|
||||
|
||||
// let rb = b(self);
|
||||
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
||||
Ok(val) => val,
|
||||
Err(payload) => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
||||
cold_path();
|
||||
|
||||
// if b panicked, we need to wait for a to finish
|
||||
self.wait_until_latch(&latch);
|
||||
if let Some(recv) = job.take_receiver() {
|
||||
_ = self.wait_until_recv(recv);
|
||||
}
|
||||
|
||||
resume_unwind(payload);
|
||||
}
|
||||
};
|
||||
|
||||
let ra = if job.state() == JobState::Empty as u8 {
|
||||
unsafe {
|
||||
job.unlink();
|
||||
}
|
||||
|
||||
// a is allowed to panic here, because we already finished b.
|
||||
unsafe { a.unwrap()() }
|
||||
let ra = if let Some(recv) = job.take_receiver() {
|
||||
crate::util::unwrap_or_panic(self.wait_until_recv(recv))
|
||||
} else {
|
||||
match self.wait_until_job::<RA>(unsafe { job.transmute_ref() }, latch.as_core_latch()) {
|
||||
Some(t) => t.into_result(), // propagate panic here
|
||||
// the job was shared, but not yet stolen, so we get to run the
|
||||
// job inline
|
||||
None => unsafe { a.unwrap()() },
|
||||
self.pop_back();
|
||||
|
||||
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||
a.run_inline(self)
|
||||
};
|
||||
|
||||
// touch the job to ensure it is not dropped while we are still using it.
|
||||
drop(_pinned);
|
||||
|
||||
(ra, rb)
|
||||
}
|
||||
|
||||
/// This function must be called from a worker thread.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn join_heartbeat<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
A: FnOnce(&WorkerThread) -> RA + Send,
|
||||
B: FnOnce(&WorkerThread) -> RB,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_joins.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let a = StackJob::new(a);
|
||||
let job = Job::from_stackjob(&a);
|
||||
|
||||
self.push_back(&job);
|
||||
|
||||
self.tick();
|
||||
|
||||
let rb = match catch_unwind(AssertUnwindSafe(|| b(self))) {
|
||||
Ok(val) => val,
|
||||
Err(payload) => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
||||
cold_path();
|
||||
|
||||
// if b panicked, we need to wait for a to finish
|
||||
if let Some(recv) = job.take_receiver() {
|
||||
_ = self.wait_until_recv(recv);
|
||||
}
|
||||
|
||||
resume_unwind(payload);
|
||||
}
|
||||
};
|
||||
|
||||
drop(a);
|
||||
let ra = if let Some(recv) = job.take_receiver() {
|
||||
crate::util::unwrap_or_panic(self.wait_until_recv(recv))
|
||||
} else {
|
||||
self.pop_back();
|
||||
|
||||
// SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||
a.run_inline(self)
|
||||
};
|
||||
|
||||
(ra, rb)
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn join<A, B, RA, RB>(self: &Arc<Self>, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
{
|
||||
// SAFETY: join_heartbeat_every is safe to call from a worker thread.
|
||||
self.run_in_worker(move |worker| {
|
||||
worker.join_heartbeat_every::<_, _, _, _>(|_| a(), |_| b())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// run two closures potentially in parallel, in the global threadpool.
|
||||
pub fn join<A, B, RA, RB>(a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
{
|
||||
join_in(Context::global_context().clone(), a, b)
|
||||
}
|
||||
|
||||
/// run two closures potentially in parallel, in the global threadpool.
|
||||
#[allow(dead_code)]
|
||||
fn join_in<A, B, RA, RB>(context: Arc<Context>, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
{
|
||||
context.join(a, b)
|
||||
}
|
||||
|
|
|
@ -2,11 +2,11 @@ use core::{
|
|||
marker::PhantomData,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
use std::sync::{Arc, atomic::AtomicU8};
|
||||
use std::sync::atomic::{AtomicPtr, AtomicU8};
|
||||
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
|
||||
use crate::context::Context;
|
||||
use crate::channel::Parker;
|
||||
|
||||
pub trait Latch {
|
||||
unsafe fn set_raw(this: *const Self);
|
||||
|
@ -30,6 +30,8 @@ impl AtomicLatch {
|
|||
pub const UNSET: u8 = 0;
|
||||
pub const SET: u8 = 1;
|
||||
pub const SLEEPING: u8 = 2;
|
||||
pub const WAKEUP: u8 = 4;
|
||||
pub const HEARTBEAT: u8 = 8;
|
||||
|
||||
#[inline]
|
||||
pub const fn new() -> Self {
|
||||
|
@ -37,21 +39,66 @@ impl AtomicLatch {
|
|||
inner: AtomicU8::new(Self::UNSET),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn new_set() -> Self {
|
||||
Self {
|
||||
inner: AtomicU8::new(Self::SET),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset(&self) {
|
||||
self.inner.store(Self::UNSET, Ordering::Release);
|
||||
pub fn unset(&self) {
|
||||
self.inner.fetch_and(!Self::SET, Ordering::Release);
|
||||
}
|
||||
|
||||
pub fn reset(&self) -> u8 {
|
||||
self.inner.swap(Self::UNSET, Ordering::Release)
|
||||
}
|
||||
|
||||
pub fn get(&self) -> u8 {
|
||||
self.inner.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
pub fn poll_heartbeat(&self) -> bool {
|
||||
self.inner.fetch_and(!Self::HEARTBEAT, Ordering::Relaxed) & Self::HEARTBEAT
|
||||
== Self::HEARTBEAT
|
||||
}
|
||||
|
||||
/// returns true if the latch was already set.
|
||||
pub fn set_sleeping(&self) -> bool {
|
||||
self.inner.fetch_or(Self::SLEEPING, Ordering::Relaxed) & Self::SET == Self::SET
|
||||
}
|
||||
|
||||
pub fn is_sleeping(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::SLEEPING == Self::SLEEPING
|
||||
}
|
||||
|
||||
pub fn is_heartbeat(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::HEARTBEAT == Self::HEARTBEAT
|
||||
}
|
||||
|
||||
pub fn is_wakeup(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::WAKEUP == Self::WAKEUP
|
||||
}
|
||||
|
||||
pub fn is_set(&self) -> bool {
|
||||
self.inner.load(Ordering::Relaxed) & Self::SET == Self::SET
|
||||
}
|
||||
|
||||
pub fn set_wakeup(&self) {
|
||||
self.inner.fetch_or(Self::WAKEUP, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn set_heartbeat(&self) {
|
||||
self.inner.fetch_or(Self::HEARTBEAT, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// returns true if the latch was previously sleeping.
|
||||
#[inline]
|
||||
pub unsafe fn set(this: *const Self) -> bool {
|
||||
unsafe {
|
||||
let old = (*this).inner.swap(Self::SET, Ordering::Release);
|
||||
old == Self::SLEEPING
|
||||
let old = (*this).inner.fetch_or(Self::SET, Ordering::Relaxed);
|
||||
old & Self::SLEEPING == Self::SLEEPING
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -68,7 +115,7 @@ impl Latch for AtomicLatch {
|
|||
impl Probe for AtomicLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
self.inner.load(Ordering::Acquire) == Self::SET
|
||||
self.inner.load(Ordering::Relaxed) & Self::SET != 0
|
||||
}
|
||||
}
|
||||
impl AsCoreLatch for AtomicLatch {
|
||||
|
@ -142,80 +189,28 @@ impl Probe for NopLatch {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ThreadWakeLatch {
|
||||
waker: Mutex<Option<std::thread::Thread>>,
|
||||
}
|
||||
|
||||
impl ThreadWakeLatch {
|
||||
#[inline]
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
waker: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset(&self) {
|
||||
let mut waker = self.waker.lock();
|
||||
*waker = None;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn set_waker(&self, thread: std::thread::Thread) {
|
||||
let mut waker = self.waker.lock();
|
||||
*waker = Some(thread);
|
||||
}
|
||||
|
||||
pub unsafe fn wait(&self) {
|
||||
assert!(
|
||||
self.waker.lock().replace(std::thread::current()).is_none(),
|
||||
"ThreadWakeLatch can only be waited once per thread"
|
||||
);
|
||||
|
||||
std::thread::park();
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for ThreadWakeLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
if let Some(thread) = (&*this).waker.lock().take() {
|
||||
thread.unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Probe for ThreadWakeLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
self.waker.lock().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CountLatch<L: Latch> {
|
||||
pub struct CountLatch {
|
||||
count: AtomicUsize,
|
||||
inner: L,
|
||||
inner: AtomicPtr<Parker>,
|
||||
}
|
||||
|
||||
impl<L: Latch> CountLatch<L> {
|
||||
impl CountLatch {
|
||||
#[inline]
|
||||
pub const fn new(inner: L) -> Self {
|
||||
pub const fn new(inner: *const Parker) -> Self {
|
||||
Self {
|
||||
count: AtomicUsize::new(0),
|
||||
inner,
|
||||
inner: AtomicPtr::new(inner as *mut Parker),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_inner(&self, inner: *const Parker) {
|
||||
self.inner.store(inner as *mut Parker, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn count(&self) -> usize {
|
||||
self.count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &L {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn increment(&self) {
|
||||
self.count.fetch_add(1, Ordering::Release);
|
||||
|
@ -223,63 +218,77 @@ impl<L: Latch> CountLatch<L> {
|
|||
|
||||
#[inline]
|
||||
pub fn decrement(&self) {
|
||||
if self.count.fetch_sub(1, Ordering::Release) == 1 {
|
||||
unsafe {
|
||||
Latch::set_raw(&self.inner);
|
||||
unsafe {
|
||||
Latch::set_raw(self);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for CountLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
if (&*this).count.fetch_sub(1, Ordering::Relaxed) == 1 {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("CountLatch set_raw: count was 1, setting inner latch");
|
||||
// If the count was 1, we need to set the inner latch.
|
||||
let inner = (*this).inner.load(Ordering::Relaxed);
|
||||
if !inner.is_null() {
|
||||
(&*inner).unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Latch> Latch for CountLatch<L> {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
let this = &*this;
|
||||
this.decrement();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Latch + Probe> Probe for CountLatch<L> {
|
||||
impl Probe for CountLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
self.inner.probe()
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Latch + AsCoreLatch> AsCoreLatch for CountLatch<L> {
|
||||
#[inline]
|
||||
fn as_core_latch(&self) -> &CoreLatch {
|
||||
self.inner.as_core_latch()
|
||||
self.count.load(Ordering::Relaxed) == 0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MutexLatch {
|
||||
inner: Mutex<bool>,
|
||||
inner: AtomicLatch,
|
||||
lock: Mutex<()>,
|
||||
condvar: Condvar,
|
||||
}
|
||||
|
||||
unsafe impl Send for MutexLatch {}
|
||||
unsafe impl Sync for MutexLatch {}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum WakeResult {
|
||||
Wake,
|
||||
Heartbeat,
|
||||
Set,
|
||||
}
|
||||
|
||||
impl MutexLatch {
|
||||
#[inline]
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
inner: Mutex::new(false),
|
||||
inner: AtomicLatch::new(),
|
||||
lock: Mutex::new(()),
|
||||
condvar: Condvar::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset(&self) {
|
||||
let mut guard = self.inner.lock();
|
||||
*guard = false;
|
||||
let _guard = self.lock.lock();
|
||||
// SAFETY: inner is atomic, so we can safely access it.
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
pub fn wait(&self) {
|
||||
let mut guard = self.inner.lock();
|
||||
while !*guard {
|
||||
pub fn wait_and_reset(&self) {
|
||||
// SAFETY: inner is locked by the mutex, so we can safely access it.
|
||||
let mut guard = self.lock.lock();
|
||||
while !self.inner.probe() {
|
||||
self.condvar.wait(&mut guard);
|
||||
}
|
||||
|
||||
self.inner.reset();
|
||||
}
|
||||
|
||||
pub fn set(&self) {
|
||||
|
@ -287,22 +296,17 @@ impl MutexLatch {
|
|||
Latch::set_raw(self);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait_and_reset(&self) {
|
||||
let mut guard = self.inner.lock();
|
||||
while !*guard {
|
||||
self.condvar.wait(&mut guard);
|
||||
}
|
||||
*guard = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for MutexLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
// SAFETY: `this` is valid until the guard is dropped.
|
||||
unsafe {
|
||||
*(&*this).inner.lock() = true;
|
||||
(&*this).condvar.notify_all();
|
||||
let this = &*this;
|
||||
let _guard = this.lock.lock();
|
||||
Latch::set_raw(&this.inner);
|
||||
this.condvar.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -310,58 +314,224 @@ impl Latch for MutexLatch {
|
|||
impl Probe for MutexLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
*self.inner.lock()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WakeLatch {
|
||||
inner: AtomicLatch,
|
||||
context: Arc<Context>,
|
||||
worker_index: AtomicUsize,
|
||||
}
|
||||
|
||||
impl WakeLatch {
|
||||
pub fn new(context: Arc<Context>, worker_index: usize) -> Self {
|
||||
Self {
|
||||
inner: AtomicLatch::new(),
|
||||
context,
|
||||
worker_index: AtomicUsize::new(worker_index),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_worker_index(&self, worker_index: usize) {
|
||||
self.worker_index.store(worker_index, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Latch for WakeLatch {
|
||||
#[inline]
|
||||
unsafe fn set_raw(this: *const Self) {
|
||||
unsafe {
|
||||
let ctx = (&*this).context.clone();
|
||||
let worker_index = (&*this).worker_index.load(Ordering::Relaxed);
|
||||
|
||||
if CoreLatch::set(&(&*this).inner) {
|
||||
// If the latch was sleeping, wake the worker thread
|
||||
ctx.shared().heartbeats.get(&worker_index).and_then(|weak| {
|
||||
weak.upgrade()
|
||||
.map(|heartbeat| Latch::set_raw(&heartbeat.latch))
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Probe for WakeLatch {
|
||||
#[inline]
|
||||
fn probe(&self) -> bool {
|
||||
let _guard = self.lock.lock();
|
||||
// SAFETY: inner is atomic, so we can safely access it.
|
||||
self.inner.probe()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsCoreLatch for WakeLatch {
|
||||
impl AsCoreLatch for MutexLatch {
|
||||
#[inline]
|
||||
fn as_core_latch(&self) -> &CoreLatch {
|
||||
&self.inner
|
||||
// SAFETY: inner is atomic, so we can safely access it.
|
||||
self.inner.as_core_latch()
|
||||
}
|
||||
}
|
||||
|
||||
// The worker waits on this latch whenever it has nothing to do.
|
||||
#[derive(Debug)]
|
||||
pub struct WorkerLatch {
|
||||
// this boolean is set when the worker is waiting.
|
||||
mutex: Mutex<bool>,
|
||||
condvar: Condvar,
|
||||
}
|
||||
|
||||
impl WorkerLatch {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
mutex: Mutex::new(false),
|
||||
condvar: Condvar::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(this = self as *const Self as usize)
|
||||
)
|
||||
)]
|
||||
pub fn lock(&self) -> parking_lot::MutexGuard<'_, bool> {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("aquiring mutex..");
|
||||
let guard = self.mutex.lock();
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("mutex acquired.");
|
||||
|
||||
guard
|
||||
}
|
||||
|
||||
pub unsafe fn force_unlock(&self) {
|
||||
unsafe {
|
||||
self.mutex.force_unlock();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait(&self) {
|
||||
let condvar = &self.condvar;
|
||||
let mut guard = self.lock();
|
||||
|
||||
Self::wait_internal(condvar, &mut guard);
|
||||
}
|
||||
|
||||
fn wait_internal(condvar: &Condvar, guard: &mut parking_lot::MutexGuard<'_, bool>) {
|
||||
**guard = true; // set the mutex to true to indicate that the worker is waiting
|
||||
//condvar.wait_for(guard, std::time::Duration::from_micros(100));
|
||||
condvar.wait(guard);
|
||||
**guard = false;
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip_all, fields(
|
||||
this = self as *const Self as usize,
|
||||
)))]
|
||||
pub fn wait_unless<F>(&self, mut f: F)
|
||||
where
|
||||
F: FnMut() -> bool,
|
||||
{
|
||||
let mut guard = self.lock();
|
||||
if !f() {
|
||||
Self::wait_internal(&self.condvar, &mut guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip_all, fields(
|
||||
this = self as *const Self as usize,
|
||||
)))]
|
||||
pub fn wait_until<F, T>(&self, mut f: F) -> T
|
||||
where
|
||||
F: FnMut() -> Option<T>,
|
||||
{
|
||||
let mut guard = self.lock();
|
||||
loop {
|
||||
if let Some(result) = f() {
|
||||
return result;
|
||||
}
|
||||
Self::wait_internal(&self.condvar, &mut guard);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_waiting(&self) -> bool {
|
||||
*self.mutex.lock()
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(level = "trace", skip_all, fields(
|
||||
this = self as *const Self as usize,
|
||||
)))]
|
||||
fn notify(&self) {
|
||||
let n = self.condvar.notify_all();
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("WorkerLatch notify: notified {} threads", n);
|
||||
}
|
||||
|
||||
pub fn wake(&self) {
|
||||
self.notify();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{ptr, sync::Arc};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_atomic_latch() {
|
||||
let latch = AtomicLatch::new();
|
||||
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
||||
unsafe {
|
||||
assert!(!latch.probe());
|
||||
AtomicLatch::set_raw(&latch);
|
||||
}
|
||||
assert_eq!(latch.get(), AtomicLatch::SET);
|
||||
assert!(latch.probe());
|
||||
latch.unset();
|
||||
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn core_latch_sleep() {
|
||||
let latch = AtomicLatch::new();
|
||||
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
||||
latch.set_sleeping();
|
||||
assert_eq!(latch.get(), AtomicLatch::SLEEPING);
|
||||
unsafe {
|
||||
assert!(!latch.probe());
|
||||
assert!(AtomicLatch::set(&latch));
|
||||
}
|
||||
assert_eq!(latch.get(), AtomicLatch::SET | AtomicLatch::SLEEPING);
|
||||
assert!(latch.probe());
|
||||
latch.reset();
|
||||
assert_eq!(latch.get(), AtomicLatch::UNSET);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nop_latch() {
|
||||
assert!(
|
||||
core::mem::size_of::<NopLatch>() == 0,
|
||||
"NopLatch should be zero-sized"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn count_latch() {
|
||||
let latch = CountLatch::new(ptr::null());
|
||||
assert_eq!(latch.count(), 0);
|
||||
latch.increment();
|
||||
assert_eq!(latch.count(), 1);
|
||||
assert!(!latch.probe());
|
||||
latch.increment();
|
||||
assert_eq!(latch.count(), 2);
|
||||
assert!(!latch.probe());
|
||||
|
||||
unsafe {
|
||||
Latch::set_raw(&latch);
|
||||
}
|
||||
assert!(!latch.probe());
|
||||
assert_eq!(latch.count(), 1);
|
||||
|
||||
unsafe {
|
||||
Latch::set_raw(&latch);
|
||||
}
|
||||
assert!(latch.probe());
|
||||
assert_eq!(latch.count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn mutex_latch() {
|
||||
let latch = Arc::new(MutexLatch::new());
|
||||
assert!(!latch.probe());
|
||||
latch.set();
|
||||
assert!(latch.probe());
|
||||
latch.reset();
|
||||
assert!(!latch.probe());
|
||||
|
||||
// Test wait functionality
|
||||
let latch_clone = latch.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::info!("Thread waiting on latch");
|
||||
latch_clone.wait_and_reset();
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::info!("Thread woke up from latch");
|
||||
});
|
||||
|
||||
// Give the thread time to block
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
assert!(!latch.probe());
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::info!("Setting latch from main thread");
|
||||
latch.set();
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::info!("Latch set, joining waiting thread");
|
||||
handle.join().expect("Thread should join successfully");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,16 +7,30 @@
|
|||
unsafe_cell_access,
|
||||
box_as_ptr,
|
||||
box_vec_non_null,
|
||||
strict_provenance_atomic_ptr,
|
||||
btree_extract_if,
|
||||
likely_unlikely,
|
||||
let_chains
|
||||
)]
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
mod channel;
|
||||
mod context;
|
||||
mod heartbeat;
|
||||
mod job;
|
||||
mod join;
|
||||
mod latch;
|
||||
#[cfg(feature = "metrics")]
|
||||
mod metrics;
|
||||
mod queue;
|
||||
mod scope;
|
||||
mod threadpool;
|
||||
pub mod util;
|
||||
mod workerthread;
|
||||
|
||||
pub use context::run_in_worker;
|
||||
pub use join::join;
|
||||
pub use scope::{Scope, scope};
|
||||
pub use threadpool::ThreadPool;
|
||||
pub use workerthread::WorkerThread;
|
||||
|
|
12
distaff/src/metrics.rs
Normal file
12
distaff/src/metrics.rs
Normal file
|
@ -0,0 +1,12 @@
|
|||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct WorkerMetrics {
|
||||
pub(crate) num_jobs_shared: AtomicU32,
|
||||
pub(crate) num_heartbeats: AtomicU32,
|
||||
pub(crate) num_joins: AtomicU32,
|
||||
pub(crate) num_jobs_reclaimed: AtomicU32,
|
||||
pub(crate) num_jobs_executed: AtomicU32,
|
||||
pub(crate) num_jobs_stolen: AtomicU32,
|
||||
pub(crate) num_sent_to_self: AtomicU32,
|
||||
}
|
655
distaff/src/queue.rs
Normal file
655
distaff/src/queue.rs
Normal file
|
@ -0,0 +1,655 @@
|
|||
use std::{
|
||||
cell::UnsafeCell,
|
||||
collections::HashMap,
|
||||
marker::{PhantomData, PhantomPinned},
|
||||
mem::{self, MaybeUninit},
|
||||
pin::Pin,
|
||||
ptr::{self, NonNull},
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU32, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use werkzeug::CachePadded;
|
||||
use werkzeug::sync::Parker;
|
||||
|
||||
use werkzeug::ptr::TaggedAtomicPtr;
|
||||
|
||||
// A Queue with multiple receivers and multiple producers, where a producer can send a message to one of any of the receivers (any-cast), or one of the receivers (uni-cast).
|
||||
// After being woken up from waiting on a message, the receiver will look up the index of the message in the queue and return it.
|
||||
|
||||
struct QueueInner<T> {
|
||||
receivers: HashMap<ReceiverToken, CachePadded<(Slot<T>, bool)>>,
|
||||
messages: Vec<T>,
|
||||
_phantom: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
pub struct Queue<T> {
|
||||
inner: UnsafeCell<QueueInner<T>>,
|
||||
lock: AtomicU32,
|
||||
}
|
||||
|
||||
unsafe impl<T> Send for Queue<T> {}
|
||||
unsafe impl<T> Sync for Queue<T> where T: Send {}
|
||||
|
||||
enum SlotKey {
|
||||
Owned(ReceiverToken),
|
||||
Indexed(usize),
|
||||
}
|
||||
|
||||
pub struct Receiver<T> {
|
||||
queue: Arc<Queue<T>>,
|
||||
lock: Pin<Box<(Parker, PhantomPinned)>>,
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Sender<T> {
|
||||
queue: Arc<Queue<T>>,
|
||||
}
|
||||
|
||||
// TODO: make this a linked list of slots so we can queue multiple messages for
|
||||
// a single receiver
|
||||
const SLOT_ALIGN: u8 = core::mem::align_of::<usize>().ilog2() as u8;
|
||||
struct Slot<T> {
|
||||
value: UnsafeCell<MaybeUninit<T>>,
|
||||
next_and_state: TaggedAtomicPtr<Self, SLOT_ALIGN>,
|
||||
_phantom: PhantomData<Self>,
|
||||
}
|
||||
|
||||
impl<T> Slot<T> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
value: UnsafeCell::new(MaybeUninit::uninit()),
|
||||
next_and_state: TaggedAtomicPtr::new(ptr::null_mut(), 0), // 0 means empty
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_set(&self) -> bool {
|
||||
self.next_and_state.tag(Ordering::Acquire) == 1
|
||||
}
|
||||
|
||||
unsafe fn pop(&self) -> Option<T> {
|
||||
NonNull::new(self.next_and_state.ptr(Ordering::Acquire))
|
||||
.and_then(|next| {
|
||||
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
|
||||
unsafe { next.as_ref().pop() }
|
||||
})
|
||||
.or_else(|| {
|
||||
if self
|
||||
.next_and_state
|
||||
.swap_tag(0, Ordering::AcqRel, Ordering::Relaxed)
|
||||
== 1
|
||||
{
|
||||
// SAFETY: The value is only initialized when the state is set to 1.
|
||||
Some(unsafe { self.value.as_ref_unchecked().assume_init_read() })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// this operation isn't atomic.
|
||||
unsafe fn pop_front(&self) -> Option<T> {
|
||||
// SAFETY: The caller must ensure that they have exclusive access to the slot
|
||||
if self.is_set() {
|
||||
let next = self.next_ptr();
|
||||
unsafe { (next.as_ref()).pop_front() }
|
||||
} else {
|
||||
// SAFETY: The value is only initialized when the state is set to 1.
|
||||
if self.next_and_state.tag(Ordering::Acquire) == 1 {
|
||||
Some(unsafe { self.value.as_ref_unchecked().assume_init_read() })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// the caller must ensure that they have exclusive access to the slot
|
||||
unsafe fn push(&self, value: T) {
|
||||
if self.is_set() {
|
||||
let next = self.next_ptr();
|
||||
unsafe {
|
||||
(next.as_ref()).push(value);
|
||||
}
|
||||
} else {
|
||||
// SAFETY: The value is only initialized when the state is set to 1.
|
||||
unsafe { self.value.as_mut_unchecked().write(value) };
|
||||
self.next_and_state
|
||||
.set_tag(1, Ordering::Release, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn next_ptr(&self) -> NonNull<Slot<T>> {
|
||||
if let Some(next) = NonNull::new(self.next_and_state.ptr(Ordering::Acquire)) {
|
||||
next.cast()
|
||||
} else {
|
||||
self.alloc_next()
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_next(&self) -> NonNull<Slot<T>> {
|
||||
let next = Box::into_raw(Box::new(Slot::new()));
|
||||
|
||||
let next = loop {
|
||||
match self.next_and_state.compare_exchange_weak_ptr(
|
||||
ptr::null_mut(),
|
||||
next,
|
||||
Ordering::Release,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => break next,
|
||||
Err(other) => {
|
||||
if other.is_null() {
|
||||
eprintln!("What the sigma? Slot::alloc_next: other is null");
|
||||
continue;
|
||||
}
|
||||
// next was allocated under us, so we need to drop the slot we just allocated again.
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"Slot::alloc_next: next was allocated under us, dropping it. ours: {:p}, other: {:p}",
|
||||
next,
|
||||
other
|
||||
);
|
||||
_ = unsafe { Box::from_raw(next) };
|
||||
break other;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
unsafe {
|
||||
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
|
||||
NonNull::new_unchecked(next)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Slot<T> {
|
||||
fn drop(&mut self) {
|
||||
// drop next chain
|
||||
if let Some(next) = NonNull::new(self.next_and_state.swap_ptr(
|
||||
ptr::null_mut(),
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
)) {
|
||||
// SAFETY: The next slot is a valid pointer to a Slot<T> that was allocated by us.
|
||||
// We drop this in place because idk..
|
||||
unsafe {
|
||||
next.drop_in_place();
|
||||
_ = Box::<mem::ManuallyDrop<Self>>::from_non_null(next.cast());
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: The value is only initialized when the state is set to 1.
|
||||
if mem::needs_drop::<T>() && self.next_and_state.tag(Ordering::Acquire) == 1 {
|
||||
unsafe { self.value.as_mut_unchecked().assume_init_drop() };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// const BLOCK_SIZE: usize = 8;
|
||||
// struct Block<T> {
|
||||
// next: AtomicPtr<Block<T>>,
|
||||
// slots: [CachePadded<Slot<T>>; BLOCK_SIZE],
|
||||
// }
|
||||
|
||||
/// A token that can be used to identify a specific receiver in a queue.
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub struct ReceiverToken(werkzeug::util::Send<NonNull<u32>>);
|
||||
|
||||
impl ReceiverToken {
|
||||
pub fn as_ptr(&self) -> *mut u32 {
|
||||
self.0.into_inner().as_ptr()
|
||||
}
|
||||
|
||||
pub unsafe fn as_parker(&self) -> &Parker {
|
||||
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
|
||||
unsafe { Parker::from_ptr(self.as_ptr()) }
|
||||
}
|
||||
|
||||
pub unsafe fn from_parker(parker: &Parker) -> Self {
|
||||
// SAFETY: The pointer is guaranteed to be valid and aligned, as it comes from a pinned Parker.
|
||||
let ptr = NonNull::from(parker).cast::<u32>();
|
||||
ReceiverToken(werkzeug::util::Send(ptr))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Queue<T> {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
inner: UnsafeCell::new(QueueInner {
|
||||
messages: Vec::new(),
|
||||
receivers: HashMap::new(),
|
||||
_phantom: PhantomData,
|
||||
}),
|
||||
lock: AtomicU32::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_sender(self: &Arc<Self>) -> Sender<T> {
|
||||
Sender {
|
||||
queue: self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_receivers(self: &Arc<Self>) -> usize {
|
||||
let _guard = self.lock();
|
||||
self.inner().receivers.len()
|
||||
}
|
||||
|
||||
pub fn as_sender(self: &Arc<Self>) -> &Sender<T> {
|
||||
unsafe { mem::transmute::<&Arc<Self>, &Sender<T>>(self) }
|
||||
}
|
||||
|
||||
pub fn new_receiver(self: &Arc<Self>) -> Receiver<T> {
|
||||
let recv = Receiver {
|
||||
queue: self.clone(),
|
||||
lock: Box::pin((Parker::new(), PhantomPinned)),
|
||||
};
|
||||
|
||||
// allocate slot for the receiver
|
||||
let token = recv.get_token();
|
||||
let _guard = recv.queue.lock();
|
||||
recv.queue
|
||||
.inner()
|
||||
.receivers
|
||||
.insert(token, CachePadded::new((Slot::new(), false)));
|
||||
|
||||
drop(_guard);
|
||||
recv
|
||||
}
|
||||
|
||||
fn lock(&self) -> impl Drop {
|
||||
unsafe {
|
||||
let lock = werkzeug::sync::Lock::from_ptr(&self.lock as *const _ as _);
|
||||
lock.lock();
|
||||
werkzeug::drop_guard::DropGuard::new(|| lock.unlock())
|
||||
}
|
||||
}
|
||||
|
||||
fn inner(&self) -> &mut QueueInner<T> {
|
||||
// SAFETY: The inner is only accessed while the queue is locked.
|
||||
unsafe { &mut *self.inner.get() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> QueueInner<T> {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn poll(&mut self, token: ReceiverToken) -> Option<T> {
|
||||
// check if someone has sent a message to this receiver
|
||||
let CachePadded((slot, _)) = self.receivers.get(&token)?;
|
||||
|
||||
unsafe { slot.pop() }.or_else(|| {
|
||||
// if the slot is empty, we can check the indexed messages
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("QueueInner::poll: checking open messages");
|
||||
|
||||
self.messages.pop()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Receiver<T> {
|
||||
pub fn get_token(&self) -> ReceiverToken {
|
||||
// the token is just the pointer to the lock of this receiver.
|
||||
// the lock is pinned, so it's address is stable across calls to `receive`.
|
||||
|
||||
ReceiverToken(werkzeug::util::Send(NonNull::from(&self.lock.0).cast()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Receiver<T> {
|
||||
fn drop(&mut self) {
|
||||
if mem::needs_drop::<T>() {
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
// remove the receiver from the queue
|
||||
_ = queue.receivers.remove(&self.get_token());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send> Receiver<T> {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn recv(&self) -> T {
|
||||
let token = self.get_token();
|
||||
|
||||
loop {
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
// check if someone has sent a message to this receiver
|
||||
if let Some(t) = queue.poll(token) {
|
||||
queue.receivers.get_mut(&token).unwrap().1 = false; // mark the slot as not parked
|
||||
return t;
|
||||
}
|
||||
|
||||
// there was no message for this receiver, so we need to park it
|
||||
queue.receivers.get_mut(&token).unwrap().1 = true; // mark the slot as parked
|
||||
|
||||
self.lock.0.park_with_callback(move || {
|
||||
// drop the lock guard after having set the lock state to waiting.
|
||||
// this avoids a deadlock if the sender tries to send a message
|
||||
// while the receiver is in the process of parking (I think..)
|
||||
drop(_guard);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn try_recv(&self) -> Option<T> {
|
||||
let token = self.get_token();
|
||||
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
// check if someone has sent a message to this receiver
|
||||
queue.poll(token)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send> Sender<T> {
|
||||
/// Sends a message to one of the receivers in the queue, or makes it
|
||||
/// available to any receiver that will park in the future.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn anycast(&self, value: T) {
|
||||
let _guard = self.queue.lock();
|
||||
|
||||
// SAFETY: The queue is locked, so we can safely access the inner queue.
|
||||
match unsafe { self.try_anycast_inner(value) } {
|
||||
Ok(_) => {}
|
||||
Err(value) => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"Queue::anycast: no parked receiver found, adding message to indexed slots"
|
||||
);
|
||||
|
||||
// no parked receiver found, so we want to add the message to the indexed slots
|
||||
let queue = self.queue.inner();
|
||||
queue.messages.push(value);
|
||||
|
||||
// waking up a parked receiver is not necessary here, as any
|
||||
// receivers that don't have a free slot are currently waking up.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_anycast(&self, value: T) -> Result<(), T> {
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
|
||||
// SAFETY: The queue is locked, so we can safely access the inner queue.
|
||||
unsafe { self.try_anycast_inner(value) }
|
||||
}
|
||||
|
||||
/// The caller must hold the lock on the queue for the duration of this function.
|
||||
unsafe fn try_anycast_inner(&self, value: T) -> Result<(), T> {
|
||||
// look for a receiver that is parked
|
||||
let queue = self.queue.inner();
|
||||
if let Some((token, slot)) =
|
||||
queue
|
||||
.receivers
|
||||
.iter()
|
||||
.find_map(|(token, CachePadded((slot, is_parked)))| {
|
||||
// ensure the slot is available
|
||||
if *is_parked && !slot.is_set() {
|
||||
Some((*token, slot))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
{
|
||||
// we found a receiver that is parked, so we can send the message to it
|
||||
unsafe {
|
||||
slot.value.as_mut_unchecked().write(value);
|
||||
slot.next_and_state
|
||||
.set_tag(1, Ordering::Release, Ordering::Relaxed);
|
||||
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a message to a specific receiver, waking it if it is parked.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn unicast(&self, value: T, receiver: ReceiverToken) -> Result<(), T> {
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
let Some(CachePadded((slot, _))) = queue.receivers.get_mut(&receiver) else {
|
||||
return Err(value);
|
||||
};
|
||||
|
||||
unsafe {
|
||||
slot.push(value);
|
||||
}
|
||||
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
Parker::from_ptr(receiver.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn broadcast(&self, value: T)
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
// send the message to all receivers
|
||||
for (token, CachePadded((slot, _))) in queue.receivers.iter() {
|
||||
// SAFETY: The slot is owned by this receiver.
|
||||
|
||||
unsafe { slot.push(value.clone()) };
|
||||
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn broadcast_with<F>(&self, mut f: F)
|
||||
where
|
||||
F: FnMut() -> T,
|
||||
{
|
||||
// lock the queue
|
||||
let _guard = self.queue.lock();
|
||||
let queue = self.queue.inner();
|
||||
|
||||
// send the message to all receivers
|
||||
for (token, CachePadded((slot, _))) in queue.receivers.iter() {
|
||||
// SAFETY: The slot is owned by this receiver.
|
||||
|
||||
unsafe { slot.push(f()) };
|
||||
|
||||
// check if the receiver is parked
|
||||
// wake the receiver
|
||||
unsafe {
|
||||
Parker::from_ptr(token.0.into_inner().as_ptr()).unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_queue() {
|
||||
let queue = Queue::<i32>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
let receiver1 = queue.new_receiver();
|
||||
let receiver2 = queue.new_receiver();
|
||||
|
||||
let token2 = receiver2.get_token();
|
||||
|
||||
sender.anycast(42);
|
||||
|
||||
assert_eq!(receiver1.recv(), 42);
|
||||
|
||||
sender.unicast(100, token2).unwrap();
|
||||
assert_eq!(receiver1.try_recv(), None);
|
||||
assert_eq!(receiver2.recv(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_broadcast() {
|
||||
let queue = Queue::<i32>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
let receiver1 = queue.new_receiver();
|
||||
let receiver2 = queue.new_receiver();
|
||||
|
||||
sender.broadcast(42);
|
||||
|
||||
assert_eq!(receiver1.recv(), 42);
|
||||
assert_eq!(receiver2.recv(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_multiple_messages() {
|
||||
let queue = Queue::<i32>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
let receiver = queue.new_receiver();
|
||||
|
||||
sender.anycast(1);
|
||||
sender.unicast(2, receiver.get_token()).unwrap();
|
||||
|
||||
assert_eq!(receiver.recv(), 2);
|
||||
assert_eq!(receiver.recv(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_threaded() {
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Message {
|
||||
Send(i32),
|
||||
Exit,
|
||||
}
|
||||
|
||||
let queue = Queue::<Message>::new();
|
||||
|
||||
let sender = queue.new_sender();
|
||||
|
||||
let threads = (0..5)
|
||||
.map(|_| {
|
||||
let queue_clone = queue.clone();
|
||||
let receiver = queue_clone.new_receiver();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
loop {
|
||||
match receiver.recv() {
|
||||
Message::Send(value) => {
|
||||
println!("Receiver {:?} Received: {}", receiver.get_token(), value);
|
||||
}
|
||||
Message::Exit => {
|
||||
println!("Exiting thread");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Send messages to the receivers
|
||||
for i in 0..10 {
|
||||
sender.anycast(Message::Send(i));
|
||||
}
|
||||
|
||||
// Send exit messages to all receivers
|
||||
sender.broadcast(Message::Exit);
|
||||
for thread in threads {
|
||||
thread.join().unwrap();
|
||||
}
|
||||
println!("All threads have exited.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drop_slot() {
|
||||
// Test that dropping a slot does not cause a double free or panic
|
||||
let slot = Slot::<i32>::new();
|
||||
unsafe {
|
||||
slot.push(42);
|
||||
drop(slot);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drop_slot_chain() {
|
||||
struct DropCheck<'a>(&'a AtomicU32);
|
||||
impl Drop for DropCheck<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.0.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> DropCheck<'a> {
|
||||
fn new(counter: &'a AtomicU32) -> Self {
|
||||
counter.fetch_add(1, Ordering::SeqCst);
|
||||
Self(counter)
|
||||
}
|
||||
}
|
||||
let counter = AtomicU32::new(0);
|
||||
let slot = Slot::<DropCheck>::new();
|
||||
for _ in 0..10 {
|
||||
unsafe {
|
||||
slot.push(DropCheck::new(&counter));
|
||||
}
|
||||
}
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 10);
|
||||
drop(slot);
|
||||
assert_eq!(
|
||||
counter.load(Ordering::SeqCst),
|
||||
0,
|
||||
"All DropCheck instances should have been dropped"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_self() {
|
||||
// Test that sending a message to self works
|
||||
let queue = Queue::<i32>::new();
|
||||
let sender = queue.new_sender();
|
||||
let receiver = queue.new_receiver();
|
||||
|
||||
sender.unicast(42, receiver.get_token()).unwrap();
|
||||
assert_eq!(receiver.recv(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_self_many() {
|
||||
// Test that sending multiple messages to self works
|
||||
let queue = Queue::<i32>::new();
|
||||
let sender = queue.new_sender();
|
||||
let receiver = queue.new_receiver();
|
||||
|
||||
for i in 0..10 {
|
||||
sender.unicast(i, receiver.get_token()).unwrap();
|
||||
}
|
||||
|
||||
for i in (0..10).rev() {
|
||||
assert_eq!(receiver.recv(), i);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,107 +1,187 @@
|
|||
use std::{
|
||||
any::Any,
|
||||
marker::PhantomData,
|
||||
marker::{PhantomData, PhantomPinned},
|
||||
panic::{AssertUnwindSafe, catch_unwind},
|
||||
pin::{self, Pin},
|
||||
ptr::{self, NonNull},
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicPtr, Ordering},
|
||||
atomic::{AtomicPtr, AtomicUsize, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use async_task::Runnable;
|
||||
use werkzeug::util;
|
||||
|
||||
use crate::{
|
||||
context::{Context, run_in_worker},
|
||||
job::{HeapJob, Job},
|
||||
latch::{AsCoreLatch, CountLatch, WakeLatch},
|
||||
channel::Sender,
|
||||
context::{Context, Message},
|
||||
job::{
|
||||
HeapJob, Job2 as Job, SharedJob,
|
||||
traits::{InlineJob, IntoJob},
|
||||
},
|
||||
latch::{CountLatch, Probe},
|
||||
queue::ReceiverToken,
|
||||
util::{DropGuard, SendPtr},
|
||||
workerthread::WorkerThread,
|
||||
};
|
||||
|
||||
pub struct Scope<'scope> {
|
||||
// thinking:
|
||||
|
||||
// the scope needs to keep track of any spawn() and spawn_async() calls, across all worker threads.
|
||||
// that means, that for any spawn() or spawn_async() calls, we have to share a counter across all worker threads.
|
||||
// we want to minimise the number of atomic operations in general.
|
||||
// atomic operations occur in the following cases:
|
||||
// - when we spawn() or spawn_async() a job, we increment the counter
|
||||
// - when the same job finishes, we decrement the counter
|
||||
// - when a join() job finishes, it's latch is set
|
||||
// - when we wait for a join() job, we loop over the latch until it is set
|
||||
|
||||
// a Scope must keep track of:
|
||||
// - The number of async jobs spawned, which is used to determine when the scope
|
||||
// is complete.
|
||||
// - A panic box, which is set when a job panics and is used to resume the panic
|
||||
// when the scope is completed.
|
||||
// - The Parker of the worker on which the scope was created, which is signaled
|
||||
// when the last outstanding async job finishes.
|
||||
// - The current worker thread in order to avoid having to query the
|
||||
// thread-local storage.
|
||||
|
||||
struct ScopeInner {
|
||||
outstanding_jobs: AtomicUsize,
|
||||
parker: ReceiverToken,
|
||||
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
|
||||
unsafe impl Send for ScopeInner {}
|
||||
unsafe impl Sync for ScopeInner {}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Scope<'scope, 'env: 'scope> {
|
||||
inner: SendPtr<ScopeInner>,
|
||||
worker: SendPtr<WorkerThread>,
|
||||
_scope: PhantomData<&'scope mut &'scope ()>,
|
||||
_env: PhantomData<&'env mut &'env ()>,
|
||||
}
|
||||
|
||||
impl ScopeInner {
|
||||
fn from_worker(worker: &WorkerThread) -> Self {
|
||||
Self {
|
||||
outstanding_jobs: AtomicUsize::new(0),
|
||||
parker: worker.receiver.get_token(),
|
||||
panic: AtomicPtr::new(ptr::null_mut()),
|
||||
}
|
||||
}
|
||||
|
||||
fn increment(&self) {
|
||||
self.outstanding_jobs.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn decrement(&self, worker: &WorkerThread) {
|
||||
if self.outstanding_jobs.fetch_sub(1, Ordering::Relaxed) == 1 {
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::ScopeFinished, self.parker);
|
||||
}
|
||||
}
|
||||
|
||||
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
||||
unsafe {
|
||||
let err = Box::into_raw(Box::new(err));
|
||||
if !self
|
||||
.panic
|
||||
.compare_exchange(ptr::null_mut(), err, Ordering::AcqRel, Ordering::Acquire)
|
||||
.is_ok()
|
||||
{
|
||||
// someone else already set the panic, so we drop the error
|
||||
_ = Box::from_raw(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_propagate_panic(&self) {
|
||||
let err = self.panic.swap(ptr::null_mut(), Ordering::AcqRel);
|
||||
|
||||
if err.is_null() {
|
||||
return;
|
||||
} else {
|
||||
// SAFETY: we have exclusive access to the panic error, so we can safely resume it.
|
||||
unsafe {
|
||||
let err = *Box::from_raw(err);
|
||||
std::panic::resume_unwind(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// find below a sketch of an unbalanced tree:
|
||||
// []
|
||||
// / \
|
||||
// [] []
|
||||
// / \ / \
|
||||
// [] [] [] []
|
||||
// / \ / \
|
||||
// [] [][] []
|
||||
// / \ / \
|
||||
// [] [] [] []
|
||||
// / \ / \
|
||||
// [] [] [] []
|
||||
// / \
|
||||
// [] []
|
||||
|
||||
// in this tree of join() calls, it is possible to wait for a long time, so it is necessary to keep waking up when a job is shared.
|
||||
|
||||
// the worker waits on it's latch, which may be woken by:
|
||||
// - a job finishing
|
||||
// - another thread sharing a job
|
||||
// - the heartbeat waking up the worker // does this make sense? if the thread was sleeping, it didn't have any work to share.
|
||||
|
||||
pub struct Scope2<'scope, 'env: 'scope> {
|
||||
// latch to wait on before the scope finishes
|
||||
job_counter: CountLatch<WakeLatch>,
|
||||
job_counter: CountLatch,
|
||||
// local threadpool
|
||||
context: Arc<Context>,
|
||||
// panic error
|
||||
panic: AtomicPtr<Box<dyn Any + Send + 'static>>,
|
||||
// variant lifetime
|
||||
_pd: PhantomData<fn(&'scope ())>,
|
||||
_scope: PhantomData<&'scope mut &'scope ()>,
|
||||
_env: PhantomData<&'env mut &'env ()>,
|
||||
}
|
||||
|
||||
pub fn scope<'scope, F, R>(f: F) -> R
|
||||
pub fn scope<'env, F, R>(f: F) -> R
|
||||
where
|
||||
F: FnOnce(&Scope<'scope>) -> R + Send,
|
||||
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
||||
R: Send,
|
||||
{
|
||||
Scope::<'scope>::scope(f)
|
||||
scope_with_context(Context::global_context(), f)
|
||||
}
|
||||
|
||||
impl<'scope> Scope<'scope> {
|
||||
fn wait_for_jobs(&self, worker: &WorkerThread) {
|
||||
tracing::trace!("waiting for {} jobs to finish.", self.job_counter.count());
|
||||
tracing::trace!("thread id: {:?}, jobs: {:?}", worker.index, unsafe {
|
||||
worker.queue.as_ref_unchecked()
|
||||
});
|
||||
|
||||
// set worker index in the job counter
|
||||
self.job_counter.inner().set_worker_index(worker.index);
|
||||
worker.wait_until_latch(self.job_counter.as_core_latch());
|
||||
}
|
||||
|
||||
pub fn scope<F, R>(f: F) -> R
|
||||
where
|
||||
F: FnOnce(&Self) -> R + Send,
|
||||
R: Send,
|
||||
{
|
||||
run_in_worker(|worker| {
|
||||
// SAFETY: we call complete() after creating this scope, which
|
||||
// ensures that any jobs spawned from the scope exit before the
|
||||
// scope closes.
|
||||
let this = unsafe { Self::from_context(worker.context.clone()) };
|
||||
this.complete(worker, || f(&this))
|
||||
})
|
||||
}
|
||||
|
||||
fn scope_with_context<F, R>(context: Arc<Context>, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&Self) -> R + Send,
|
||||
R: Send,
|
||||
{
|
||||
context.run_in_worker(|worker| {
|
||||
// SAFETY: we call complete() after creating this scope, which
|
||||
// ensures that any jobs spawned from the scope exit before the
|
||||
// scope closes.
|
||||
let this = unsafe { Self::from_context(context.clone()) };
|
||||
this.complete(worker, || f(&this))
|
||||
})
|
||||
}
|
||||
pub fn scope_with_context<'env, F, R>(context: &Arc<Context>, f: F) -> R
|
||||
where
|
||||
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
||||
R: Send,
|
||||
{
|
||||
context.run_in_worker(|worker| {
|
||||
// SAFETY: we call complete() after creating this scope, which
|
||||
// ensures that any jobs spawned from the scope exit before the
|
||||
// scope closes.
|
||||
let inner = pin::pin!(ScopeInner::from_worker(worker));
|
||||
let this = Scope::<'_, 'env>::new(worker, inner.as_ref());
|
||||
this.complete(|| f(this))
|
||||
})
|
||||
}
|
||||
|
||||
impl<'scope, 'env> Scope<'scope, 'env> {
|
||||
/// should be called from within a worker thread.
|
||||
fn complete<F, R>(&self, worker: &WorkerThread, f: F) -> R
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn complete<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce() -> R + Send,
|
||||
R: Send,
|
||||
F: FnOnce() -> R,
|
||||
{
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn make_job<F: FnOnce() -> T, T>(f: F) -> Job<T> {
|
||||
#[align(8)]
|
||||
unsafe fn harness<F: FnOnce() -> T, T>(this: *const (), job: *const Job<T>) {
|
||||
let f = unsafe { Box::from_raw(this.cast::<F>().cast_mut()) };
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(move || f()));
|
||||
|
||||
let job = unsafe { Box::from_raw(job.cast_mut()) };
|
||||
job.complete(result);
|
||||
}
|
||||
|
||||
Job::<T>::new(harness::<F, T>, unsafe {
|
||||
NonNull::new_unchecked(Box::into_raw(Box::new(f))).cast()
|
||||
})
|
||||
}
|
||||
|
||||
let result = match catch_unwind(AssertUnwindSafe(|| f())) {
|
||||
Ok(val) => Some(val),
|
||||
Err(payload) => {
|
||||
|
@ -110,70 +190,107 @@ impl<'scope> Scope<'scope> {
|
|||
}
|
||||
};
|
||||
|
||||
self.wait_for_jobs(worker);
|
||||
self.maybe_propagate_panic();
|
||||
self.wait_for_jobs();
|
||||
let inner = self.inner();
|
||||
inner.maybe_propagate_panic();
|
||||
|
||||
// SAFETY: if result panicked, we would have propagated the panic above.
|
||||
result.unwrap()
|
||||
}
|
||||
|
||||
/// resumes the panic if one happened in this scope.
|
||||
fn maybe_propagate_panic(&self) {
|
||||
let err_ptr = self.panic.load(Ordering::Relaxed);
|
||||
if !err_ptr.is_null() {
|
||||
unsafe {
|
||||
let err = Box::from_raw(err_ptr);
|
||||
std::panic::resume_unwind(*err);
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn wait_for_jobs(&self) {
|
||||
loop {
|
||||
let count = self.inner().outstanding_jobs.load(Ordering::Relaxed);
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("waiting for {} jobs to finish.", count);
|
||||
if count == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
match self.worker().receiver.recv() {
|
||||
Message::Shared(shared_job) => unsafe {
|
||||
SharedJob::execute(shared_job, self.worker());
|
||||
},
|
||||
Message::ScopeFinished => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("scope finished, decrementing outstanding jobs.");
|
||||
assert_eq!(self.inner().outstanding_jobs.load(Ordering::Acquire), 0);
|
||||
break;
|
||||
}
|
||||
Message::WakeUp | Message::Exit => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// stores the first panic that happened in this scope.
|
||||
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
||||
self.panic.load(Ordering::Relaxed).is_null().then(|| {
|
||||
use core::mem::ManuallyDrop;
|
||||
let mut boxed = ManuallyDrop::new(Box::new(err));
|
||||
fn decrement(&self) {
|
||||
self.inner().decrement(self.worker());
|
||||
}
|
||||
|
||||
let err_ptr: *mut Box<dyn Any + Send + 'static> = &mut **boxed;
|
||||
if self
|
||||
.panic
|
||||
.compare_exchange(
|
||||
ptr::null_mut(),
|
||||
err_ptr,
|
||||
Ordering::SeqCst,
|
||||
Ordering::Relaxed,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
// we successfully set the panic, no need to drop
|
||||
} else {
|
||||
// drop the error, someone else already set it
|
||||
_ = ManuallyDrop::into_inner(boxed);
|
||||
}
|
||||
});
|
||||
fn inner(&self) -> &ScopeInner {
|
||||
unsafe { self.inner.as_ref() }
|
||||
}
|
||||
|
||||
/// stores the first panic that happened in this scope.
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn panicked(&self, err: Box<dyn Any + Send + 'static>) {
|
||||
self.inner().panicked(err);
|
||||
}
|
||||
|
||||
pub fn spawn<F>(&self, f: F)
|
||||
where
|
||||
F: FnOnce(&Scope<'scope>) + Send,
|
||||
F: FnOnce(Self) + Send,
|
||||
{
|
||||
self.context.run_in_worker(|worker| {
|
||||
self.job_counter.increment();
|
||||
struct SpawnedJob<F> {
|
||||
f: F,
|
||||
inner: SendPtr<ScopeInner>,
|
||||
}
|
||||
|
||||
let this = SendPtr::new_const(self).unwrap();
|
||||
impl<F> SpawnedJob<F> {
|
||||
fn new<'scope, 'env, T>(f: F, inner: SendPtr<ScopeInner>) -> Job
|
||||
where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
T: Send,
|
||||
{
|
||||
Job::from_harness(
|
||||
Self::harness,
|
||||
Box::into_non_null(Box::new(Self { f, inner })).cast(),
|
||||
)
|
||||
}
|
||||
|
||||
let job = Box::new(HeapJob::new(move || unsafe {
|
||||
_ = f(this.as_ref());
|
||||
this.as_ref().job_counter.decrement();
|
||||
}))
|
||||
.into_boxed_job();
|
||||
#[align(8)]
|
||||
unsafe fn harness<'scope, 'env, T>(
|
||||
worker: &WorkerThread,
|
||||
this: NonNull<()>,
|
||||
_: Option<Sender>,
|
||||
) where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
T: Send,
|
||||
{
|
||||
let Self { f, inner } =
|
||||
unsafe { *Box::<SpawnedJob<F>>::from_non_null(this.cast()) };
|
||||
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, inner) };
|
||||
|
||||
tracing::trace!("allocated heapjob");
|
||||
// SAFETY: we are in a worker thread, so the inner is valid.
|
||||
(f)(scope);
|
||||
}
|
||||
}
|
||||
|
||||
worker.push_front(job);
|
||||
self.inner().increment();
|
||||
let job = SpawnedJob::new(
|
||||
move |scope| {
|
||||
if let Err(payload) = catch_unwind(AssertUnwindSafe(|| f(scope))) {
|
||||
scope.inner().panicked(payload);
|
||||
}
|
||||
|
||||
tracing::trace!("leaked heapjob");
|
||||
});
|
||||
scope.decrement();
|
||||
},
|
||||
self.inner,
|
||||
);
|
||||
|
||||
self.context().inject_job(job.share(None));
|
||||
}
|
||||
|
||||
pub fn spawn_future<T, F>(&self, future: F) -> async_task::Task<T>
|
||||
|
@ -181,87 +298,337 @@ impl<'scope> Scope<'scope> {
|
|||
F: Future<Output = T> + Send + 'scope,
|
||||
T: Send + 'scope,
|
||||
{
|
||||
self.context.run_in_worker(|worker| {
|
||||
self.job_counter.increment();
|
||||
|
||||
let this = SendPtr::new_const(&self.job_counter).unwrap();
|
||||
|
||||
let future = async move {
|
||||
let _guard = DropGuard::new(move || unsafe {
|
||||
this.as_ref().decrement();
|
||||
});
|
||||
future.await
|
||||
};
|
||||
|
||||
let schedule = move |runnable: Runnable| {
|
||||
#[align(8)]
|
||||
unsafe fn harness<T>(this: *const (), job: *const Job<T>) {
|
||||
unsafe {
|
||||
let runnable =
|
||||
Runnable::<()>::from_raw(NonNull::new_unchecked(this.cast_mut()));
|
||||
runnable.run();
|
||||
|
||||
// SAFETY: job was turned into raw
|
||||
drop(Box::from_raw(job.cast_mut()));
|
||||
}
|
||||
}
|
||||
|
||||
let job = Box::new(Job::<T>::new(harness::<T>, runnable.into_raw()));
|
||||
|
||||
// casting into Job<()> here
|
||||
worker.push_front(Box::into_raw(job) as _);
|
||||
};
|
||||
|
||||
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
||||
|
||||
runnable.schedule();
|
||||
|
||||
task
|
||||
})
|
||||
self.spawn_async_internal(move |_| future)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn spawn_async<'a, T, Fut, Fn>(&'a self, f: Fn) -> async_task::Task<T>
|
||||
pub fn spawn_async<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
|
||||
where
|
||||
Fn: FnOnce(&Scope) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
Fn: FnOnce(Self) -> Fut + Send + 'scope,
|
||||
Fut: Future<Output = T> + Send + 'scope,
|
||||
T: Send + 'scope,
|
||||
{
|
||||
let this = SendPtr::new_const(self).unwrap();
|
||||
let future = async move { f(unsafe { this.as_ref() }).await };
|
||||
|
||||
self.spawn_future(future)
|
||||
self.spawn_async_internal(f)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn spawn_async_internal<T, Fut, Fn>(&self, f: Fn) -> async_task::Task<T>
|
||||
where
|
||||
Fn: FnOnce(Self) -> Fut + Send + 'scope,
|
||||
Fut: Future<Output = T> + Send + 'scope,
|
||||
T: Send + 'scope,
|
||||
{
|
||||
self.inner().increment();
|
||||
|
||||
// TODO: make sure this worker lasts long enough for the
|
||||
// reference to remain valid for the duration of the future.
|
||||
let scope = unsafe { Self::new_unchecked(self.worker.as_ref(), self.inner) };
|
||||
|
||||
let future = async move {
|
||||
let _guard = DropGuard::new(move || {
|
||||
scope.decrement();
|
||||
});
|
||||
|
||||
// TODO: handle panics here
|
||||
f(scope).await
|
||||
};
|
||||
|
||||
let schedule = move |runnable: Runnable| {
|
||||
#[align(8)]
|
||||
unsafe fn harness(_: &WorkerThread, this: NonNull<()>, _: Option<Sender>) {
|
||||
unsafe {
|
||||
let runnable = Runnable::<()>::from_raw(this.cast());
|
||||
runnable.run();
|
||||
}
|
||||
}
|
||||
|
||||
let job = Job::<()>::from_harness(harness, runnable.into_raw());
|
||||
|
||||
// casting into Job<()> here
|
||||
self.context().inject_job(job.share(None));
|
||||
// WorkerThread::current_ref()
|
||||
// .expect("spawn_async_internal is run in workerthread.")
|
||||
// .push_front(job);
|
||||
};
|
||||
|
||||
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
|
||||
|
||||
runnable.schedule();
|
||||
|
||||
task
|
||||
}
|
||||
|
||||
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce(&Self) -> RA + Send,
|
||||
B: FnOnce(&Self) -> RB + Send,
|
||||
A: FnOnce(Self) -> RA + Send,
|
||||
B: FnOnce(Self) -> RB,
|
||||
{
|
||||
let worker = WorkerThread::current_ref().expect("join is run in workerthread.");
|
||||
let this = SendPtr::new_const(self).unwrap();
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
|
||||
use std::{
|
||||
cell::UnsafeCell,
|
||||
mem::{self, ManuallyDrop},
|
||||
};
|
||||
|
||||
worker.join_heartbeat_every::<_, _, _, _, 64>(
|
||||
let worker = self.worker();
|
||||
|
||||
struct ScopeJob<F> {
|
||||
f: UnsafeCell<ManuallyDrop<F>>,
|
||||
inner: SendPtr<ScopeInner>,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
|
||||
impl<F> ScopeJob<F> {
|
||||
fn new(f: F, inner: SendPtr<ScopeInner>) -> Self {
|
||||
Self {
|
||||
f: UnsafeCell::new(ManuallyDrop::new(f)),
|
||||
inner,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_job<'scope, 'env, T>(self: Pin<&Self>) -> Job<T>
|
||||
where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
T: Send,
|
||||
{
|
||||
let this = this;
|
||||
move || a(unsafe { this.as_ref() })
|
||||
},
|
||||
Job::from_harness(Self::harness, NonNull::from(&*self).cast())
|
||||
}
|
||||
|
||||
unsafe fn unwrap(&self) -> F {
|
||||
unsafe { ManuallyDrop::take(&mut *self.f.get()) }
|
||||
}
|
||||
|
||||
#[align(8)]
|
||||
unsafe fn harness<'scope, 'env, T>(
|
||||
worker: &WorkerThread,
|
||||
this: NonNull<()>,
|
||||
sender: Option<Sender>,
|
||||
) where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
T: Send,
|
||||
{
|
||||
let this = this;
|
||||
move || b(unsafe { this.as_ref() })
|
||||
},
|
||||
)
|
||||
let this: &ScopeJob<F> = unsafe { this.cast().as_ref() };
|
||||
let sender: Option<Sender<T>> = unsafe { mem::transmute(sender) };
|
||||
let f = unsafe { this.unwrap() };
|
||||
let scope = unsafe { Scope::<'scope, 'env>::new_unchecked(worker, this.inner) };
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(|| f(scope)));
|
||||
|
||||
let sender = sender.unwrap();
|
||||
unsafe {
|
||||
sender.send_as_ref(result);
|
||||
worker
|
||||
.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.unicast(Message::WakeUp, ReceiverToken::from_parker(sender.parker()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'scope, 'env, F, T> IntoJob<T> for Pin<&ScopeJob<F>>
|
||||
where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
T: Send,
|
||||
{
|
||||
fn into_job(self) -> Job<T> {
|
||||
self.into_job()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'scope, 'env, F, T> InlineJob<T> for Pin<&ScopeJob<F>>
|
||||
where
|
||||
F: FnOnce(Scope<'scope, 'env>) -> T + Send,
|
||||
'env: 'scope,
|
||||
T: Send,
|
||||
{
|
||||
fn run_inline(self, worker: &WorkerThread) -> T {
|
||||
unsafe { self.unwrap()(Scope::<'scope, 'env>::new_unchecked(worker, self.inner)) }
|
||||
}
|
||||
}
|
||||
|
||||
let _pinned = ScopeJob::new(a, self.inner);
|
||||
let job = unsafe { Pin::new_unchecked(&_pinned) };
|
||||
|
||||
let (a, b) = worker.join_heartbeat2(job, |_| b(*self));
|
||||
|
||||
// touch job here to ensure it is not dropped before we run the join.
|
||||
drop(_pinned);
|
||||
(a, b)
|
||||
|
||||
// let stack = ScopeJob::new(a, self.inner);
|
||||
// let job = ScopeJob::into_job(&stack);
|
||||
|
||||
// worker.push_back(&job);
|
||||
|
||||
// worker.tick();
|
||||
|
||||
// let rb = match catch_unwind(AssertUnwindSafe(|| b(*self))) {
|
||||
// Ok(val) => val,
|
||||
// Err(payload) => {
|
||||
// #[cfg(feature = "tracing")]
|
||||
// tracing::debug!("join_heartbeat: b panicked, waiting for a to finish");
|
||||
// std::hint::cold_path();
|
||||
|
||||
// // if b panicked, we need to wait for a to finish
|
||||
// let mut receiver = job.take_receiver();
|
||||
// worker.wait_until_pred(|| match &receiver {
|
||||
// Some(recv) => recv.poll().is_some(),
|
||||
// None => {
|
||||
// receiver = job.take_receiver();
|
||||
// false
|
||||
// }
|
||||
// });
|
||||
|
||||
// resume_unwind(payload);
|
||||
// }
|
||||
// };
|
||||
|
||||
// let ra = if let Some(recv) = job.take_receiver() {
|
||||
// match worker.wait_until_recv(recv) {
|
||||
// Some(t) => crate::util::unwrap_or_panic(t),
|
||||
// None => {
|
||||
// #[cfg(feature = "tracing")]
|
||||
// tracing::trace!(
|
||||
// "join_heartbeat: job was shared, but reclaimed, running a() inline"
|
||||
// );
|
||||
// // the job was shared, but not yet stolen, so we get to run the
|
||||
// // job inline
|
||||
// unsafe { stack.unwrap()(*self) }
|
||||
// }
|
||||
// }
|
||||
// } else {
|
||||
// worker.pop_back();
|
||||
|
||||
// unsafe {
|
||||
// // SAFETY: we just popped the job from the queue, so it is safe to unwrap.
|
||||
// #[cfg(feature = "tracing")]
|
||||
// tracing::trace!("join_heartbeat: job was not shared, running a() inline");
|
||||
// stack.unwrap()(*self)
|
||||
// }
|
||||
// };
|
||||
|
||||
// (ra, rb)
|
||||
}
|
||||
|
||||
unsafe fn from_context(ctx: Arc<Context>) -> Self {
|
||||
fn new(worker: &WorkerThread, inner: Pin<&'scope ScopeInner>) -> Self {
|
||||
// SAFETY: we are creating a new scope, so the inner is valid.
|
||||
unsafe { Self::new_unchecked(worker, SendPtr::new_const(&*inner).unwrap()) }
|
||||
}
|
||||
|
||||
unsafe fn new_unchecked(worker: &WorkerThread, inner: SendPtr<ScopeInner>) -> Self {
|
||||
Self {
|
||||
context: ctx.clone(),
|
||||
job_counter: CountLatch::new(WakeLatch::new(ctx, 0)),
|
||||
panic: AtomicPtr::new(ptr::null_mut()),
|
||||
_pd: PhantomData,
|
||||
inner,
|
||||
worker: SendPtr::new_const(worker).unwrap(),
|
||||
_scope: PhantomData,
|
||||
_env: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context(&self) -> &Arc<Context> {
|
||||
unsafe { &self.worker.as_ref().context }
|
||||
}
|
||||
pub fn worker(&self) -> &WorkerThread {
|
||||
unsafe { self.worker.as_ref() }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::AtomicU8;
|
||||
|
||||
use super::*;
|
||||
use crate::ThreadPool;
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn scope_spawn_sync() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let count = Arc::new(AtomicU8::new(0));
|
||||
|
||||
scope_with_context(&pool.context, |scope| {
|
||||
scope.spawn(|_| {
|
||||
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
});
|
||||
});
|
||||
|
||||
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn scope_join_one() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let count = AtomicU8::new(0);
|
||||
|
||||
let a = pool.scope(|scope| {
|
||||
let (a, b) = scope.join(
|
||||
|_| count.fetch_add(1, Ordering::Relaxed) + 4,
|
||||
|_| count.fetch_add(2, Ordering::Relaxed) + 6,
|
||||
);
|
||||
a + b
|
||||
});
|
||||
|
||||
assert_eq!(count.load(Ordering::Relaxed), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn scope_join_many() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
|
||||
fn sum<'scope, 'env>(scope: Scope<'scope, 'env>, n: usize) -> usize {
|
||||
if n == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let (l, r) = scope.join(|s| sum(s, n - 1), |s| sum(s, n - 1));
|
||||
|
||||
l + r + 1
|
||||
}
|
||||
|
||||
pool.scope(|scope| {
|
||||
let total = sum(scope, 5);
|
||||
// assert_eq!(total, 1023);
|
||||
eprintln!("Total sum: {}", total);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn scope_spawn_future() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let mut x = 0;
|
||||
pool.scope(|scope| {
|
||||
let task = scope.spawn_async(|_| async {
|
||||
x += 1;
|
||||
});
|
||||
|
||||
task.detach();
|
||||
});
|
||||
|
||||
assert_eq!(x, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn scope_spawn_many() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let count = Arc::new(AtomicU8::new(0));
|
||||
|
||||
pool.scope(|scope| {
|
||||
for _ in 0..10 {
|
||||
let count = count.clone();
|
||||
scope.spawn(move |_| {
|
||||
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 10);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +1,104 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{Scope, context::Context, scope::scope_with_context};
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct ThreadPool {
|
||||
pub(crate) context: Arc<Context>,
|
||||
}
|
||||
|
||||
impl Drop for ThreadPool {
|
||||
fn drop(&mut self) {
|
||||
// TODO: Ensure that the context is properly cleaned up when the thread pool is dropped.
|
||||
self.context.set_should_exit();
|
||||
}
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
pub fn new_with_threads(num_threads: usize) -> Self {
|
||||
let context = Context::new_with_threads(num_threads);
|
||||
Self { context }
|
||||
}
|
||||
|
||||
/// Creates a new thread pool with a thread per hardware thread.
|
||||
pub fn new() -> Self {
|
||||
let context = Context::new();
|
||||
Self { context }
|
||||
}
|
||||
|
||||
pub fn global() -> &'static Self {
|
||||
// SAFETY: ThreadPool is a transparent wrapper around Arc<Context>,
|
||||
unsafe { core::mem::transmute(Context::global_context()) }
|
||||
}
|
||||
|
||||
pub fn scope<'env, F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: for<'scope> FnOnce(Scope<'scope, 'env>) -> R + Send,
|
||||
R: Send,
|
||||
{
|
||||
scope_with_context(&self.context, f)
|
||||
}
|
||||
|
||||
pub fn spawn<F, R>(&self, f: F)
|
||||
where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
self.context.spawn(f)
|
||||
}
|
||||
|
||||
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
|
||||
where
|
||||
RA: Send,
|
||||
RB: Send,
|
||||
A: FnOnce() -> RA + Send,
|
||||
B: FnOnce() -> RB + Send,
|
||||
{
|
||||
self.context.join(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn pool_spawn_borrow() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let mut x = 0;
|
||||
pool.scope(|scope| {
|
||||
scope.spawn(|_| {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::info!("Incrementing x");
|
||||
x += 1;
|
||||
});
|
||||
});
|
||||
assert_eq!(x, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn pool_spawn_future() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let mut x = 0;
|
||||
let task = pool.scope(|scope| {
|
||||
let task = scope.spawn_async(|_| async {
|
||||
x += 1;
|
||||
});
|
||||
|
||||
task
|
||||
});
|
||||
|
||||
futures::executor::block_on(task);
|
||||
assert_eq!(x, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(all(not(miri), feature = "tracing"), tracing_test::traced_test)]
|
||||
fn pool_join() {
|
||||
let pool = ThreadPool::new_with_threads(1);
|
||||
let (a, b) = pool.join(|| 3 + 4, || 5 * 6);
|
||||
assert_eq!(a, 7);
|
||||
assert_eq!(b, 30);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ impl<T> core::fmt::Pointer for SendPtr<T> {
|
|||
}
|
||||
}
|
||||
|
||||
unsafe impl<T> Send for SendPtr<T> {}
|
||||
unsafe impl<T> core::marker::Send for SendPtr<T> {}
|
||||
|
||||
impl<T> Deref for SendPtr<T> {
|
||||
type Target = NonNull<T>;
|
||||
|
@ -93,12 +93,17 @@ impl<T> SendPtr<T> {
|
|||
pub const unsafe fn new_const_unchecked(ptr: *const T) -> Self {
|
||||
unsafe { Self::new_unchecked(ptr.cast_mut()) }
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn as_ref(&self) -> &T {
|
||||
unsafe { self.0.as_ref() }
|
||||
}
|
||||
}
|
||||
|
||||
/// A tagged atomic pointer that can store a pointer and a tag `BITS` wide in the same space
|
||||
/// as the pointer.
|
||||
/// The pointer must be aligned to `BITS` bits, i.e. `align_of::<T>() >= 2^BITS`.
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug)]
|
||||
pub struct TaggedAtomicPtr<T, const BITS: u8> {
|
||||
ptr: AtomicPtr<()>,
|
||||
_pd: PhantomData<T>,
|
||||
|
@ -133,6 +138,19 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
|
|||
self.ptr.load(order).addr() & Self::mask()
|
||||
}
|
||||
|
||||
pub fn fetch_or_tag(&self, tag: usize, order: Ordering) -> usize {
|
||||
let mask = Self::mask();
|
||||
let old_ptr = self.ptr.fetch_or(tag & mask, order);
|
||||
old_ptr.addr() & mask
|
||||
}
|
||||
|
||||
/// returns the tag and clears it
|
||||
pub fn take_tag(&self, order: Ordering) -> usize {
|
||||
let mask = Self::mask();
|
||||
let old_ptr = self.ptr.fetch_and(!mask, order);
|
||||
old_ptr.addr() & mask
|
||||
}
|
||||
|
||||
/// returns tag
|
||||
#[inline(always)]
|
||||
fn compare_exchange_tag_inner(
|
||||
|
@ -163,7 +181,6 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
|
|||
}
|
||||
|
||||
/// returns tag
|
||||
#[inline]
|
||||
#[allow(dead_code)]
|
||||
pub fn compare_exchange_tag(
|
||||
&self,
|
||||
|
@ -182,7 +199,6 @@ impl<T, const BITS: u8> TaggedAtomicPtr<T, BITS> {
|
|||
}
|
||||
|
||||
/// returns tag
|
||||
#[inline]
|
||||
pub fn compare_exchange_weak_tag(
|
||||
&self,
|
||||
old: usize,
|
||||
|
@ -402,3 +418,106 @@ pub fn available_parallelism() -> usize {
|
|||
.map(|n| n.get())
|
||||
.unwrap_or(1)
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
pub struct Send<T>(pub(self) T);
|
||||
|
||||
unsafe impl<T> core::marker::Send for Send<T> {}
|
||||
|
||||
impl<T> Deref for Send<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl<T> DerefMut for Send<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Send<T> {
|
||||
pub unsafe fn new(value: T) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unwrap_or_panic<T>(result: std::thread::Result<T>) -> T {
|
||||
match result {
|
||||
Ok(value) => value,
|
||||
Err(payload) => std::panic::resume_unwind(payload),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tagged_ptr_zero_tag() {
|
||||
let ptr = Box::into_raw(Box::new(42u32));
|
||||
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0);
|
||||
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0);
|
||||
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
|
||||
|
||||
unsafe {
|
||||
_ = Box::from_raw(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tagged_ptr_exchange() {
|
||||
let ptr = Box::into_raw(Box::new(42u32));
|
||||
let tagged_ptr = TaggedAtomicPtr::<u32, 2>::new(ptr, 0b11);
|
||||
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b11);
|
||||
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
|
||||
|
||||
assert_eq!(
|
||||
tagged_ptr
|
||||
.compare_exchange_tag(0b11, 0b10, Ordering::Relaxed, Ordering::Relaxed)
|
||||
.unwrap(),
|
||||
0b11
|
||||
);
|
||||
|
||||
assert_eq!(tagged_ptr.tag(Ordering::Relaxed), 0b10);
|
||||
assert_eq!(tagged_ptr.ptr(Ordering::Relaxed).as_ptr(), ptr);
|
||||
|
||||
unsafe {
|
||||
_ = Box::from_raw(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn value_inline() {
|
||||
assert!(SmallBox::<u32>::is_inline(), "u32 should be inline");
|
||||
assert!(SmallBox::<u8>::is_inline(), "u8 should be inline");
|
||||
assert!(
|
||||
SmallBox::<Box<u32>>::is_inline(),
|
||||
"Box<u32> should be inline"
|
||||
);
|
||||
assert!(
|
||||
SmallBox::<[u32; 2]>::is_inline(),
|
||||
"[u32; 2] should be inline"
|
||||
);
|
||||
assert!(
|
||||
!SmallBox::<[u32; 3]>::is_inline(),
|
||||
"[u32; 3] should not be inline"
|
||||
);
|
||||
assert!(SmallBox::<usize>::is_inline(), "usize should be inline");
|
||||
|
||||
#[repr(C, align(16))]
|
||||
struct LargeType(u8);
|
||||
assert!(
|
||||
!SmallBox::<LargeType>::is_inline(),
|
||||
"LargeType should not be inline"
|
||||
);
|
||||
|
||||
#[repr(C, align(4))]
|
||||
struct SmallType(u8);
|
||||
assert!(
|
||||
SmallBox::<SmallType>::is_inline(),
|
||||
"SmallType should be inline"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,26 +1,34 @@
|
|||
#[cfg(feature = "metrics")]
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use std::{
|
||||
cell::{Cell, UnsafeCell},
|
||||
ptr::NonNull,
|
||||
sync::Arc,
|
||||
sync::{Arc, Barrier},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crossbeam_utils::CachePadded;
|
||||
use parking_lot_core::SpinWait;
|
||||
#[cfg(feature = "metrics")]
|
||||
use werkzeug::CachePadded;
|
||||
|
||||
use crate::{
|
||||
context::{Context, Heartbeat},
|
||||
job::{Job, JobList, JobResult},
|
||||
latch::{AsCoreLatch, CoreLatch, Probe},
|
||||
channel::Receiver,
|
||||
context::{Context, Message},
|
||||
heartbeat::OwnedHeartbeatReceiver,
|
||||
job::{Job2 as Job, JobQueue as JobList, SharedJob},
|
||||
queue,
|
||||
util::DropGuard,
|
||||
};
|
||||
|
||||
pub struct WorkerThread {
|
||||
pub(crate) context: Arc<Context>,
|
||||
pub(crate) index: usize,
|
||||
pub(crate) receiver: queue::Receiver<Message>,
|
||||
pub(crate) queue: UnsafeCell<JobList>,
|
||||
heartbeat: Arc<CachePadded<Heartbeat>>,
|
||||
pub(crate) heartbeat: OwnedHeartbeatReceiver,
|
||||
pub(crate) join_count: Cell<u8>,
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub(crate) metrics: CachePadded<crate::metrics::WorkerMetrics>,
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
|
@ -29,25 +37,25 @@ thread_local! {
|
|||
|
||||
impl WorkerThread {
|
||||
pub fn new_in(context: Arc<Context>) -> Self {
|
||||
let (heartbeat, index) = context.shared().new_heartbeat();
|
||||
let heartbeat = context.heartbeats.new_heartbeat();
|
||||
|
||||
Self {
|
||||
receiver: context.queue.new_receiver(),
|
||||
context,
|
||||
index,
|
||||
queue: UnsafeCell::new(JobList::new()),
|
||||
heartbeat,
|
||||
join_count: Cell::new(0),
|
||||
#[cfg(feature = "metrics")]
|
||||
metrics: CachePadded::new(crate::metrics::WorkerMetrics::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn new() -> Self {
|
||||
let context = Context::global_context().clone();
|
||||
Self::new_in(context)
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerThread {
|
||||
pub fn run(self: Box<Self>) {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all, fields(
|
||||
worker = self.heartbeat.index(),
|
||||
)))]
|
||||
pub fn run(self: Box<Self>, barrier: Arc<Barrier>) {
|
||||
let this = Box::into_raw(self);
|
||||
unsafe {
|
||||
Self::set_current(this);
|
||||
|
@ -59,107 +67,111 @@ impl WorkerThread {
|
|||
Self::drop_in_place(this);
|
||||
});
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("WorkerThread::run: starting worker thread");
|
||||
|
||||
barrier.wait();
|
||||
unsafe {
|
||||
(&*this).run_inner();
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
unsafe {
|
||||
eprintln!("{:?}", (&*this).metrics);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("WorkerThread::run: worker thread finished");
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn run_inner(&self) {
|
||||
let mut job = self.context.shared().pop_job();
|
||||
'outer: loop {
|
||||
let mut guard = loop {
|
||||
if let Some(job) = job {
|
||||
self.execute(job);
|
||||
}
|
||||
loop {
|
||||
if self.context.should_exit() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut guard = self.context.shared();
|
||||
if guard.should_exit() {
|
||||
// if the context is stopped, break out of the outer loop which
|
||||
// will exit the thread.
|
||||
break 'outer;
|
||||
match self.receiver.recv() {
|
||||
Message::Shared(shared_job) => {
|
||||
self.execute(shared_job);
|
||||
}
|
||||
|
||||
match guard.pop_job() {
|
||||
Some(job) => {
|
||||
tracing::trace!("worker: popping job: {:?}", job);
|
||||
// found job, continue inner loop
|
||||
continue;
|
||||
}
|
||||
None => {
|
||||
tracing::trace!("worker: no job, waiting for shared job");
|
||||
// no more jobs, break out of inner loop and wait for shared job
|
||||
break guard;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.context.shared_job.wait(&mut guard);
|
||||
job = guard.pop_job();
|
||||
Message::Exit => break,
|
||||
Message::WakeUp | Message::ScopeFinished => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerThread {
|
||||
#[inline(always)]
|
||||
fn tick(&self) {
|
||||
if self.heartbeat.is_pending() {
|
||||
/// Checks if the worker thread has received a heartbeat, and if so,
|
||||
/// attempts to share a job with other workers. If a job was popped from
|
||||
/// the queue, but not shared, this function runs the job locally.
|
||||
pub(crate) fn tick(&self) {
|
||||
if self.heartbeat.take() {
|
||||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_heartbeats.fetch_add(1, Ordering::Relaxed);
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!(
|
||||
"received heartbeat, thread id: {:?}",
|
||||
self.heartbeat.index()
|
||||
);
|
||||
|
||||
self.heartbeat_cold();
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn execute(&self, job: NonNull<Job>) {
|
||||
self.tick();
|
||||
Job::execute(job);
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
fn execute(&self, job: SharedJob) {
|
||||
unsafe { SharedJob::execute(job, self) };
|
||||
// TODO: maybe tick here?
|
||||
}
|
||||
|
||||
/// Attempts to share a job with other workers within the same context.
|
||||
/// returns `true` if the job was shared, `false` if it was not.
|
||||
#[cold]
|
||||
fn heartbeat_cold(&self) {
|
||||
let mut guard = self.context.shared();
|
||||
if let Some(job) = self.pop_back() {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("heartbeat: sharing job: {:?}", job);
|
||||
|
||||
if !guard.jobs.contains_key(&self.index) {
|
||||
if let Some(job) = self.pop_back() {
|
||||
tracing::trace!("heartbeat: sharing job: {:?}", job);
|
||||
#[cfg(feature = "metrics")]
|
||||
self.metrics.num_jobs_shared.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if let Err(Message::Shared(job)) =
|
||||
self.context
|
||||
.queue
|
||||
.as_sender()
|
||||
.try_anycast(Message::Shared(unsafe {
|
||||
job.as_ref()
|
||||
.share(Some(self.receiver.get_token().as_parker()))
|
||||
}))
|
||||
{
|
||||
unsafe {
|
||||
job.as_ref().set_pending();
|
||||
SharedJob::execute(job, self);
|
||||
}
|
||||
guard.jobs.insert(self.index, job);
|
||||
self.context.notify_shared_job();
|
||||
}
|
||||
}
|
||||
|
||||
self.heartbeat.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerThread {
|
||||
#[inline]
|
||||
pub fn pop_back(&self) -> Option<NonNull<Job>> {
|
||||
unsafe { self.queue.as_mut_unchecked().pop_back() }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn push_back(&self, job: *const Job) {
|
||||
unsafe { self.queue.as_mut_unchecked().push_back(job) }
|
||||
pub fn push_back<T>(&self, job: *const Job<T>) {
|
||||
unsafe { self.queue.as_mut_unchecked().push_back(job.cast()) }
|
||||
}
|
||||
pub fn push_front<T>(&self, job: *const Job<T>) {
|
||||
unsafe { self.queue.as_mut_unchecked().push_front(job.cast()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn pop_front(&self) -> Option<NonNull<Job>> {
|
||||
unsafe { self.queue.as_mut_unchecked().pop_front() }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn push_front(&self, job: *const Job) {
|
||||
unsafe { self.queue.as_mut_unchecked().push_front(job) }
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerThread {
|
||||
#[inline]
|
||||
pub fn current_ref<'a>() -> Option<&'a Self> {
|
||||
unsafe { (*WORKER.with(UnsafeCell::get)).map(|ptr| ptr.as_ref()) }
|
||||
}
|
||||
|
@ -190,50 +202,50 @@ impl WorkerThread {
|
|||
|
||||
unsafe fn drop_in_place(this: *mut Self) {
|
||||
unsafe {
|
||||
this.drop_in_place();
|
||||
drop(Box::from_raw(this));
|
||||
// SAFETY: this is only called when the thread is exiting, so we can
|
||||
// safely drop the thread. We use `drop_in_place` to prevent `Box`
|
||||
// from creating a no-alias reference to the worker thread.
|
||||
core::ptr::drop_in_place(this);
|
||||
_ = Box::<core::mem::ManuallyDrop<Self>>::from_raw(this as _);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HeartbeatThread {
|
||||
ctx: Arc<Context>,
|
||||
num_workers: usize,
|
||||
}
|
||||
|
||||
impl HeartbeatThread {
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_micros(100);
|
||||
|
||||
pub fn new(ctx: Arc<Context>) -> Self {
|
||||
Self { ctx }
|
||||
pub fn new(ctx: Arc<Context>, num_workers: usize) -> Self {
|
||||
Self { ctx, num_workers }
|
||||
}
|
||||
|
||||
pub fn run(self) {
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
|
||||
pub fn run(self, barrier: Arc<Barrier>) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("new heartbeat thread {:?}", std::thread::current());
|
||||
barrier.wait();
|
||||
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let sleep_for = {
|
||||
let mut guard = self.ctx.shared();
|
||||
if guard.should_exit() {
|
||||
// loop {
|
||||
// if self.ctx.should_exit() || self.ctx.queue.num_receivers() != self.num_workers
|
||||
// {
|
||||
// break;
|
||||
// }
|
||||
|
||||
// self.ctx.heartbeat.park();
|
||||
// }
|
||||
if self.ctx.should_exit() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut n = 0;
|
||||
guard.heartbeats.retain(|_, b| {
|
||||
b.upgrade()
|
||||
.inspect(|heartbeat| {
|
||||
if n == i {
|
||||
if heartbeat.set_pending() {
|
||||
heartbeat.latch.set();
|
||||
}
|
||||
}
|
||||
n += 1;
|
||||
})
|
||||
.is_some()
|
||||
});
|
||||
let num_heartbeats = guard.heartbeats.len();
|
||||
|
||||
drop(guard);
|
||||
self.ctx.heartbeats.notify_nth(i);
|
||||
let num_heartbeats = self.ctx.heartbeats.len();
|
||||
|
||||
if i >= num_heartbeats {
|
||||
i = 0;
|
||||
|
@ -252,145 +264,19 @@ impl HeartbeatThread {
|
|||
}
|
||||
|
||||
impl WorkerThread {
|
||||
#[cold]
|
||||
fn wait_until_latch_cold(&self, latch: &CoreLatch) {
|
||||
// does this optimise?
|
||||
assert!(!latch.probe());
|
||||
|
||||
'outer: while !latch.probe() {
|
||||
// take a shared job, if it exists
|
||||
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
|
||||
self.execute(shared_job);
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
|
||||
pub fn wait_until_recv<T: Send>(&self, recv: Receiver<T>) -> std::thread::Result<T> {
|
||||
loop {
|
||||
if let Some(result) = recv.poll() {
|
||||
break result;
|
||||
}
|
||||
|
||||
// process local jobs before locking shared context
|
||||
while let Some(job) = self.pop_front() {
|
||||
unsafe {
|
||||
job.as_ref().set_pending();
|
||||
}
|
||||
self.execute(job);
|
||||
}
|
||||
|
||||
while !latch.probe() {
|
||||
let job = self.context.shared().pop_job();
|
||||
|
||||
match job {
|
||||
Some(job) => {
|
||||
self.execute(job);
|
||||
|
||||
continue 'outer;
|
||||
}
|
||||
None => {
|
||||
tracing::trace!("waiting for shared job, thread id: {:?}", self.index);
|
||||
|
||||
// TODO: wait on latch? if we have something that can
|
||||
// signal being done, e.g. can be waited on instead of
|
||||
// shared jobs, we should wait on it instead, but we
|
||||
// would also want to receive shared jobs still?
|
||||
// Spin? probably just wastes CPU time.
|
||||
// self.context.shared_job.wait(&mut guard);
|
||||
// if spin.spin() {
|
||||
// // wait for more shared jobs.
|
||||
// // self.context.shared_job.wait(&mut guard);
|
||||
// return;
|
||||
// }
|
||||
// Yield? same as spinning, really, so just exit and let the upstream use wait
|
||||
// std::thread::yield_now();
|
||||
|
||||
self.heartbeat.latch.wait_and_reset();
|
||||
// since we were sleeping, the shared job can't be populated,
|
||||
// so resuming the inner loop is fine.
|
||||
}
|
||||
}
|
||||
match self.receiver.recv() {
|
||||
Message::Shared(shared_job) => unsafe {
|
||||
SharedJob::execute(shared_job, self);
|
||||
},
|
||||
Message::WakeUp | Message::Exit | Message::ScopeFinished => {}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
pub fn wait_until_job<T>(&self, job: &Job<T>, latch: &CoreLatch) -> Option<JobResult<T>> {
|
||||
// we've already checked that the job was popped from the queue
|
||||
// check if shared job is our job
|
||||
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
|
||||
if core::ptr::eq(shared_job.as_ptr(), job as *const Job<T> as _) {
|
||||
// this is the job we are looking for, so we want to
|
||||
// short-circuit and call it inline
|
||||
return None;
|
||||
} else {
|
||||
// this isn't the job we are looking for, but we still need to
|
||||
// execute it
|
||||
self.execute(shared_job);
|
||||
}
|
||||
}
|
||||
|
||||
// do the usual thing and wait for the job's latch
|
||||
if !latch.probe() {
|
||||
self.wait_until_latch_cold(latch);
|
||||
}
|
||||
|
||||
Some(job.wait())
|
||||
}
|
||||
|
||||
pub fn wait_until_latch<L>(&self, latch: &L)
|
||||
where
|
||||
L: AsCoreLatch,
|
||||
{
|
||||
let latch = latch.as_core_latch();
|
||||
if !latch.probe() {
|
||||
self.wait_until_latch_cold(latch)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn wait_until_predicate<F>(&self, pred: F)
|
||||
where
|
||||
F: Fn() -> bool,
|
||||
{
|
||||
'outer: while !pred() {
|
||||
// take a shared job, if it exists
|
||||
if let Some(shared_job) = self.context.shared().jobs.remove(&self.index) {
|
||||
self.execute(shared_job);
|
||||
}
|
||||
|
||||
// process local jobs before locking shared context
|
||||
while let Some(job) = self.pop_front() {
|
||||
unsafe {
|
||||
job.as_ref().set_pending();
|
||||
}
|
||||
self.execute(job);
|
||||
}
|
||||
|
||||
while !pred() {
|
||||
let mut guard = self.context.shared();
|
||||
let mut _spin = SpinWait::new();
|
||||
|
||||
match guard.pop_job() {
|
||||
Some(job) => {
|
||||
drop(guard);
|
||||
self.execute(job);
|
||||
|
||||
continue 'outer;
|
||||
}
|
||||
None => {
|
||||
tracing::trace!("waiting for shared job, thread id: {:?}", self.index);
|
||||
|
||||
// TODO: wait on latch? if we have something that can
|
||||
// signal being done, e.g. can be waited on instead of
|
||||
// shared jobs, we should wait on it instead, but we
|
||||
// would also want to receive shared jobs still?
|
||||
// Spin? probably just wastes CPU time.
|
||||
// self.context.shared_job.wait(&mut guard);
|
||||
// if spin.spin() {
|
||||
// // wait for more shared jobs.
|
||||
// // self.context.shared_job.wait(&mut guard);
|
||||
// return;
|
||||
// }
|
||||
// Yield? same as spinning, really, so just exit and let the upstream use wait
|
||||
// std::thread::yield_now();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
173
examples/join.rs
Normal file
173
examples/join.rs
Normal file
|
@ -0,0 +1,173 @@
|
|||
use executor::praetor::{Scope, ThreadPool};
|
||||
|
||||
use executor::util::tree::Tree;
|
||||
|
||||
const TREE_SIZE: usize = 16;
|
||||
|
||||
fn join_scope(tree_size: usize) {
|
||||
let pool = ThreadPool::new();
|
||||
|
||||
let tree = Tree::new(tree_size, 1);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize, scope: &Scope) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = scope.join(
|
||||
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
||||
|s| {
|
||||
node.right
|
||||
.map(|node| sum(tree, node, s))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
);
|
||||
|
||||
// eprintln!("node: {node:?}, l: {l}, r: {r}");
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
for _ in 0..1000 {
|
||||
let sum = pool.scope(|s| sum(&tree, tree.root().unwrap(), s));
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
}
|
||||
|
||||
fn join_pool(tree_size: usize) {
|
||||
let pool = ThreadPool::new();
|
||||
|
||||
let tree = Tree::new(tree_size, 1);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize, pool: &ThreadPool) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = pool.join(
|
||||
|| {
|
||||
node.left
|
||||
.map(|node| sum(tree, node, pool))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
|| {
|
||||
node.right
|
||||
.map(|node| sum(tree, node, pool))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
);
|
||||
|
||||
// eprintln!("node: {node:?}, l: {l}, r: {r}");
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
let sum = sum(&tree, tree.root().unwrap(), &pool);
|
||||
eprintln!("sum: {sum}");
|
||||
}
|
||||
|
||||
fn join_distaff(tree_size: usize) {
|
||||
use distaff::*;
|
||||
let pool = ThreadPool::new_with_threads(6);
|
||||
|
||||
let tree = Tree::new(tree_size, 1);
|
||||
|
||||
fn sum<'scope, 'env>(tree: &Tree<u32>, node: usize, scope: Scope<'scope, 'env>) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = scope.join(
|
||||
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
||||
|s| {
|
||||
node.right
|
||||
.map(|node| sum(tree, node, s))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
);
|
||||
|
||||
// eprintln!("node: {node:?}, l: {l}, r: {r}");
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
for _ in 0..1000 {
|
||||
let sum = pool.scope(|s| {
|
||||
let sum = sum(&tree, tree.root().unwrap(), s);
|
||||
sum
|
||||
});
|
||||
eprintln!("sum: {sum}");
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
}
|
||||
|
||||
fn join_chili(tree_size: usize) {
|
||||
let tree = Tree::new(tree_size, 1u32);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize, scope: &mut chili::Scope<'_>) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = scope.join(
|
||||
|s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
|
||||
|s| {
|
||||
node.right
|
||||
.map(|node| sum(tree, node, s))
|
||||
.unwrap_or_default()
|
||||
},
|
||||
);
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
for _ in 0..1000 {
|
||||
sum(&tree, tree.root().unwrap(), &mut chili::Scope::global());
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
}
|
||||
|
||||
fn join_rayon(tree_size: usize) {
|
||||
let tree = Tree::new(tree_size, 1u32);
|
||||
|
||||
fn sum(tree: &Tree<u32>, node: usize) -> u32 {
|
||||
let node = tree.get(node);
|
||||
let (l, r) = rayon::join(
|
||||
|| node.left.map(|node| sum(tree, node)).unwrap_or_default(),
|
||||
|| node.right.map(|node| sum(tree, node)).unwrap_or_default(),
|
||||
);
|
||||
|
||||
node.leaf + l + r
|
||||
}
|
||||
|
||||
for _ in 0..1000 {
|
||||
let sum = sum(&tree, tree.root().unwrap());
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// tracing_subscriber::fmt::init();
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
tracing::subscriber::set_global_default(
|
||||
tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()),
|
||||
)
|
||||
.expect("Failed to set global default subscriber");
|
||||
|
||||
eprintln!("Press Enter to start profiling...");
|
||||
std::io::stdin().read_line(&mut String::new()).unwrap();
|
||||
|
||||
let size = std::env::args()
|
||||
.nth(2)
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(TREE_SIZE);
|
||||
|
||||
match std::env::args().nth(1).as_deref() {
|
||||
Some("scope") => join_scope(size),
|
||||
Some("pool") => join_pool(size),
|
||||
Some("chili") => join_chili(size),
|
||||
Some("distaff") => join_distaff(size),
|
||||
Some("rayon") => join_rayon(size),
|
||||
_ => {
|
||||
eprintln!(
|
||||
"Usage: {} [scope|pool|chili|distaff|rayon] <tree_size={}>",
|
||||
std::env::args().next().unwrap(),
|
||||
TREE_SIZE
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("Done!");
|
||||
println!("Done!");
|
||||
// // wait for user input before exiting
|
||||
// std::io::stdin().read_line(&mut String::new()).unwrap();
|
||||
}
|
|
@ -813,7 +813,7 @@ mod job {
|
|||
}
|
||||
}
|
||||
|
||||
/// call this when popping value from local queue
|
||||
/// must be called before `execute()`
|
||||
pub fn set_pending(&self) {
|
||||
let mut spin = SpinWait::new();
|
||||
loop {
|
||||
|
|
Loading…
Reference in a new issue