vidya/crates/renderer/src/pipeline.rs
2025-01-01 02:00:23 +01:00

725 lines
25 KiB
Rust

use std::{borrow::Cow, path::Path, sync::Arc};
use ash::{prelude::*, vk};
use crate::{
define_device_owned_handle,
device::{Device, DeviceOwnedDebugObject},
};
#[derive(Debug)]
pub struct ShaderStageDesc<'a> {
pub flags: vk::PipelineShaderStageCreateFlags,
pub module: &'a ShaderModule,
pub stage: vk::ShaderStageFlags,
pub entry: Cow<'a, std::ffi::CStr>,
// specialization: Option<vk::SpecializationInfo>
}
#[derive(Debug, Default)]
pub struct DescriptorSetLayoutBindingDesc {
pub binding: u32,
pub count: u32,
pub kind: vk::DescriptorType,
pub stage: vk::ShaderStageFlags,
pub flags: Option<vk::DescriptorBindingFlags>,
}
#[derive(Debug, Default)]
pub struct DescriptorSetLayoutDesc<'a> {
pub flags: vk::DescriptorSetLayoutCreateFlags,
pub bindings: &'a [DescriptorSetLayoutBindingDesc],
pub name: Option<Cow<'static, str>>,
}
#[derive(Debug, Default)]
pub struct PipelineLayoutDesc<'a> {
pub descriptor_set_layouts: &'a [&'a DescriptorSetLayout],
pub push_constant_ranges: &'a [vk::PushConstantRange],
pub name: Option<Cow<'static, str>>,
}
#[derive(Debug)]
pub enum PipelineDesc<'a> {
Compute(ComputePipelineDesc<'a>),
Graphics(GraphicsPipelineDesc<'a>),
}
#[derive(Debug)]
pub struct ComputePipelineDesc<'a> {
pub flags: vk::PipelineCreateFlags,
pub name: Option<Cow<'static, str>>,
pub shader_stage: ShaderStageDesc<'a>,
pub layout: Arc<PipelineLayout>,
pub base_pipeline: Option<Arc<Pipeline>>,
}
#[derive(Debug, Default)]
pub struct VertexInputState<'a> {
// pub flags: vk::PipelineVertexInputStateCreateFlags,
pub bindings: &'a [vk::VertexInputBindingDescription],
pub attributes: &'a [vk::VertexInputAttributeDescription],
}
#[derive(Debug, Default)]
pub struct TessellationState {
pub flags: vk::PipelineTessellationStateCreateFlags,
pub patch_control_points: u32,
}
#[derive(Debug, Default)]
pub struct InputAssemblyState {
// pub flags: vk::PipelineInputAssemblyStateCreateFlags,
pub topology: vk::PrimitiveTopology,
pub primitive_restart: bool,
}
#[derive(Debug, Default)]
pub struct ViewportState<'a> {
pub num_scissors: u32,
pub scissors: Option<&'a [vk::Rect2D]>,
pub num_viewports: u32,
pub viewports: Option<&'a [vk::Viewport]>,
}
#[derive(Debug, Default)]
pub struct DepthBiasState {
pub clamp: f32,
pub constant_factor: f32,
pub slope_factor: f32,
}
#[derive(Debug)]
pub struct RasterizationState {
pub depth_clamp_enable: bool,
pub discard_enable: bool,
pub line_width: f32,
pub cull_mode: vk::CullModeFlags,
pub depth_bias: Option<DepthBiasState>,
pub polygon_mode: vk::PolygonMode,
}
impl Default for RasterizationState {
fn default() -> Self {
Self {
depth_clamp_enable: false,
line_width: 1.0,
cull_mode: vk::CullModeFlags::BACK,
depth_bias: Default::default(),
polygon_mode: vk::PolygonMode::FILL,
discard_enable: false,
}
}
}
#[derive(Debug)]
pub struct MultisampleState<'a> {
pub flags: vk::PipelineMultisampleStateCreateFlags,
pub sample_shading_enable: bool,
pub rasterization_samples: vk::SampleCountFlags,
pub min_sample_shading: f32,
pub sample_mask: &'a [vk::SampleMask],
pub alpha_to_coverage_enable: bool,
pub alpha_to_one_enable: bool,
}
impl<'a> Default for MultisampleState<'a> {
fn default() -> Self {
Self {
flags: Default::default(),
sample_shading_enable: Default::default(),
rasterization_samples: vk::SampleCountFlags::TYPE_1,
min_sample_shading: 1.0,
sample_mask: Default::default(),
alpha_to_coverage_enable: Default::default(),
alpha_to_one_enable: Default::default(),
}
}
}
#[derive(Debug)]
pub struct DepthBounds {
pub min: f32,
pub max: f32,
}
#[derive(Debug, Default)]
pub struct DepthState {
pub write_enable: bool,
/// sets depthTestEnable to true when `Some`
pub compare_op: Option<vk::CompareOp>,
/// sets depthBoundsTestEnable to true when `Some`
pub bounds: Option<DepthBounds>,
}
#[derive(Debug, Default)]
pub struct StencilState {
pub front: vk::StencilOpState,
pub back: vk::StencilOpState,
}
#[derive(Debug, Default)]
pub struct DepthStencilState {
pub flags: vk::PipelineDepthStencilStateCreateFlags,
pub depth: Option<DepthState>,
pub stencil: Option<StencilState>,
}
#[derive(Debug, Default)]
pub struct ColorBlendState<'a> {
pub flags: vk::PipelineColorBlendStateCreateFlags,
pub attachments: &'a [vk::PipelineColorBlendAttachmentState],
pub logic_op: Option<vk::LogicOp>,
pub blend_constants: [f32; 4],
}
#[derive(Debug, Default)]
pub struct RenderingState<'a> {
pub color_formats: &'a [vk::Format],
pub depth_format: Option<vk::Format>,
pub stencil_format: Option<vk::Format>,
}
#[derive(Debug, Default)]
pub struct DynamicState<'a> {
pub flags: vk::PipelineDynamicStateCreateFlags,
pub dynamic_states: &'a [vk::DynamicState],
}
#[derive(Debug)]
pub struct GraphicsPipelineDesc<'a> {
pub flags: vk::PipelineCreateFlags,
pub name: Option<Cow<'static, str>>,
pub shader_stages: &'a [ShaderStageDesc<'a>],
pub render_pass: Option<vk::RenderPass>,
pub layout: &'a PipelineLayout,
pub subpass: Option<u32>,
pub base_pipeline: Option<Arc<Pipeline>>,
pub vertex_input: Option<VertexInputState<'a>>,
pub input_assembly: Option<InputAssemblyState>,
pub tessellation: Option<TessellationState>,
pub viewport: Option<ViewportState<'a>>,
pub rasterization: Option<RasterizationState>,
pub multisample: Option<MultisampleState<'a>>,
pub depth_stencil: Option<DepthStencilState>,
pub color_blend: Option<ColorBlendState<'a>>,
pub dynamic: Option<DynamicState<'a>>,
pub rendering: Option<RenderingState<'a>>,
}
#[derive(Debug, Default)]
pub struct DescriptorPoolDesc<'a> {
pub flags: vk::DescriptorPoolCreateFlags,
pub name: Option<Cow<'static, str>>,
pub sizes: &'a [vk::DescriptorPoolSize],
pub max_sets: u32,
}
#[derive(Debug)]
pub struct DescriptorSetAllocDesc<'a> {
pub name: Option<Cow<'static, str>>,
pub layout: &'a DescriptorSetLayout,
}
define_device_owned_handle! {
#[derive(Debug)]
pub DescriptorPool(vk::DescriptorPool) {} => |this| unsafe {
this.device().dev().destroy_descriptor_pool(this.handle(), None);
}
}
impl DescriptorPool {
pub fn new(device: Device, desc: DescriptorPoolDesc) -> VkResult<Self> {
let info = &vk::DescriptorPoolCreateInfo::default()
.flags(desc.flags)
.max_sets(desc.max_sets)
.pool_sizes(desc.sizes);
let handle = unsafe { device.dev().create_descriptor_pool(info, None)? };
Self::construct(device, handle, desc.name)
}
pub fn allocate(&self, descs: &[DescriptorSetAllocDesc]) -> VkResult<Vec<vk::DescriptorSet>> {
let layouts = descs
.iter()
.map(|desc| desc.layout.handle())
.collect::<Vec<_>>();
let info = &vk::DescriptorSetAllocateInfo::default()
.descriptor_pool(self.handle())
.set_layouts(&layouts);
let sets = unsafe { self.device().dev().allocate_descriptor_sets(&info)? };
for (&set, desc) in sets.iter().zip(descs) {
if let Some(name) = desc.name.as_ref() {
self.device().debug_name_object(set, &name)?;
}
}
Ok(sets)
}
// pub fn free(&self) {}
#[allow(dead_code)]
pub fn reset(&self) -> VkResult<()> {
unsafe {
self.device()
.dev()
.reset_descriptor_pool(self.handle(), vk::DescriptorPoolResetFlags::empty())
}
}
}
define_device_owned_handle! {
#[derive(Debug)]
pub DescriptorSetLayout(vk::DescriptorSetLayout) {} => |this| unsafe {
this.device().dev().destroy_descriptor_set_layout(this.handle(), None);
}
}
impl DescriptorSetLayout {
pub fn new(device: Device, desc: DescriptorSetLayoutDesc) -> VkResult<Self> {
let (flags, bindings): (Vec<_>, Vec<_>) = desc
.bindings
.iter()
.map(|binding| {
let flag = binding.flags.unwrap_or_default();
let binding = vk::DescriptorSetLayoutBinding::default()
.binding(binding.binding)
.descriptor_count(binding.count)
.descriptor_type(binding.kind)
.stage_flags(binding.stage);
(flag, binding)
})
.unzip();
let flags =
&mut vk::DescriptorSetLayoutBindingFlagsCreateInfo::default().binding_flags(&flags);
let mut info = vk::DescriptorSetLayoutCreateInfo::default()
.bindings(&bindings)
.flags(desc.flags);
if device.features().version >= vk::API_VERSION_1_2
|| device
.features()
.supports_extension(&crate::make_extention_properties(
ash::ext::descriptor_indexing::NAME,
ash::ext::descriptor_indexing::SPEC_VERSION,
))
{
info = info.push_next(flags);
}
let layout = unsafe { device.dev().create_descriptor_set_layout(&info, None)? };
Self::construct(device, layout, desc.name)
}
}
use crate::device::DeviceOwned;
define_device_owned_handle! {
#[derive(Debug)]
pub PipelineLayout(vk::PipelineLayout) {} => |this| unsafe {
this.device().dev().destroy_pipeline_layout(this.handle(), None);
}
}
impl PipelineLayout {
pub fn new(device: Device, desc: PipelineLayoutDesc) -> VkResult<Self> {
let set_layouts = desc
.descriptor_set_layouts
.iter()
.map(|desc| desc.handle())
.collect::<Vec<_>>();
let info = &vk::PipelineLayoutCreateInfo::default()
.set_layouts(&set_layouts)
.push_constant_ranges(desc.push_constant_ranges);
let layout = unsafe { device.dev().create_pipeline_layout(info, None)? };
Self::construct(device, layout, desc.name)
}
}
#[derive(Debug, Default)]
pub struct SamplerDesc {
pub flags: vk::SamplerCreateFlags,
pub min_filter: vk::Filter,
pub mag_filter: vk::Filter,
pub mipmap_mode: vk::SamplerMipmapMode,
pub address_u: vk::SamplerAddressMode,
pub address_v: vk::SamplerAddressMode,
pub address_w: vk::SamplerAddressMode,
pub mip_lod_bias: f32,
pub anisotropy_enable: bool,
pub max_anisotropy: f32,
pub compare_op: Option<vk::CompareOp>,
pub min_lod: f32,
pub max_lod: f32,
pub border_color: vk::BorderColor,
pub unnormalized_coordinates: bool,
}
impl Eq for SamplerDesc {}
impl PartialEq for SamplerDesc {
fn eq(&self, other: &Self) -> bool {
use crate::util::eq_f32;
self.flags == other.flags
&& self.min_filter == other.min_filter
&& self.mag_filter == other.mag_filter
&& self.mipmap_mode == other.mipmap_mode
&& self.address_u == other.address_u
&& self.address_v == other.address_v
&& self.address_w == other.address_w
&& self.anisotropy_enable == other.anisotropy_enable
&& self.compare_op == other.compare_op
&& eq_f32(self.mip_lod_bias, other.mip_lod_bias)
&& eq_f32(self.max_anisotropy, other.max_anisotropy)
&& eq_f32(self.min_lod, other.min_lod)
&& eq_f32(self.max_lod, other.max_lod)
&& self.border_color == other.border_color
&& self.unnormalized_coordinates == other.unnormalized_coordinates
}
}
impl std::hash::Hash for SamplerDesc {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
use crate::util::hash_f32;
self.flags.hash(state);
self.min_filter.hash(state);
self.mag_filter.hash(state);
self.mipmap_mode.hash(state);
self.address_u.hash(state);
self.address_v.hash(state);
self.address_w.hash(state);
hash_f32(state, self.mip_lod_bias);
hash_f32(state, self.max_anisotropy);
hash_f32(state, self.min_lod);
hash_f32(state, self.max_lod);
self.anisotropy_enable.hash(state);
self.compare_op.hash(state);
self.border_color.hash(state);
self.unnormalized_coordinates.hash(state);
}
}
define_device_owned_handle! {
#[derive(Debug)]
pub Sampler(vk::Sampler) {} => |this| unsafe {
this.device().dev().destroy_sampler(this.handle(), None);
}
}
impl Sampler {
pub fn new(device: Device, desc: &SamplerDesc) -> VkResult<Self> {
let info = &vk::SamplerCreateInfo::default()
.flags(desc.flags)
.min_filter(desc.min_filter)
.mag_filter(desc.mag_filter)
.mip_lod_bias(desc.mip_lod_bias)
.mipmap_mode(desc.mipmap_mode)
.address_mode_u(desc.address_u)
.address_mode_v(desc.address_v)
.address_mode_w(desc.address_w)
.anisotropy_enable(desc.anisotropy_enable)
.max_anisotropy(desc.max_anisotropy)
.compare_enable(desc.compare_op.is_some())
.compare_op(desc.compare_op.unwrap_or_default())
.min_lod(desc.min_lod)
.max_lod(desc.max_lod)
.border_color(desc.border_color)
.unnormalized_coordinates(desc.unnormalized_coordinates);
let handle = unsafe { device.dev().create_sampler(info, None)? };
Self::construct(device, handle, None)
}
}
define_device_owned_handle! {
#[derive(Debug)]
pub ShaderModule(vk::ShaderModule) {} => |this| unsafe {
this.device().dev().destroy_shader_module(this.handle(), None);
}
}
impl ShaderModule {
pub fn new_from_path<P: AsRef<Path>>(device: Device, path: P) -> crate::Result<Self> {
use std::io::{BufReader, Read, Seek};
let mut file = std::fs::File::open(path)?;
let size = file.seek(std::io::SeekFrom::End(0))? / 4;
file.seek(std::io::SeekFrom::Start(0))?;
let mut reader = BufReader::new(file);
let mut buffer = Vec::<u32>::with_capacity(size as usize);
buffer.resize(size as usize, 0);
let size = reader.read(bytemuck::cast_slice_mut(buffer.as_mut_slice()))?;
buffer.resize(size / 4, 0);
Ok(Self::new_from_memory(device, &buffer)?)
}
pub fn new_from_memory(device: Device, buffer: &[u32]) -> VkResult<Self> {
let info = &vk::ShaderModuleCreateInfo::default().code(buffer);
let module = unsafe { device.dev().create_shader_module(info, None)? };
Self::construct(device, module, None)
}
}
#[derive(Debug)]
pub struct Pipeline {
pipeline: DeviceOwnedDebugObject<vk::Pipeline>,
bind_point: vk::PipelineBindPoint,
}
impl Drop for Pipeline {
fn drop(&mut self) {
unsafe {
self.pipeline
.dev()
.dev()
.destroy_pipeline(self.pipeline.handle(), None);
}
}
}
impl ShaderStageDesc<'_> {
fn into_create_info(&self) -> vk::PipelineShaderStageCreateInfo {
vk::PipelineShaderStageCreateInfo::default()
.module(self.module.handle())
.flags(self.flags)
.stage(self.stage)
.name(&self.entry)
}
}
impl Pipeline {
pub fn new(device: Device, desc: PipelineDesc) -> VkResult<Self> {
let name: Option<Cow<'static, str>>;
let bind_point: vk::PipelineBindPoint;
let result = match desc {
PipelineDesc::Compute(desc) => {
name = desc.name;
bind_point = vk::PipelineBindPoint::COMPUTE;
let info = &vk::ComputePipelineCreateInfo::default()
.flags(desc.flags)
.layout(desc.layout.handle())
.base_pipeline_handle(
desc.base_pipeline
.map(|p| p.handle())
.unwrap_or(vk::Pipeline::null()),
)
.stage(desc.shader_stage.into_create_info());
unsafe {
device.dev().create_compute_pipelines(
vk::PipelineCache::null(),
core::slice::from_ref(info),
None,
)
}
}
PipelineDesc::Graphics(desc) => {
name = desc.name;
bind_point = vk::PipelineBindPoint::GRAPHICS;
let stages = desc
.shader_stages
.iter()
.map(|stage| stage.into_create_info())
.collect::<Vec<_>>();
let vertex_input = desc.vertex_input.map(|vertex| {
vk::PipelineVertexInputStateCreateInfo::default()
.vertex_attribute_descriptions(vertex.attributes)
.vertex_binding_descriptions(vertex.bindings)
});
let input_assembly = desc.input_assembly.map(|state| {
vk::PipelineInputAssemblyStateCreateInfo::default()
.primitive_restart_enable(state.primitive_restart)
.topology(state.topology)
});
let tessellation = desc.tessellation.map(|state| {
vk::PipelineTessellationStateCreateInfo::default()
.flags(state.flags)
.patch_control_points(state.patch_control_points)
});
let viewport = desc.viewport.map(|state| {
let mut info = vk::PipelineViewportStateCreateInfo::default()
.scissor_count(state.num_scissors)
.viewport_count(state.num_viewports);
if let Some(viewports) = state.viewports {
info = info.viewports(viewports);
}
if let Some(scissors) = state.scissors {
info = info.scissors(scissors);
}
info
});
let rasterization = desc.rasterization.map(|state| {
let mut info = vk::PipelineRasterizationStateCreateInfo::default()
.line_width(state.line_width)
.cull_mode(state.cull_mode)
.polygon_mode(state.polygon_mode)
.rasterizer_discard_enable(state.discard_enable)
.depth_clamp_enable(state.depth_clamp_enable);
if let Some(depth_bias) = state.depth_bias {
info = info
.depth_bias_enable(true)
.depth_bias_clamp(depth_bias.clamp)
.depth_bias_constant_factor(depth_bias.constant_factor)
.depth_bias_slope_factor(depth_bias.slope_factor);
}
info
});
let multisample = desc.multisample.map(|state| {
let info = vk::PipelineMultisampleStateCreateInfo::default()
.flags(state.flags)
.min_sample_shading(state.min_sample_shading)
.rasterization_samples(state.rasterization_samples)
.sample_mask(state.sample_mask)
.sample_shading_enable(state.sample_shading_enable)
.alpha_to_coverage_enable(state.alpha_to_coverage_enable)
.alpha_to_one_enable(state.alpha_to_one_enable);
info
});
let color_blend = desc.color_blend.map(|state| {
let info = vk::PipelineColorBlendStateCreateInfo::default()
.flags(state.flags)
.attachments(state.attachments)
.blend_constants(state.blend_constants)
.logic_op(state.logic_op.unwrap_or(Default::default()))
.logic_op_enable(state.logic_op.is_some());
info
});
let depth_stencil = desc.depth_stencil.map(|state| {
let mut info =
vk::PipelineDepthStencilStateCreateInfo::default().flags(state.flags);
if let Some(depth) = state.depth {
info = info
.depth_compare_op(depth.compare_op.unwrap_or(vk::CompareOp::default()))
.depth_test_enable(depth.compare_op.is_some())
.depth_write_enable(depth.write_enable)
.depth_bounds_test_enable(depth.bounds.is_some());
if let Some(bounds) = depth.bounds {
info = info
.max_depth_bounds(bounds.max)
.min_depth_bounds(bounds.min);
}
}
if let Some(stencil) = state.stencil {
info = info
.stencil_test_enable(true)
.front(stencil.front)
.back(stencil.back);
}
info
});
let dynamic = desc.dynamic.map(|state| {
let info = vk::PipelineDynamicStateCreateInfo::default()
.flags(state.flags)
.dynamic_states(state.dynamic_states);
info
});
let mut rendering = desc.rendering.map(|state| {
let info = vk::PipelineRenderingCreateInfo::default()
.color_attachment_formats(state.color_formats)
.depth_attachment_format(state.depth_format.unwrap_or_default())
.stencil_attachment_format(state.stencil_format.unwrap_or_default());
info
});
fn option_to_ptr<T>(option: &Option<T>) -> *const T {
option
.as_ref()
.map(|t| t as *const T)
.unwrap_or(core::ptr::null())
}
let mut info = vk::GraphicsPipelineCreateInfo {
flags: desc.flags,
stage_count: stages.len() as u32,
p_stages: stages.as_ptr(),
p_vertex_input_state: option_to_ptr(&vertex_input),
p_input_assembly_state: option_to_ptr(&input_assembly),
p_tessellation_state: option_to_ptr(&tessellation),
p_viewport_state: option_to_ptr(&viewport),
p_rasterization_state: option_to_ptr(&rasterization),
p_multisample_state: option_to_ptr(&multisample),
p_depth_stencil_state: option_to_ptr(&depth_stencil),
p_color_blend_state: option_to_ptr(&color_blend),
p_dynamic_state: option_to_ptr(&dynamic),
layout: desc.layout.handle(),
render_pass: desc.render_pass.unwrap_or(vk::RenderPass::null()),
subpass: desc.subpass.unwrap_or(0),
base_pipeline_handle: desc
.base_pipeline
.map(|piepline| piepline.pipeline.handle())
.unwrap_or(vk::Pipeline::null()),
..Default::default()
};
if let Some(rendering) = rendering.as_mut() {
info = info.push_next(rendering)
}
unsafe {
device.dev().create_graphics_pipelines(
vk::PipelineCache::null(),
core::slice::from_ref(&info),
None,
)
}
}
};
let pipeline = match result {
Ok(pipelines) => pipelines[0],
Err((pipelines, error)) => {
tracing::error!("failed to create pipelines with :{error}");
for pipeline in pipelines {
unsafe {
device.dev().destroy_pipeline(pipeline, None);
}
}
return Err(error.into());
}
};
Ok(Self {
pipeline: DeviceOwnedDebugObject::new(device, pipeline, name)?,
bind_point,
})
}
pub fn handle(&self) -> vk::Pipeline {
self.pipeline.handle()
}
pub fn bind_point(&self) -> vk::PipelineBindPoint {
self.bind_point
}
}