diff options
Diffstat (limited to 'thirdparty/glslang/SPIRV/SpvPostProcess.cpp')
-rw-r--r-- | thirdparty/glslang/SPIRV/SpvPostProcess.cpp | 71 |
1 files changed, 58 insertions, 13 deletions
diff --git a/thirdparty/glslang/SPIRV/SpvPostProcess.cpp b/thirdparty/glslang/SPIRV/SpvPostProcess.cpp index d40174d172..dd6dabce0d 100644 --- a/thirdparty/glslang/SPIRV/SpvPostProcess.cpp +++ b/thirdparty/glslang/SPIRV/SpvPostProcess.cpp @@ -44,10 +44,8 @@ #include <algorithm> #include "SpvBuilder.h" - #include "spirv.hpp" -#include "GlslangToSpv.h" -#include "SpvBuilder.h" + namespace spv { #include "GLSL.std.450.h" #include "GLSL.ext.KHR.h" @@ -113,8 +111,6 @@ void Builder::postProcessType(const Instruction& inst, Id typeId) } } break; - case OpAccessChain: - case OpPtrAccessChain: case OpCopyObject: break; case OpFConvert: @@ -161,26 +157,43 @@ void Builder::postProcessType(const Instruction& inst, Id typeId) switch (inst.getImmediateOperand(1)) { case GLSLstd450Frexp: case GLSLstd450FrexpStruct: - if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeInt, 16)) + if (getSpvVersion() < spv::Spv_1_3 && containsType(typeId, OpTypeInt, 16)) addExtension(spv::E_SPV_AMD_gpu_shader_int16); break; case GLSLstd450InterpolateAtCentroid: case GLSLstd450InterpolateAtSample: case GLSLstd450InterpolateAtOffset: - if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeFloat, 16)) + if (getSpvVersion() < spv::Spv_1_3 && containsType(typeId, OpTypeFloat, 16)) addExtension(spv::E_SPV_AMD_gpu_shader_half_float); break; default: break; } break; + case OpAccessChain: + case OpPtrAccessChain: + if (isPointerType(typeId)) + break; + if (basicTypeOp == OpTypeInt) { + if (width == 16) + addCapability(CapabilityInt16); + else if (width == 8) + addCapability(CapabilityInt8); + } default: - if (basicTypeOp == OpTypeFloat && width == 16) - addCapability(CapabilityFloat16); - if (basicTypeOp == OpTypeInt && width == 16) - addCapability(CapabilityInt16); - if (basicTypeOp == OpTypeInt && width == 8) - addCapability(CapabilityInt8); + if (basicTypeOp == OpTypeInt) { + if (width == 16) + addCapability(CapabilityInt16); + else if (width == 8) + addCapability(CapabilityInt8); + else if (width == 64) + addCapability(CapabilityInt64); + } else if (basicTypeOp == OpTypeFloat) { + if (width == 16) + addCapability(CapabilityFloat16); + else if (width == 64) + addCapability(CapabilityFloat64); + } break; } } @@ -436,6 +449,38 @@ void Builder::postProcessFeatures() { } } } + + // If any Vulkan memory model-specific functionality is used, update the + // OpMemoryModel to match. + if (capabilities.find(spv::CapabilityVulkanMemoryModelKHR) != capabilities.end()) { + memoryModel = spv::MemoryModelVulkanKHR; + addIncorporatedExtension(spv::E_SPV_KHR_vulkan_memory_model, spv::Spv_1_5); + } + + // Add Aliased decoration if there's more than one Workgroup Block variable. + if (capabilities.find(spv::CapabilityWorkgroupMemoryExplicitLayoutKHR) != capabilities.end()) { + assert(entryPoints.size() == 1); + auto &ep = entryPoints[0]; + + std::vector<Id> workgroup_variables; + for (int i = 0; i < (int)ep->getNumOperands(); i++) { + if (!ep->isIdOperand(i)) + continue; + + const Id id = ep->getIdOperand(i); + const Instruction *instr = module.getInstruction(id); + if (instr->getOpCode() != spv::OpVariable) + continue; + + if (instr->getImmediateOperand(0) == spv::StorageClassWorkgroup) + workgroup_variables.push_back(id); + } + + if (workgroup_variables.size() > 1) { + for (size_t i = 0; i < workgroup_variables.size(); i++) + addDecoration(workgroup_variables[i], spv::DecorationAliased); + } + } } #endif |