diff options
Diffstat (limited to 'thirdparty/glslang/SPIRV/SpvPostProcess.cpp')
-rw-r--r-- | thirdparty/glslang/SPIRV/SpvPostProcess.cpp | 78 |
1 files changed, 51 insertions, 27 deletions
diff --git a/thirdparty/glslang/SPIRV/SpvPostProcess.cpp b/thirdparty/glslang/SPIRV/SpvPostProcess.cpp index 6e1f7cf61f..d40174d172 100644 --- a/thirdparty/glslang/SPIRV/SpvPostProcess.cpp +++ b/thirdparty/glslang/SPIRV/SpvPostProcess.cpp @@ -39,6 +39,7 @@ #include <cassert> #include <cstdlib> +#include <unordered_map> #include <unordered_set> #include <algorithm> @@ -51,16 +52,13 @@ namespace spv { #include "GLSL.std.450.h" #include "GLSL.ext.KHR.h" #include "GLSL.ext.EXT.h" -#ifdef AMD_EXTENSIONS #include "GLSL.ext.AMD.h" -#endif -#ifdef NV_EXTENSIONS #include "GLSL.ext.NV.h" -#endif } namespace spv { +#ifndef GLSLANG_WEB // Hook to visit each operand type and result type of an instruction. // Will be called multiple times for one instruction, once for each typed // operand and the result. @@ -160,7 +158,6 @@ void Builder::postProcessType(const Instruction& inst, Id typeId) } break; case OpExtInst: -#if AMD_EXTENSIONS switch (inst.getImmediateOperand(1)) { case GLSLstd450Frexp: case GLSLstd450FrexpStruct: @@ -176,7 +173,6 @@ void Builder::postProcessType(const Instruction& inst, Id typeId) default: break; } -#endif break; default: if (basicTypeOp == OpTypeFloat && width == 16) @@ -222,12 +218,10 @@ void Builder::postProcess(Instruction& inst) addCapability(CapabilityImageQuery); break; -#ifdef NV_EXTENSIONS case OpGroupNonUniformPartitionNV: addExtension(E_SPV_NV_shader_subgroup_partitioned); addCapability(CapabilityGroupNonUniformPartitionedNV); break; -#endif case OpLoad: case OpStore: @@ -326,17 +320,16 @@ void Builder::postProcess(Instruction& inst) } } } - -// Called for each instruction in a reachable block. -void Builder::postProcessReachable(const Instruction&) -{ - // did have code here, but questionable to do so without deleting the instructions -} +#endif // comment in header -void Builder::postProcess() +void Builder::postProcessCFG() { + // reachableBlocks is the set of blockss reached via control flow, or which are + // unreachable continue targert or unreachable merge. std::unordered_set<const Block*> reachableBlocks; + std::unordered_map<Block*, Block*> headerForUnreachableContinue; + std::unordered_set<Block*> unreachableMerges; std::unordered_set<Id> unreachableDefinitions; // Collect IDs defined in unreachable blocks. For each function, label the // reachable blocks first. Then for each unreachable block, collect the @@ -344,16 +337,41 @@ void Builder::postProcess() for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) { Function* f = *fi; Block* entry = f->getEntryBlock(); - inReadableOrder(entry, [&reachableBlocks](const Block* b) { reachableBlocks.insert(b); }); + inReadableOrder(entry, + [&reachableBlocks, &unreachableMerges, &headerForUnreachableContinue] + (Block* b, ReachReason why, Block* header) { + reachableBlocks.insert(b); + if (why == ReachDeadContinue) headerForUnreachableContinue[b] = header; + if (why == ReachDeadMerge) unreachableMerges.insert(b); + }); for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) { Block* b = *bi; - if (reachableBlocks.count(b) == 0) { - for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++) + if (unreachableMerges.count(b) != 0 || headerForUnreachableContinue.count(b) != 0) { + auto ii = b->getInstructions().cbegin(); + ++ii; // Keep potential decorations on the label. + for (; ii != b->getInstructions().cend(); ++ii) + unreachableDefinitions.insert(ii->get()->getResultId()); + } else if (reachableBlocks.count(b) == 0) { + // The normal case for unreachable code. All definitions are considered dead. + for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ++ii) unreachableDefinitions.insert(ii->get()->getResultId()); } } } + // Modify unreachable merge blocks and unreachable continue targets. + // Delete their contents. + for (auto mergeIter = unreachableMerges.begin(); mergeIter != unreachableMerges.end(); ++mergeIter) { + (*mergeIter)->rewriteAsCanonicalUnreachableMerge(); + } + for (auto continueIter = headerForUnreachableContinue.begin(); + continueIter != headerForUnreachableContinue.end(); + ++continueIter) { + Block* continue_target = continueIter->first; + Block* header = continueIter->second; + continue_target->rewriteAsCanonicalUnreachableContinue(header); + } + // Remove unneeded decorations, for unreachable instructions decorations.erase(std::remove_if(decorations.begin(), decorations.end(), [&unreachableDefinitions](std::unique_ptr<Instruction>& I) -> bool { @@ -361,7 +379,11 @@ void Builder::postProcess() return unreachableDefinitions.count(decoration_id) != 0; }), decorations.end()); +} +#ifndef GLSLANG_WEB +// comment in header +void Builder::postProcessFeatures() { // Add per-instruction capabilities, extensions, etc., // Look for any 8/16 bit type in physical storage buffer class, and set the @@ -371,24 +393,17 @@ void Builder::postProcess() Instruction* type = groupedTypes[OpTypePointer][t]; if (type->getImmediateOperand(0) == (unsigned)StorageClassPhysicalStorageBufferEXT) { if (containsType(type->getIdOperand(1), OpTypeInt, 8)) { - addExtension(spv::E_SPV_KHR_8bit_storage); + addIncorporatedExtension(spv::E_SPV_KHR_8bit_storage, spv::Spv_1_5); addCapability(spv::CapabilityStorageBuffer8BitAccess); } if (containsType(type->getIdOperand(1), OpTypeInt, 16) || containsType(type->getIdOperand(1), OpTypeFloat, 16)) { - addExtension(spv::E_SPV_KHR_16bit_storage); + addIncorporatedExtension(spv::E_SPV_KHR_16bit_storage, spv::Spv_1_3); addCapability(spv::CapabilityStorageBuffer16BitAccess); } } } - // process all reachable instructions... - for (auto bi = reachableBlocks.cbegin(); bi != reachableBlocks.cend(); ++bi) { - const Block* block = *bi; - const auto function = [this](const std::unique_ptr<Instruction>& inst) { postProcessReachable(*inst.get()); }; - std::for_each(block->getInstructions().begin(), block->getInstructions().end(), function); - } - // process all block-contained instructions for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) { Function* f = *fi; @@ -422,5 +437,14 @@ void Builder::postProcess() } } } +#endif + +// comment in header +void Builder::postProcess() { + postProcessCFG(); +#ifndef GLSLANG_WEB + postProcessFeatures(); +#endif +} }; // end spv namespace |