301 lines
9.2 KiB
Rust
301 lines
9.2 KiB
Rust
use std::{
|
|
future::Future,
|
|
marker::PhantomData,
|
|
sync::{atomic::AtomicU32, Arc},
|
|
time::Duration,
|
|
};
|
|
|
|
use super::Device;
|
|
use ash::{prelude::*, vk};
|
|
use crossbeam::channel::{Receiver, Sender};
|
|
|
|
type Message = (SyncPrimitive, std::task::Waker);
|
|
|
|
pub struct SyncThreadpool {
|
|
channel: (Sender<Message>, Receiver<Message>),
|
|
timeout: u64,
|
|
thread_dies_after: Duration,
|
|
max_threads: u32,
|
|
num_threads: Arc<AtomicU32>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum SyncPrimitive {
|
|
Fence(Arc<Fence>),
|
|
// actually, I think this is an awful idea because I would have to hold a
|
|
// lock on all queues.
|
|
DeviceIdle(Device),
|
|
}
|
|
|
|
impl SyncThreadpool {
|
|
pub fn new() -> SyncThreadpool {
|
|
Self::with_max_threads(512)
|
|
}
|
|
|
|
pub fn with_max_threads(max_threads: u32) -> SyncThreadpool {
|
|
Self {
|
|
// 0-capacity channel to ensure immediate consumption of fences
|
|
channel: crossbeam::channel::bounded(0),
|
|
max_threads,
|
|
num_threads: Arc::new(AtomicU32::new(0)),
|
|
timeout: u64::MAX,
|
|
thread_dies_after: Duration::from_secs(5),
|
|
}
|
|
}
|
|
|
|
fn try_spawn_thread(&self) -> Option<()> {
|
|
use std::sync::atomic::Ordering;
|
|
match self
|
|
.num_threads
|
|
.fetch_update(Ordering::Release, Ordering::Acquire, |i| {
|
|
if i < self.max_threads {
|
|
Some(i + 1)
|
|
} else {
|
|
None
|
|
}
|
|
}) {
|
|
Ok(tid) => {
|
|
struct SyncThread {
|
|
timeout: u64,
|
|
thread_dies_after: Duration,
|
|
num_threads: Arc<AtomicU32>,
|
|
rx: Receiver<Message>,
|
|
}
|
|
|
|
impl SyncThread {
|
|
fn run(self, barrier: Arc<std::sync::Barrier>) {
|
|
tracing::info!("spawned new sync thread");
|
|
barrier.wait();
|
|
while let Ok((sync, waker)) = self.rx.recv_timeout(self.thread_dies_after) {
|
|
tracing::info!("received ({:?}, {:?})", sync, waker);
|
|
loop {
|
|
let wait_result = match &sync {
|
|
SyncPrimitive::Fence(fence) => {
|
|
fence.wait_on(Some(self.timeout))
|
|
}
|
|
SyncPrimitive::DeviceIdle(device) => device.wait_idle(),
|
|
};
|
|
|
|
match wait_result {
|
|
Ok(_) => {
|
|
waker.wake();
|
|
break;
|
|
}
|
|
Err(vk::Result::TIMEOUT) => {}
|
|
Err(err) => {
|
|
tracing::error!(
|
|
"failed to wait on {sync:?} in waiter thread: {err}"
|
|
);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// because I don't want some thread to not spawn as soon as this one exists
|
|
self.num_threads.fetch_sub(1, Ordering::AcqRel);
|
|
}
|
|
}
|
|
|
|
let thread = SyncThread {
|
|
timeout: self.timeout,
|
|
thread_dies_after: self.thread_dies_after,
|
|
num_threads: self.num_threads.clone(),
|
|
rx: self.channel.1.clone(),
|
|
};
|
|
|
|
let barrier = Arc::new(std::sync::Barrier::new(2));
|
|
std::thread::Builder::new()
|
|
.name(format!("fence-waiter-{tid}"))
|
|
.spawn({
|
|
let barrier = barrier.clone();
|
|
move || {
|
|
thread.run(barrier);
|
|
}
|
|
});
|
|
barrier.wait();
|
|
Some(())
|
|
}
|
|
Err(_) => {
|
|
tracing::error!(
|
|
"sync-threadpool exceeded local thread limit ({})",
|
|
self.max_threads
|
|
);
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
fn spawn_waiter(&self, fence: Arc<Fence>, waker: std::task::Waker) {
|
|
use std::sync::atomic::Ordering;
|
|
|
|
let mut msg = (SyncPrimitive::Fence(fence), waker);
|
|
while let Err(err) = self.channel.0.try_send(msg) {
|
|
match err {
|
|
crossbeam::channel::TrySendError::Full(msg2) => {
|
|
msg = msg2;
|
|
self.try_spawn_thread();
|
|
}
|
|
crossbeam::channel::TrySendError::Disconnected(_) => {
|
|
//tracing::error!("sync-threadpool channel disconnected?");
|
|
unreachable!()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Semaphore {
|
|
device: Device,
|
|
inner: vk::Semaphore,
|
|
}
|
|
|
|
pub struct Fence {
|
|
dev: Device,
|
|
fence: vk::Fence,
|
|
}
|
|
|
|
impl std::fmt::Debug for Fence {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("Fence").field("fence", &self.fence).finish()
|
|
}
|
|
}
|
|
|
|
impl Drop for Fence {
|
|
fn drop(&mut self) {
|
|
unsafe {
|
|
self.dev.dev().destroy_fence(self.fence, None);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Fence {
|
|
unsafe fn new(dev: Device, fence: vk::Fence) -> Fence {
|
|
Self { dev, fence }
|
|
}
|
|
pub fn create(dev: Device) -> VkResult<Fence> {
|
|
unsafe {
|
|
Ok(Self::new(
|
|
dev.clone(),
|
|
dev.dev()
|
|
.create_fence(&vk::FenceCreateInfo::default(), None)?,
|
|
))
|
|
}
|
|
}
|
|
pub fn create_signaled(dev: Device) -> VkResult<Fence> {
|
|
unsafe {
|
|
Ok(Self::new(
|
|
dev.clone(),
|
|
dev.dev().create_fence(
|
|
&vk::FenceCreateInfo::default().flags(vk::FenceCreateFlags::SIGNALED),
|
|
None,
|
|
)?,
|
|
))
|
|
}
|
|
}
|
|
pub fn wait_on(&self, timeout: Option<u64>) -> Result<(), vk::Result> {
|
|
use core::slice::from_ref;
|
|
unsafe {
|
|
self.dev
|
|
.dev()
|
|
.wait_for_fences(from_ref(&self.fence), true, timeout.unwrap_or(u64::MAX))
|
|
}
|
|
}
|
|
pub fn fence(&self) -> vk::Fence {
|
|
self.fence
|
|
}
|
|
pub fn is_signaled(&self) -> bool {
|
|
unsafe { self.dev.dev().get_fence_status(self.fence).unwrap_or(false) }
|
|
}
|
|
pub fn reset(&self) -> Result<(), vk::Result> {
|
|
unsafe {
|
|
self.dev
|
|
.dev()
|
|
.reset_fences(core::slice::from_ref(&self.fence))
|
|
}
|
|
}
|
|
}
|
|
impl AsRef<vk::Fence> for Fence {
|
|
fn as_ref(&self) -> &vk::Fence {
|
|
todo!()
|
|
}
|
|
}
|
|
|
|
impl Semaphore {
|
|
pub fn new(device: Device) -> VkResult<Self> {
|
|
let mut type_info =
|
|
vk::SemaphoreTypeCreateInfo::default().semaphore_type(vk::SemaphoreType::BINARY);
|
|
let create_info = vk::SemaphoreCreateInfo::default().push_next(&mut type_info);
|
|
let inner = unsafe { device.dev().create_semaphore(&create_info, None)? };
|
|
|
|
Ok(Self { device, inner })
|
|
}
|
|
pub fn new_timeline(device: Device, value: u64) -> VkResult<Self> {
|
|
let mut type_info = vk::SemaphoreTypeCreateInfo::default()
|
|
.semaphore_type(vk::SemaphoreType::TIMELINE)
|
|
.initial_value(value);
|
|
let create_info = vk::SemaphoreCreateInfo::default().push_next(&mut type_info);
|
|
let inner = unsafe { device.dev().create_semaphore(&create_info, None)? };
|
|
|
|
Ok(Self { device, inner })
|
|
}
|
|
pub fn semaphore(&self) -> vk::Semaphore {
|
|
self.inner
|
|
}
|
|
}
|
|
|
|
impl Drop for Semaphore {
|
|
fn drop(&mut self) {
|
|
unsafe {
|
|
self.device.dev().destroy_semaphore(self.inner, None);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct FenceFuture<'a> {
|
|
fence: Arc<Fence>,
|
|
// lifetime parameter to prevent release of resources while future is pendign
|
|
_pd: PhantomData<&'a ()>,
|
|
}
|
|
|
|
impl FenceFuture<'_> {
|
|
/// Unsafe because `fence` must not be destroyed while this future is live.
|
|
pub unsafe fn from_fence(device: Device, fence: vk::Fence) -> Self {
|
|
Self {
|
|
fence: Arc::new(Fence::new(device, fence)),
|
|
_pd: PhantomData,
|
|
}
|
|
}
|
|
pub fn new(fence: Arc<Fence>) -> Self {
|
|
Self {
|
|
fence,
|
|
_pd: PhantomData,
|
|
}
|
|
}
|
|
pub fn block(&self) -> VkResult<()> {
|
|
self.fence.wait_on(None)?;
|
|
self.fence.reset()
|
|
}
|
|
}
|
|
|
|
impl Future for FenceFuture<'_> {
|
|
type Output = ();
|
|
|
|
fn poll(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Self::Output> {
|
|
if self.fence.is_signaled() {
|
|
tracing::info!("fence ({:?}) is signaled", self.fence);
|
|
_ = self.fence.reset();
|
|
std::task::Poll::Ready(())
|
|
} else {
|
|
self.fence
|
|
.dev
|
|
.sync_threadpool()
|
|
.spawn_waiter(self.fence.clone(), cx.waker().clone());
|
|
std::task::Poll::Pending
|
|
}
|
|
}
|
|
}
|