vidya/crates/renderer/src/sync.rs
Janis f0fff72bce idek so much when i thought i was only doing egui integration
egui cant draw yet, but textures are loaded/updated
2024-12-29 15:48:55 +01:00

301 lines
9.2 KiB
Rust

use std::{
future::Future,
marker::PhantomData,
sync::{atomic::AtomicU32, Arc},
time::Duration,
};
use super::Device;
use ash::{prelude::*, vk};
use crossbeam::channel::{Receiver, Sender};
type Message = (SyncPrimitive, std::task::Waker);
pub struct SyncThreadpool {
channel: (Sender<Message>, Receiver<Message>),
timeout: u64,
thread_dies_after: Duration,
max_threads: u32,
num_threads: Arc<AtomicU32>,
}
#[derive(Debug)]
enum SyncPrimitive {
Fence(Arc<Fence>),
// actually, I think this is an awful idea because I would have to hold a
// lock on all queues.
DeviceIdle(Device),
}
impl SyncThreadpool {
pub fn new() -> SyncThreadpool {
Self::with_max_threads(512)
}
pub fn with_max_threads(max_threads: u32) -> SyncThreadpool {
Self {
// 0-capacity channel to ensure immediate consumption of fences
channel: crossbeam::channel::bounded(0),
max_threads,
num_threads: Arc::new(AtomicU32::new(0)),
timeout: u64::MAX,
thread_dies_after: Duration::from_secs(5),
}
}
fn try_spawn_thread(&self) -> Option<()> {
use std::sync::atomic::Ordering;
match self
.num_threads
.fetch_update(Ordering::Release, Ordering::Acquire, |i| {
if i < self.max_threads {
Some(i + 1)
} else {
None
}
}) {
Ok(tid) => {
struct SyncThread {
timeout: u64,
thread_dies_after: Duration,
num_threads: Arc<AtomicU32>,
rx: Receiver<Message>,
}
impl SyncThread {
fn run(self, barrier: Arc<std::sync::Barrier>) {
tracing::info!("spawned new sync thread");
barrier.wait();
while let Ok((sync, waker)) = self.rx.recv_timeout(self.thread_dies_after) {
tracing::info!("received ({:?}, {:?})", sync, waker);
loop {
let wait_result = match &sync {
SyncPrimitive::Fence(fence) => {
fence.wait_on(Some(self.timeout))
}
SyncPrimitive::DeviceIdle(device) => device.wait_idle(),
};
match wait_result {
Ok(_) => {
waker.wake();
break;
}
Err(vk::Result::TIMEOUT) => {}
Err(err) => {
tracing::error!(
"failed to wait on {sync:?} in waiter thread: {err}"
);
break;
}
}
}
}
// because I don't want some thread to not spawn as soon as this one exists
self.num_threads.fetch_sub(1, Ordering::AcqRel);
}
}
let thread = SyncThread {
timeout: self.timeout,
thread_dies_after: self.thread_dies_after,
num_threads: self.num_threads.clone(),
rx: self.channel.1.clone(),
};
let barrier = Arc::new(std::sync::Barrier::new(2));
std::thread::Builder::new()
.name(format!("fence-waiter-{tid}"))
.spawn({
let barrier = barrier.clone();
move || {
thread.run(barrier);
}
});
barrier.wait();
Some(())
}
Err(_) => {
tracing::error!(
"sync-threadpool exceeded local thread limit ({})",
self.max_threads
);
None
}
}
}
fn spawn_waiter(&self, fence: Arc<Fence>, waker: std::task::Waker) {
use std::sync::atomic::Ordering;
let mut msg = (SyncPrimitive::Fence(fence), waker);
while let Err(err) = self.channel.0.try_send(msg) {
match err {
crossbeam::channel::TrySendError::Full(msg2) => {
msg = msg2;
self.try_spawn_thread();
}
crossbeam::channel::TrySendError::Disconnected(_) => {
//tracing::error!("sync-threadpool channel disconnected?");
unreachable!()
}
}
}
}
}
pub struct Semaphore {
device: Device,
inner: vk::Semaphore,
}
pub struct Fence {
dev: Device,
fence: vk::Fence,
}
impl std::fmt::Debug for Fence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Fence").field("fence", &self.fence).finish()
}
}
impl Drop for Fence {
fn drop(&mut self) {
unsafe {
self.dev.dev().destroy_fence(self.fence, None);
}
}
}
impl Fence {
unsafe fn new(dev: Device, fence: vk::Fence) -> Fence {
Self { dev, fence }
}
pub fn create(dev: Device) -> VkResult<Fence> {
unsafe {
Ok(Self::new(
dev.clone(),
dev.dev()
.create_fence(&vk::FenceCreateInfo::default(), None)?,
))
}
}
pub fn create_signaled(dev: Device) -> VkResult<Fence> {
unsafe {
Ok(Self::new(
dev.clone(),
dev.dev().create_fence(
&vk::FenceCreateInfo::default().flags(vk::FenceCreateFlags::SIGNALED),
None,
)?,
))
}
}
pub fn wait_on(&self, timeout: Option<u64>) -> Result<(), vk::Result> {
use core::slice::from_ref;
unsafe {
self.dev
.dev()
.wait_for_fences(from_ref(&self.fence), true, timeout.unwrap_or(u64::MAX))
}
}
pub fn fence(&self) -> vk::Fence {
self.fence
}
pub fn is_signaled(&self) -> bool {
unsafe { self.dev.dev().get_fence_status(self.fence).unwrap_or(false) }
}
pub fn reset(&self) -> Result<(), vk::Result> {
unsafe {
self.dev
.dev()
.reset_fences(core::slice::from_ref(&self.fence))
}
}
}
impl AsRef<vk::Fence> for Fence {
fn as_ref(&self) -> &vk::Fence {
todo!()
}
}
impl Semaphore {
pub fn new(device: Device) -> VkResult<Self> {
let mut type_info =
vk::SemaphoreTypeCreateInfo::default().semaphore_type(vk::SemaphoreType::BINARY);
let create_info = vk::SemaphoreCreateInfo::default().push_next(&mut type_info);
let inner = unsafe { device.dev().create_semaphore(&create_info, None)? };
Ok(Self { device, inner })
}
pub fn new_timeline(device: Device, value: u64) -> VkResult<Self> {
let mut type_info = vk::SemaphoreTypeCreateInfo::default()
.semaphore_type(vk::SemaphoreType::TIMELINE)
.initial_value(value);
let create_info = vk::SemaphoreCreateInfo::default().push_next(&mut type_info);
let inner = unsafe { device.dev().create_semaphore(&create_info, None)? };
Ok(Self { device, inner })
}
pub fn semaphore(&self) -> vk::Semaphore {
self.inner
}
}
impl Drop for Semaphore {
fn drop(&mut self) {
unsafe {
self.device.dev().destroy_semaphore(self.inner, None);
}
}
}
pub struct FenceFuture<'a> {
fence: Arc<Fence>,
// lifetime parameter to prevent release of resources while future is pendign
_pd: PhantomData<&'a ()>,
}
impl FenceFuture<'_> {
/// Unsafe because `fence` must not be destroyed while this future is live.
pub unsafe fn from_fence(device: Device, fence: vk::Fence) -> Self {
Self {
fence: Arc::new(Fence::new(device, fence)),
_pd: PhantomData,
}
}
pub fn new(fence: Arc<Fence>) -> Self {
Self {
fence,
_pd: PhantomData,
}
}
pub fn block(&self) -> VkResult<()> {
self.fence.wait_on(None)?;
self.fence.reset()
}
}
impl Future for FenceFuture<'_> {
type Output = ();
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if self.fence.is_signaled() {
tracing::info!("fence ({:?}) is signaled", self.fence);
_ = self.fence.reset();
std::task::Poll::Ready(())
} else {
self.fence
.dev
.sync_threadpool()
.spawn_waiter(self.fence.clone(), cx.waker().clone());
std::task::Poll::Pending
}
}
}