use std::{borrow::Cow, path::Path, sync::Arc}; use ash::{prelude::*, vk}; use crate::{ define_device_owned_handle, device::{Device, DeviceOwnedDebugObject}, }; #[derive(Debug)] pub struct ShaderStageDesc<'a> { pub flags: vk::PipelineShaderStageCreateFlags, pub module: &'a ShaderModule, pub stage: vk::ShaderStageFlags, pub entry: Cow<'a, std::ffi::CStr>, // specialization: Option } #[derive(Debug, Default)] pub struct DescriptorSetLayoutBindingDesc { pub binding: u32, pub count: u32, pub kind: vk::DescriptorType, pub stage: vk::ShaderStageFlags, pub flags: Option, } #[derive(Debug, Default)] pub struct DescriptorSetLayoutDesc<'a> { pub flags: vk::DescriptorSetLayoutCreateFlags, pub bindings: &'a [DescriptorSetLayoutBindingDesc], pub name: Option>, } #[derive(Debug, Default)] pub struct PipelineLayoutDesc<'a> { pub descriptor_set_layouts: &'a [&'a DescriptorSetLayout], pub push_constant_ranges: &'a [vk::PushConstantRange], pub name: Option>, } #[derive(Debug)] pub enum PipelineDesc<'a> { Compute(ComputePipelineDesc<'a>), Graphics(GraphicsPipelineDesc<'a>), } #[derive(Debug)] pub struct ComputePipelineDesc<'a> { pub flags: vk::PipelineCreateFlags, pub name: Option>, pub shader_stage: ShaderStageDesc<'a>, pub layout: Arc, pub base_pipeline: Option>, } #[derive(Debug, Default)] pub struct VertexInputState<'a> { // pub flags: vk::PipelineVertexInputStateCreateFlags, pub bindings: &'a [vk::VertexInputBindingDescription], pub attributes: &'a [vk::VertexInputAttributeDescription], } #[derive(Debug, Default)] pub struct TessellationState { pub flags: vk::PipelineTessellationStateCreateFlags, pub patch_control_points: u32, } #[derive(Debug, Default)] pub struct InputAssemblyState { // pub flags: vk::PipelineInputAssemblyStateCreateFlags, pub topology: vk::PrimitiveTopology, pub primitive_restart: bool, } #[derive(Debug, Default)] pub struct ViewportState<'a> { pub num_scissors: u32, pub scissors: Option<&'a [vk::Rect2D]>, pub num_viewports: u32, pub viewports: Option<&'a [vk::Viewport]>, } #[derive(Debug, Default)] pub struct DepthBiasState { pub clamp: f32, pub constant_factor: f32, pub slope_factor: f32, } #[derive(Debug)] pub struct RasterizationState { pub depth_clamp_enable: bool, pub discard_enable: bool, pub line_width: f32, pub cull_mode: vk::CullModeFlags, pub depth_bias: Option, pub polygon_mode: vk::PolygonMode, } impl Default for RasterizationState { fn default() -> Self { Self { depth_clamp_enable: false, line_width: 1.0, cull_mode: vk::CullModeFlags::BACK, depth_bias: Default::default(), polygon_mode: vk::PolygonMode::FILL, discard_enable: false, } } } #[derive(Debug)] pub struct MultisampleState<'a> { pub flags: vk::PipelineMultisampleStateCreateFlags, pub sample_shading_enable: bool, pub rasterization_samples: vk::SampleCountFlags, pub min_sample_shading: f32, pub sample_mask: &'a [vk::SampleMask], pub alpha_to_coverage_enable: bool, pub alpha_to_one_enable: bool, } impl<'a> Default for MultisampleState<'a> { fn default() -> Self { Self { flags: Default::default(), sample_shading_enable: Default::default(), rasterization_samples: vk::SampleCountFlags::TYPE_1, min_sample_shading: 1.0, sample_mask: Default::default(), alpha_to_coverage_enable: Default::default(), alpha_to_one_enable: Default::default(), } } } #[derive(Debug)] pub struct DepthBounds { pub min: f32, pub max: f32, } #[derive(Debug, Default)] pub struct DepthState { pub write_enable: bool, /// sets depthTestEnable to true when `Some` pub compare_op: Option, /// sets depthBoundsTestEnable to true when `Some` pub bounds: Option, } #[derive(Debug, Default)] pub struct StencilState { pub front: vk::StencilOpState, pub back: vk::StencilOpState, } #[derive(Debug, Default)] pub struct DepthStencilState { pub flags: vk::PipelineDepthStencilStateCreateFlags, pub depth: Option, pub stencil: Option, } #[derive(Debug, Default)] pub struct ColorBlendState<'a> { pub flags: vk::PipelineColorBlendStateCreateFlags, pub attachments: &'a [vk::PipelineColorBlendAttachmentState], pub logic_op: Option, pub blend_constants: [f32; 4], } #[derive(Debug, Default)] pub struct RenderingState<'a> { pub color_formats: &'a [vk::Format], pub depth_format: Option, pub stencil_format: Option, } #[derive(Debug, Default)] pub struct DynamicState<'a> { pub flags: vk::PipelineDynamicStateCreateFlags, pub dynamic_states: &'a [vk::DynamicState], } #[derive(Debug)] pub struct GraphicsPipelineDesc<'a> { pub flags: vk::PipelineCreateFlags, pub name: Option>, pub shader_stages: &'a [ShaderStageDesc<'a>], pub render_pass: Option, pub layout: &'a PipelineLayout, pub subpass: Option, pub base_pipeline: Option>, pub vertex_input: Option>, pub input_assembly: Option, pub tessellation: Option, pub viewport: Option>, pub rasterization: Option, pub multisample: Option>, pub depth_stencil: Option, pub color_blend: Option>, pub dynamic: Option>, pub rendering: Option>, } #[derive(Debug, Default)] pub struct DescriptorPoolDesc<'a> { pub flags: vk::DescriptorPoolCreateFlags, pub name: Option>, pub sizes: &'a [vk::DescriptorPoolSize], pub max_sets: u32, } #[derive(Debug)] pub struct DescriptorSetAllocDesc<'a> { pub name: Option>, pub layout: &'a DescriptorSetLayout, } define_device_owned_handle! { #[derive(Debug)] pub DescriptorPool(vk::DescriptorPool) {} => |this| unsafe { this.device().dev().destroy_descriptor_pool(this.handle(), None); } } impl DescriptorPool { pub fn new(device: Device, desc: DescriptorPoolDesc) -> VkResult { let info = &vk::DescriptorPoolCreateInfo::default() .flags(desc.flags) .max_sets(desc.max_sets) .pool_sizes(desc.sizes); let handle = unsafe { device.dev().create_descriptor_pool(info, None)? }; Self::construct(device, handle, desc.name) } pub fn allocate(&self, descs: &[DescriptorSetAllocDesc]) -> VkResult> { let layouts = descs .iter() .map(|desc| desc.layout.handle()) .collect::>(); let info = &vk::DescriptorSetAllocateInfo::default() .descriptor_pool(self.handle()) .set_layouts(&layouts); let sets = unsafe { self.device().dev().allocate_descriptor_sets(&info)? }; for (&set, desc) in sets.iter().zip(descs) { if let Some(name) = desc.name.as_ref() { self.device().debug_name_object(set, &name)?; } } Ok(sets) } // pub fn free(&self) {} #[allow(dead_code)] pub fn reset(&self) -> VkResult<()> { unsafe { self.device() .dev() .reset_descriptor_pool(self.handle(), vk::DescriptorPoolResetFlags::empty()) } } } define_device_owned_handle! { #[derive(Debug)] pub DescriptorSetLayout(vk::DescriptorSetLayout) {} => |this| unsafe { this.device().dev().destroy_descriptor_set_layout(this.handle(), None); } } impl DescriptorSetLayout { pub fn new(device: Device, desc: DescriptorSetLayoutDesc) -> VkResult { let (flags, bindings): (Vec<_>, Vec<_>) = desc .bindings .iter() .map(|binding| { let flag = binding.flags.unwrap_or_default(); let binding = vk::DescriptorSetLayoutBinding::default() .binding(binding.binding) .descriptor_count(binding.count) .descriptor_type(binding.kind) .stage_flags(binding.stage); (flag, binding) }) .unzip(); let flags = &mut vk::DescriptorSetLayoutBindingFlagsCreateInfo::default().binding_flags(&flags); let mut info = vk::DescriptorSetLayoutCreateInfo::default() .bindings(&bindings) .flags(desc.flags); if device.features().version >= vk::API_VERSION_1_2 || device .features() .supports_extension(&crate::make_extention_properties( ash::ext::descriptor_indexing::NAME, ash::ext::descriptor_indexing::SPEC_VERSION, )) { info = info.push_next(flags); } let layout = unsafe { device.dev().create_descriptor_set_layout(&info, None)? }; Self::construct(device, layout, desc.name) } } use crate::device::DeviceOwned; define_device_owned_handle! { #[derive(Debug)] pub PipelineLayout(vk::PipelineLayout) {} => |this| unsafe { this.device().dev().destroy_pipeline_layout(this.handle(), None); } } impl PipelineLayout { pub fn new(device: Device, desc: PipelineLayoutDesc) -> VkResult { let set_layouts = desc .descriptor_set_layouts .iter() .map(|desc| desc.handle()) .collect::>(); let info = &vk::PipelineLayoutCreateInfo::default() .set_layouts(&set_layouts) .push_constant_ranges(desc.push_constant_ranges); let layout = unsafe { device.dev().create_pipeline_layout(info, None)? }; Self::construct(device, layout, desc.name) } } #[derive(Debug, Default)] pub struct SamplerDesc { pub flags: vk::SamplerCreateFlags, pub min_filter: vk::Filter, pub mag_filter: vk::Filter, pub mipmap_mode: vk::SamplerMipmapMode, pub address_u: vk::SamplerAddressMode, pub address_v: vk::SamplerAddressMode, pub address_w: vk::SamplerAddressMode, pub mip_lod_bias: f32, pub anisotropy_enable: bool, pub max_anisotropy: f32, pub compare_op: Option, pub min_lod: f32, pub max_lod: f32, pub border_color: vk::BorderColor, pub unnormalized_coordinates: bool, } impl Eq for SamplerDesc {} impl PartialEq for SamplerDesc { fn eq(&self, other: &Self) -> bool { use crate::util::eq_f32; self.flags == other.flags && self.min_filter == other.min_filter && self.mag_filter == other.mag_filter && self.mipmap_mode == other.mipmap_mode && self.address_u == other.address_u && self.address_v == other.address_v && self.address_w == other.address_w && self.anisotropy_enable == other.anisotropy_enable && self.compare_op == other.compare_op && eq_f32(self.mip_lod_bias, other.mip_lod_bias) && eq_f32(self.max_anisotropy, other.max_anisotropy) && eq_f32(self.min_lod, other.min_lod) && eq_f32(self.max_lod, other.max_lod) && self.border_color == other.border_color && self.unnormalized_coordinates == other.unnormalized_coordinates } } impl std::hash::Hash for SamplerDesc { fn hash(&self, state: &mut H) { use crate::util::hash_f32; self.flags.hash(state); self.min_filter.hash(state); self.mag_filter.hash(state); self.mipmap_mode.hash(state); self.address_u.hash(state); self.address_v.hash(state); self.address_w.hash(state); hash_f32(state, self.mip_lod_bias); hash_f32(state, self.max_anisotropy); hash_f32(state, self.min_lod); hash_f32(state, self.max_lod); self.anisotropy_enable.hash(state); self.compare_op.hash(state); self.border_color.hash(state); self.unnormalized_coordinates.hash(state); } } define_device_owned_handle! { #[derive(Debug)] pub Sampler(vk::Sampler) {} => |this| unsafe { this.device().dev().destroy_sampler(this.handle(), None); } } impl Sampler { pub fn new(device: Device, desc: &SamplerDesc) -> VkResult { let info = &vk::SamplerCreateInfo::default() .flags(desc.flags) .min_filter(desc.min_filter) .mag_filter(desc.mag_filter) .mip_lod_bias(desc.mip_lod_bias) .mipmap_mode(desc.mipmap_mode) .address_mode_u(desc.address_u) .address_mode_v(desc.address_v) .address_mode_w(desc.address_w) .anisotropy_enable(desc.anisotropy_enable) .max_anisotropy(desc.max_anisotropy) .compare_enable(desc.compare_op.is_some()) .compare_op(desc.compare_op.unwrap_or_default()) .min_lod(desc.min_lod) .max_lod(desc.max_lod) .border_color(desc.border_color) .unnormalized_coordinates(desc.unnormalized_coordinates); let handle = unsafe { device.dev().create_sampler(info, None)? }; Self::construct(device, handle, None) } } define_device_owned_handle! { #[derive(Debug)] pub ShaderModule(vk::ShaderModule) {} => |this| unsafe { this.device().dev().destroy_shader_module(this.handle(), None); } } impl ShaderModule { pub fn new_from_path>(device: Device, path: P) -> crate::Result { use std::io::{BufReader, Read, Seek}; let mut file = std::fs::File::open(path)?; let size = file.seek(std::io::SeekFrom::End(0))? / 4; file.seek(std::io::SeekFrom::Start(0))?; let mut reader = BufReader::new(file); let mut buffer = Vec::::with_capacity(size as usize); buffer.resize(size as usize, 0); let size = reader.read(bytemuck::cast_slice_mut(buffer.as_mut_slice()))?; buffer.resize(size / 4, 0); Ok(Self::new_from_memory(device, &buffer)?) } pub fn new_from_memory(device: Device, buffer: &[u32]) -> VkResult { let info = &vk::ShaderModuleCreateInfo::default().code(buffer); let module = unsafe { device.dev().create_shader_module(info, None)? }; Self::construct(device, module, None) } } #[derive(Debug)] pub struct Pipeline { pipeline: DeviceOwnedDebugObject, bind_point: vk::PipelineBindPoint, } impl Drop for Pipeline { fn drop(&mut self) { unsafe { self.pipeline .dev() .dev() .destroy_pipeline(self.pipeline.handle(), None); } } } impl ShaderStageDesc<'_> { fn into_create_info(&self) -> vk::PipelineShaderStageCreateInfo { vk::PipelineShaderStageCreateInfo::default() .module(self.module.handle()) .flags(self.flags) .stage(self.stage) .name(&self.entry) } } impl Pipeline { pub fn new(device: Device, desc: PipelineDesc) -> VkResult { let name: Option>; let bind_point: vk::PipelineBindPoint; let result = match desc { PipelineDesc::Compute(desc) => { name = desc.name; bind_point = vk::PipelineBindPoint::COMPUTE; let info = &vk::ComputePipelineCreateInfo::default() .flags(desc.flags) .layout(desc.layout.handle()) .base_pipeline_handle( desc.base_pipeline .map(|p| p.handle()) .unwrap_or(vk::Pipeline::null()), ) .stage(desc.shader_stage.into_create_info()); unsafe { device.dev().create_compute_pipelines( vk::PipelineCache::null(), core::slice::from_ref(info), None, ) } } PipelineDesc::Graphics(desc) => { name = desc.name; bind_point = vk::PipelineBindPoint::GRAPHICS; let stages = desc .shader_stages .iter() .map(|stage| stage.into_create_info()) .collect::>(); let vertex_input = desc.vertex_input.map(|vertex| { vk::PipelineVertexInputStateCreateInfo::default() .vertex_attribute_descriptions(vertex.attributes) .vertex_binding_descriptions(vertex.bindings) }); let input_assembly = desc.input_assembly.map(|state| { vk::PipelineInputAssemblyStateCreateInfo::default() .primitive_restart_enable(state.primitive_restart) .topology(state.topology) }); let tessellation = desc.tessellation.map(|state| { vk::PipelineTessellationStateCreateInfo::default() .flags(state.flags) .patch_control_points(state.patch_control_points) }); let viewport = desc.viewport.map(|state| { let mut info = vk::PipelineViewportStateCreateInfo::default() .scissor_count(state.num_scissors) .viewport_count(state.num_viewports); if let Some(viewports) = state.viewports { info = info.viewports(viewports); } if let Some(scissors) = state.scissors { info = info.scissors(scissors); } info }); let rasterization = desc.rasterization.map(|state| { let mut info = vk::PipelineRasterizationStateCreateInfo::default() .line_width(state.line_width) .cull_mode(state.cull_mode) .polygon_mode(state.polygon_mode) .rasterizer_discard_enable(state.discard_enable) .depth_clamp_enable(state.depth_clamp_enable); if let Some(depth_bias) = state.depth_bias { info = info .depth_bias_enable(true) .depth_bias_clamp(depth_bias.clamp) .depth_bias_constant_factor(depth_bias.constant_factor) .depth_bias_slope_factor(depth_bias.slope_factor); } info }); let multisample = desc.multisample.map(|state| { let info = vk::PipelineMultisampleStateCreateInfo::default() .flags(state.flags) .min_sample_shading(state.min_sample_shading) .rasterization_samples(state.rasterization_samples) .sample_mask(state.sample_mask) .sample_shading_enable(state.sample_shading_enable) .alpha_to_coverage_enable(state.alpha_to_coverage_enable) .alpha_to_one_enable(state.alpha_to_one_enable); info }); let color_blend = desc.color_blend.map(|state| { let info = vk::PipelineColorBlendStateCreateInfo::default() .flags(state.flags) .attachments(state.attachments) .blend_constants(state.blend_constants) .logic_op(state.logic_op.unwrap_or(Default::default())) .logic_op_enable(state.logic_op.is_some()); info }); let depth_stencil = desc.depth_stencil.map(|state| { let mut info = vk::PipelineDepthStencilStateCreateInfo::default().flags(state.flags); if let Some(depth) = state.depth { info = info .depth_compare_op(depth.compare_op.unwrap_or(vk::CompareOp::default())) .depth_test_enable(depth.compare_op.is_some()) .depth_write_enable(depth.write_enable) .depth_bounds_test_enable(depth.bounds.is_some()); if let Some(bounds) = depth.bounds { info = info .max_depth_bounds(bounds.max) .min_depth_bounds(bounds.min); } } if let Some(stencil) = state.stencil { info = info .stencil_test_enable(true) .front(stencil.front) .back(stencil.back); } info }); let dynamic = desc.dynamic.map(|state| { let info = vk::PipelineDynamicStateCreateInfo::default() .flags(state.flags) .dynamic_states(state.dynamic_states); info }); let mut rendering = desc.rendering.map(|state| { let info = vk::PipelineRenderingCreateInfo::default() .color_attachment_formats(state.color_formats) .depth_attachment_format(state.depth_format.unwrap_or_default()) .stencil_attachment_format(state.stencil_format.unwrap_or_default()); info }); fn option_to_ptr(option: &Option) -> *const T { option .as_ref() .map(|t| t as *const T) .unwrap_or(core::ptr::null()) } let mut info = vk::GraphicsPipelineCreateInfo { flags: desc.flags, stage_count: stages.len() as u32, p_stages: stages.as_ptr(), p_vertex_input_state: option_to_ptr(&vertex_input), p_input_assembly_state: option_to_ptr(&input_assembly), p_tessellation_state: option_to_ptr(&tessellation), p_viewport_state: option_to_ptr(&viewport), p_rasterization_state: option_to_ptr(&rasterization), p_multisample_state: option_to_ptr(&multisample), p_depth_stencil_state: option_to_ptr(&depth_stencil), p_color_blend_state: option_to_ptr(&color_blend), p_dynamic_state: option_to_ptr(&dynamic), layout: desc.layout.handle(), render_pass: desc.render_pass.unwrap_or(vk::RenderPass::null()), subpass: desc.subpass.unwrap_or(0), base_pipeline_handle: desc .base_pipeline .map(|piepline| piepline.pipeline.handle()) .unwrap_or(vk::Pipeline::null()), ..Default::default() }; if let Some(rendering) = rendering.as_mut() { info = info.push_next(rendering) } unsafe { device.dev().create_graphics_pipelines( vk::PipelineCache::null(), core::slice::from_ref(&info), None, ) } } }; let pipeline = match result { Ok(pipelines) => pipelines[0], Err((pipelines, error)) => { tracing::error!("failed to create pipelines with :{error}"); for pipeline in pipelines { unsafe { device.dev().destroy_pipeline(pipeline, None); } } return Err(error.into()); } }; Ok(Self { pipeline: DeviceOwnedDebugObject::new(device, pipeline, name)?, bind_point, }) } pub fn handle(&self) -> vk::Pipeline { self.pipeline.handle() } pub fn bind_point(&self) -> vk::PipelineBindPoint { self.bind_point } }