summaryrefslogtreecommitdiff
path: root/drivers/vulkan
diff options
context:
space:
mode:
authorreduz <reduzio@gmail.com>2021-07-09 16:48:28 -0300
committerRĂ©mi Verschelde <rverschelde@gmail.com>2021-07-11 23:16:09 +0200
commitb2f6db7aa817159bd57d88bff8af017a89759aab (patch)
tree26b3401a34ee165adb9c41bdf63c6f6fc90cb0d6 /drivers/vulkan
parentfb3961b2ef9ed03501f98a8aa621f78679cc2be9 (diff)
Implement Specialization Constants
* Added support to our local copy of SpirV Reflect (which does not support it). * Pass them on render or compute pipeline creation. * Not implemented in our shaders yet.
Diffstat (limited to 'drivers/vulkan')
-rw-r--r--drivers/vulkan/rendering_device_vulkan.cpp160
-rw-r--r--drivers/vulkan/rendering_device_vulkan.h10
2 files changed, 164 insertions, 6 deletions
diff --git a/drivers/vulkan/rendering_device_vulkan.cpp b/drivers/vulkan/rendering_device_vulkan.cpp
index 6c1f1e4852..d3d49503d8 100644
--- a/drivers/vulkan/rendering_device_vulkan.cpp
+++ b/drivers/vulkan/rendering_device_vulkan.cpp
@@ -4374,6 +4374,8 @@ RID RenderingDeviceVulkan::shader_create(const Vector<ShaderStageData> &p_stages
uint32_t stages_processed = 0;
+ Vector<Shader::SpecializationConstant> specialization_constants;
+
bool is_compute = false;
uint32_t compute_local_size[3] = { 0, 0, 0 };
@@ -4560,6 +4562,62 @@ RID RenderingDeviceVulkan::shader_create(const Vector<ShaderStageData> &p_stages
}
}
+ {
+ //specialization constants
+
+ uint32_t sc_count = 0;
+ result = spvReflectEnumerateSpecializationConstants(&module, &sc_count, nullptr);
+ ERR_FAIL_COND_V_MSG(result != SPV_REFLECT_RESULT_SUCCESS, RID(),
+ "Reflection of SPIR-V shader stage '" + String(shader_stage_names[p_stages[i].shader_stage]) + "' failed enumerating specialization constants.");
+
+ if (sc_count) {
+ Vector<SpvReflectSpecializationConstant *> spec_constants;
+ spec_constants.resize(sc_count);
+
+ result = spvReflectEnumerateSpecializationConstants(&module, &sc_count, spec_constants.ptrw());
+ ERR_FAIL_COND_V_MSG(result != SPV_REFLECT_RESULT_SUCCESS, RID(),
+ "Reflection of SPIR-V shader stage '" + String(shader_stage_names[p_stages[i].shader_stage]) + "' failed obtaining specialization constants.");
+
+ for (uint32_t j = 0; j < sc_count; j++) {
+ int32_t existing = -1;
+ Shader::SpecializationConstant sconst;
+ sconst.constant.constant_id = spec_constants[j]->constant_id;
+ switch (spec_constants[j]->constant_type) {
+ case SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL: {
+ sconst.constant.type = PIPELINE_SPECIALIZATION_CONSTANT_TYPE_BOOL;
+ sconst.constant.bool_value = spec_constants[j]->default_value.int_bool_value != 0;
+ } break;
+ case SPV_REFLECT_SPECIALIZATION_CONSTANT_INT: {
+ sconst.constant.type = PIPELINE_SPECIALIZATION_CONSTANT_TYPE_INT;
+ sconst.constant.int_value = spec_constants[j]->default_value.int_bool_value;
+ } break;
+ case SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT: {
+ sconst.constant.type = PIPELINE_SPECIALIZATION_CONSTANT_TYPE_FLOAT;
+ sconst.constant.float_value = spec_constants[j]->default_value.float_value;
+ } break;
+ }
+ sconst.stage_flags = 1 << p_stages[i].shader_stage;
+
+ print_line("spec constant " + itos(i) + ": " + String(spec_constants[j]->name) + " type " + itos(spec_constants[j]->constant_type) + " id " + itos(spec_constants[j]->constant_id));
+
+ for (int k = 0; k < specialization_constants.size(); k++) {
+ if (specialization_constants[k].constant.constant_id == sconst.constant.constant_id) {
+ ERR_FAIL_COND_V_MSG(specialization_constants[k].constant.type != sconst.constant.type, RID(), "More than one specialization constant used for id (" + itos(sconst.constant.constant_id) + "), but their types differ.");
+ ERR_FAIL_COND_V_MSG(specialization_constants[k].constant.int_value != sconst.constant.int_value, RID(), "More than one specialization constant used for id (" + itos(sconst.constant.constant_id) + "), but their default values differ.");
+ existing = k;
+ break;
+ }
+ }
+
+ if (existing > 0) {
+ specialization_constants.write[existing].stage_flags |= sconst.stage_flags;
+ } else {
+ specialization_constants.push_back(sconst);
+ }
+ }
+ }
+ }
+
if (stage == SHADER_STAGE_VERTEX) {
uint32_t iv_count = 0;
result = spvReflectEnumerateInputVariables(&module, &iv_count, nullptr);
@@ -4656,6 +4714,7 @@ RID RenderingDeviceVulkan::shader_create(const Vector<ShaderStageData> &p_stages
shader.compute_local_size[0] = compute_local_size[0];
shader.compute_local_size[1] = compute_local_size[1];
shader.compute_local_size[2] = compute_local_size[2];
+ shader.specialization_constants = specialization_constants;
String error_text;
@@ -5651,7 +5710,7 @@ Vector<uint8_t> RenderingDeviceVulkan::buffer_get_data(RID p_buffer) {
/**** RENDER PIPELINE ****/
/*************************/
-RID RenderingDeviceVulkan::render_pipeline_create(RID p_shader, FramebufferFormatID p_framebuffer_format, VertexFormatID p_vertex_format, RenderPrimitive p_render_primitive, const PipelineRasterizationState &p_rasterization_state, const PipelineMultisampleState &p_multisample_state, const PipelineDepthStencilState &p_depth_stencil_state, const PipelineColorBlendState &p_blend_state, int p_dynamic_state_flags, uint32_t p_for_render_pass) {
+RID RenderingDeviceVulkan::render_pipeline_create(RID p_shader, FramebufferFormatID p_framebuffer_format, VertexFormatID p_vertex_format, RenderPrimitive p_render_primitive, const PipelineRasterizationState &p_rasterization_state, const PipelineMultisampleState &p_multisample_state, const PipelineDepthStencilState &p_depth_stencil_state, const PipelineColorBlendState &p_blend_state, int p_dynamic_state_flags, uint32_t p_for_render_pass, const Vector<PipelineSpecializationConstant> &p_specialization_constants) {
_THREAD_SAFE_METHOD_
//needs a shader
@@ -5969,8 +6028,63 @@ RID RenderingDeviceVulkan::render_pipeline_create(RID p_shader, FramebufferForma
graphics_pipeline_create_info.pNext = nullptr;
graphics_pipeline_create_info.flags = 0;
- graphics_pipeline_create_info.stageCount = shader->pipeline_stages.size();
- graphics_pipeline_create_info.pStages = shader->pipeline_stages.ptr();
+ Vector<VkPipelineShaderStageCreateInfo> pipeline_stages = shader->pipeline_stages;
+ Vector<VkSpecializationInfo> specialization_info;
+ Vector<Vector<VkSpecializationMapEntry>> specialization_map_entries;
+ Vector<uint32_t> specialization_constant_data;
+
+ if (shader->specialization_constants.size()) {
+ specialization_constant_data.resize(shader->specialization_constants.size());
+ uint32_t *data_ptr = specialization_constant_data.ptrw();
+ specialization_info.resize(pipeline_stages.size());
+ specialization_map_entries.resize(pipeline_stages.size());
+ for (int i = 0; i < shader->specialization_constants.size(); i++) {
+ //see if overriden
+ const Shader::SpecializationConstant &sc = shader->specialization_constants[i];
+ data_ptr[i] = sc.constant.int_value; //just copy the 32 bits
+
+ for (int j = 0; j < p_specialization_constants.size(); j++) {
+ const PipelineSpecializationConstant &psc = p_specialization_constants[j];
+ if (psc.constant_id == sc.constant.constant_id) {
+ ERR_FAIL_COND_V_MSG(psc.type != sc.constant.type, RID(), "Specialization constant provided for id (" + itos(sc.constant.constant_id) + ") is of the wrong type.");
+ data_ptr[i] = sc.constant.int_value;
+ break;
+ }
+ }
+
+ VkSpecializationMapEntry entry;
+
+ entry.constantID = sc.constant.constant_id;
+ entry.offset = i * sizeof(uint32_t);
+ entry.size = sizeof(uint32_t);
+
+ for (int j = 0; j < SHADER_STAGE_MAX; j++) {
+ if (sc.stage_flags & (1 << j)) {
+ VkShaderStageFlagBits stage = shader_stage_masks[j];
+ for (int k = 0; k < pipeline_stages.size(); k++) {
+ if (pipeline_stages[k].stage == stage) {
+ specialization_map_entries.write[k].push_back(entry);
+ }
+ }
+ }
+ }
+ }
+
+ for (int k = 0; k < pipeline_stages.size(); k++) {
+ if (specialization_map_entries[k].size()) {
+ specialization_info.write[k].dataSize = specialization_constant_data.size() * sizeof(uint32_t);
+ specialization_info.write[k].pData = data_ptr;
+ specialization_info.write[k].mapEntryCount = specialization_map_entries[k].size();
+ specialization_info.write[k].pMapEntries = specialization_map_entries[k].ptr();
+
+ pipeline_stages.write[k].pSpecializationInfo = specialization_info.ptr();
+ }
+ }
+ }
+
+ graphics_pipeline_create_info.stageCount = pipeline_stages.size();
+ graphics_pipeline_create_info.pStages = pipeline_stages.ptr();
+
graphics_pipeline_create_info.pVertexInputState = &pipeline_vertex_input_state_create_info;
graphics_pipeline_create_info.pInputAssemblyState = &input_assembly_create_info;
graphics_pipeline_create_info.pTessellationState = &tessellation_create_info;
@@ -6039,7 +6153,7 @@ bool RenderingDeviceVulkan::render_pipeline_is_valid(RID p_pipeline) {
/**** COMPUTE PIPELINE ****/
/**************************/
-RID RenderingDeviceVulkan::compute_pipeline_create(RID p_shader) {
+RID RenderingDeviceVulkan::compute_pipeline_create(RID p_shader, const Vector<PipelineSpecializationConstant> &p_specialization_constants) {
_THREAD_SAFE_METHOD_
//needs a shader
@@ -6061,6 +6175,44 @@ RID RenderingDeviceVulkan::compute_pipeline_create(RID p_shader) {
compute_pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE;
compute_pipeline_create_info.basePipelineIndex = 0;
+ VkSpecializationInfo specialization_info;
+ Vector<VkSpecializationMapEntry> specialization_map_entries;
+ Vector<uint32_t> specialization_constant_data;
+
+ if (shader->specialization_constants.size()) {
+ specialization_constant_data.resize(shader->specialization_constants.size());
+ uint32_t *data_ptr = specialization_constant_data.ptrw();
+ for (int i = 0; i < shader->specialization_constants.size(); i++) {
+ //see if overriden
+ const Shader::SpecializationConstant &sc = shader->specialization_constants[i];
+ data_ptr[i] = sc.constant.int_value; //just copy the 32 bits
+
+ for (int j = 0; j < p_specialization_constants.size(); j++) {
+ const PipelineSpecializationConstant &psc = p_specialization_constants[j];
+ if (psc.constant_id == sc.constant.constant_id) {
+ ERR_FAIL_COND_V_MSG(psc.type != sc.constant.type, RID(), "Specialization constant provided for id (" + itos(sc.constant.constant_id) + ") is of the wrong type.");
+ data_ptr[i] = sc.constant.int_value;
+ break;
+ }
+ }
+
+ VkSpecializationMapEntry entry;
+
+ entry.constantID = sc.constant.constant_id;
+ entry.offset = i * sizeof(uint32_t);
+ entry.size = sizeof(uint32_t);
+
+ specialization_map_entries.push_back(entry);
+ }
+
+ specialization_info.dataSize = specialization_constant_data.size() * sizeof(uint32_t);
+ specialization_info.pData = data_ptr;
+ specialization_info.mapEntryCount = specialization_map_entries.size();
+ specialization_info.pMapEntries = specialization_map_entries.ptr();
+
+ compute_pipeline_create_info.stage.pSpecializationInfo = &specialization_info;
+ }
+
ComputePipeline pipeline;
VkResult err = vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &compute_pipeline_create_info, nullptr, &pipeline.pipeline);
ERR_FAIL_COND_V_MSG(err, RID(), "vkCreateComputePipelines failed with error " + itos(err) + ".");
diff --git a/drivers/vulkan/rendering_device_vulkan.h b/drivers/vulkan/rendering_device_vulkan.h
index ff9ad71268..8b95ff43b8 100644
--- a/drivers/vulkan/rendering_device_vulkan.h
+++ b/drivers/vulkan/rendering_device_vulkan.h
@@ -623,11 +623,17 @@ class RenderingDeviceVulkan : public RenderingDevice {
uint32_t compute_local_size[3] = { 0, 0, 0 };
+ struct SpecializationConstant {
+ PipelineSpecializationConstant constant;
+ uint32_t stage_flags = 0;
+ };
+
bool is_compute = false;
int max_output = 0;
Vector<Set> sets;
Vector<uint32_t> set_formats;
Vector<VkPipelineShaderStageCreateInfo> pipeline_stages;
+ Vector<SpecializationConstant> specialization_constants;
VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
};
@@ -1100,14 +1106,14 @@ public:
/**** RENDER PIPELINE ****/
/*************************/
- virtual RID render_pipeline_create(RID p_shader, FramebufferFormatID p_framebuffer_format, VertexFormatID p_vertex_format, RenderPrimitive p_render_primitive, const PipelineRasterizationState &p_rasterization_state, const PipelineMultisampleState &p_multisample_state, const PipelineDepthStencilState &p_depth_stencil_state, const PipelineColorBlendState &p_blend_state, int p_dynamic_state_flags = 0, uint32_t p_for_render_pass = 0);
+ virtual RID render_pipeline_create(RID p_shader, FramebufferFormatID p_framebuffer_format, VertexFormatID p_vertex_format, RenderPrimitive p_render_primitive, const PipelineRasterizationState &p_rasterization_state, const PipelineMultisampleState &p_multisample_state, const PipelineDepthStencilState &p_depth_stencil_state, const PipelineColorBlendState &p_blend_state, int p_dynamic_state_flags = 0, uint32_t p_for_render_pass = 0, const Vector<PipelineSpecializationConstant> &p_specialization_constants = Vector<PipelineSpecializationConstant>());
virtual bool render_pipeline_is_valid(RID p_pipeline);
/**************************/
/**** COMPUTE PIPELINE ****/
/**************************/
- virtual RID compute_pipeline_create(RID p_shader);
+ virtual RID compute_pipeline_create(RID p_shader, const Vector<PipelineSpecializationConstant> &p_specialization_constants = Vector<PipelineSpecializationConstant>());
virtual bool compute_pipeline_is_valid(RID p_pipeline);
/****************/