use std::{ future::Future, marker::PhantomData, sync::{atomic::AtomicU32, Arc}, time::Duration, }; use super::Device; use ash::{prelude::*, vk}; use crossbeam::channel::{Receiver, Sender}; type Message = (SyncPrimitive, std::task::Waker); pub struct SyncThreadpool { channel: (Sender, Receiver), timeout: u64, thread_dies_after: Duration, max_threads: u32, num_threads: Arc, } #[derive(Debug)] enum SyncPrimitive { Fence(Arc), // actually, I think this is an awful idea because I would have to hold a // lock on all queues. DeviceIdle(Device), } impl SyncThreadpool { pub fn new() -> SyncThreadpool { Self::with_max_threads(512) } pub fn with_max_threads(max_threads: u32) -> SyncThreadpool { Self { // 0-capacity channel to ensure immediate consumption of fences channel: crossbeam::channel::bounded(0), max_threads, num_threads: Arc::new(AtomicU32::new(0)), timeout: u64::MAX, thread_dies_after: Duration::from_secs(5), } } fn try_spawn_thread(&self) -> Option<()> { use std::sync::atomic::Ordering; match self .num_threads .fetch_update(Ordering::Release, Ordering::Acquire, |i| { if i < self.max_threads { Some(i + 1) } else { None } }) { Ok(tid) => { struct SyncThread { timeout: u64, thread_dies_after: Duration, num_threads: Arc, rx: Receiver, } impl SyncThread { fn run(self, barrier: Arc) { tracing::info!("spawned new sync thread"); barrier.wait(); while let Ok((sync, waker)) = self.rx.recv_timeout(self.thread_dies_after) { tracing::info!("received ({:?}, {:?})", sync, waker); loop { let wait_result = match &sync { SyncPrimitive::Fence(fence) => { fence.wait_on(Some(self.timeout)) } SyncPrimitive::DeviceIdle(device) => device.wait_idle(), }; match wait_result { Ok(_) => { waker.wake(); break; } Err(vk::Result::TIMEOUT) => {} Err(err) => { tracing::error!( "failed to wait on {sync:?} in waiter thread: {err}" ); break; } } } } // because I don't want some thread to not spawn as soon as this one exists self.num_threads.fetch_sub(1, Ordering::AcqRel); } } let thread = SyncThread { timeout: self.timeout, thread_dies_after: self.thread_dies_after, num_threads: self.num_threads.clone(), rx: self.channel.1.clone(), }; let barrier = Arc::new(std::sync::Barrier::new(2)); std::thread::Builder::new() .name(format!("fence-waiter-{tid}")) .spawn({ let barrier = barrier.clone(); move || { thread.run(barrier); } }); barrier.wait(); Some(()) } Err(_) => { tracing::error!( "sync-threadpool exceeded local thread limit ({})", self.max_threads ); None } } } fn spawn_waiter(&self, fence: Arc, waker: std::task::Waker) { use std::sync::atomic::Ordering; let mut msg = (SyncPrimitive::Fence(fence), waker); while let Err(err) = self.channel.0.try_send(msg) { match err { crossbeam::channel::TrySendError::Full(msg2) => { msg = msg2; self.try_spawn_thread(); } crossbeam::channel::TrySendError::Disconnected(_) => { //tracing::error!("sync-threadpool channel disconnected?"); unreachable!() } } } } } pub struct Semaphore { device: Device, inner: vk::Semaphore, } pub struct Fence { dev: Device, fence: vk::Fence, } impl std::fmt::Debug for Fence { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Fence").field("fence", &self.fence).finish() } } impl Drop for Fence { fn drop(&mut self) { unsafe { self.dev.dev().destroy_fence(self.fence, None); } } } impl Fence { unsafe fn new(dev: Device, fence: vk::Fence) -> Fence { Self { dev, fence } } pub fn create(dev: Device) -> VkResult { unsafe { Ok(Self::new( dev.clone(), dev.dev() .create_fence(&vk::FenceCreateInfo::default(), None)?, )) } } pub fn create_signaled(dev: Device) -> VkResult { unsafe { Ok(Self::new( dev.clone(), dev.dev().create_fence( &vk::FenceCreateInfo::default().flags(vk::FenceCreateFlags::SIGNALED), None, )?, )) } } pub fn wait_on(&self, timeout: Option) -> Result<(), vk::Result> { use core::slice::from_ref; unsafe { self.dev .dev() .wait_for_fences(from_ref(&self.fence), true, timeout.unwrap_or(u64::MAX)) } } pub fn fence(&self) -> vk::Fence { self.fence } pub fn is_signaled(&self) -> bool { unsafe { self.dev.dev().get_fence_status(self.fence).unwrap_or(false) } } pub fn reset(&self) -> Result<(), vk::Result> { unsafe { self.dev .dev() .reset_fences(core::slice::from_ref(&self.fence)) } } } impl AsRef for Fence { fn as_ref(&self) -> &vk::Fence { todo!() } } impl Semaphore { pub fn new(device: Device) -> VkResult { let mut type_info = vk::SemaphoreTypeCreateInfo::default().semaphore_type(vk::SemaphoreType::BINARY); let create_info = vk::SemaphoreCreateInfo::default().push_next(&mut type_info); let inner = unsafe { device.dev().create_semaphore(&create_info, None)? }; Ok(Self { device, inner }) } pub fn new_timeline(device: Device, value: u64) -> VkResult { let mut type_info = vk::SemaphoreTypeCreateInfo::default() .semaphore_type(vk::SemaphoreType::TIMELINE) .initial_value(value); let create_info = vk::SemaphoreCreateInfo::default().push_next(&mut type_info); let inner = unsafe { device.dev().create_semaphore(&create_info, None)? }; Ok(Self { device, inner }) } pub fn semaphore(&self) -> vk::Semaphore { self.inner } } impl Drop for Semaphore { fn drop(&mut self) { unsafe { self.device.dev().destroy_semaphore(self.inner, None); } } } pub struct FenceFuture<'a> { fence: Arc, // lifetime parameter to prevent release of resources while future is pendign _pd: PhantomData<&'a ()>, } impl FenceFuture<'_> { /// Unsafe because `fence` must not be destroyed while this future is live. pub unsafe fn from_fence(device: Device, fence: vk::Fence) -> Self { Self { fence: Arc::new(Fence::new(device, fence)), _pd: PhantomData, } } pub fn new(fence: Arc) -> Self { Self { fence, _pd: PhantomData, } } pub fn block(&self) -> VkResult<()> { self.fence.wait_on(None)?; self.fence.reset() } } impl Future for FenceFuture<'_> { type Output = (); fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { if self.fence.is_signaled() { tracing::info!("fence ({:?}) is signaled", self.fence); _ = self.fence.reset(); std::task::Poll::Ready(()) } else { self.fence .dev .sync_threadpool() .spawn_waiter(self.fence.clone(), cx.waker().clone()); std::task::Poll::Pending } } }