summaryrefslogtreecommitdiff
path: root/thirdparty/glslang/SPIRV/SpvPostProcess.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/glslang/SPIRV/SpvPostProcess.cpp')
-rw-r--r--thirdparty/glslang/SPIRV/SpvPostProcess.cpp78
1 files changed, 51 insertions, 27 deletions
diff --git a/thirdparty/glslang/SPIRV/SpvPostProcess.cpp b/thirdparty/glslang/SPIRV/SpvPostProcess.cpp
index 6e1f7cf61f..d40174d172 100644
--- a/thirdparty/glslang/SPIRV/SpvPostProcess.cpp
+++ b/thirdparty/glslang/SPIRV/SpvPostProcess.cpp
@@ -39,6 +39,7 @@
#include <cassert>
#include <cstdlib>
+#include <unordered_map>
#include <unordered_set>
#include <algorithm>
@@ -51,16 +52,13 @@ namespace spv {
#include "GLSL.std.450.h"
#include "GLSL.ext.KHR.h"
#include "GLSL.ext.EXT.h"
-#ifdef AMD_EXTENSIONS
#include "GLSL.ext.AMD.h"
-#endif
-#ifdef NV_EXTENSIONS
#include "GLSL.ext.NV.h"
-#endif
}
namespace spv {
+#ifndef GLSLANG_WEB
// Hook to visit each operand type and result type of an instruction.
// Will be called multiple times for one instruction, once for each typed
// operand and the result.
@@ -160,7 +158,6 @@ void Builder::postProcessType(const Instruction& inst, Id typeId)
}
break;
case OpExtInst:
-#if AMD_EXTENSIONS
switch (inst.getImmediateOperand(1)) {
case GLSLstd450Frexp:
case GLSLstd450FrexpStruct:
@@ -176,7 +173,6 @@ void Builder::postProcessType(const Instruction& inst, Id typeId)
default:
break;
}
-#endif
break;
default:
if (basicTypeOp == OpTypeFloat && width == 16)
@@ -222,12 +218,10 @@ void Builder::postProcess(Instruction& inst)
addCapability(CapabilityImageQuery);
break;
-#ifdef NV_EXTENSIONS
case OpGroupNonUniformPartitionNV:
addExtension(E_SPV_NV_shader_subgroup_partitioned);
addCapability(CapabilityGroupNonUniformPartitionedNV);
break;
-#endif
case OpLoad:
case OpStore:
@@ -326,17 +320,16 @@ void Builder::postProcess(Instruction& inst)
}
}
}
-
-// Called for each instruction in a reachable block.
-void Builder::postProcessReachable(const Instruction&)
-{
- // did have code here, but questionable to do so without deleting the instructions
-}
+#endif
// comment in header
-void Builder::postProcess()
+void Builder::postProcessCFG()
{
+ // reachableBlocks is the set of blockss reached via control flow, or which are
+ // unreachable continue targert or unreachable merge.
std::unordered_set<const Block*> reachableBlocks;
+ std::unordered_map<Block*, Block*> headerForUnreachableContinue;
+ std::unordered_set<Block*> unreachableMerges;
std::unordered_set<Id> unreachableDefinitions;
// Collect IDs defined in unreachable blocks. For each function, label the
// reachable blocks first. Then for each unreachable block, collect the
@@ -344,16 +337,41 @@ void Builder::postProcess()
for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
Function* f = *fi;
Block* entry = f->getEntryBlock();
- inReadableOrder(entry, [&reachableBlocks](const Block* b) { reachableBlocks.insert(b); });
+ inReadableOrder(entry,
+ [&reachableBlocks, &unreachableMerges, &headerForUnreachableContinue]
+ (Block* b, ReachReason why, Block* header) {
+ reachableBlocks.insert(b);
+ if (why == ReachDeadContinue) headerForUnreachableContinue[b] = header;
+ if (why == ReachDeadMerge) unreachableMerges.insert(b);
+ });
for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
Block* b = *bi;
- if (reachableBlocks.count(b) == 0) {
- for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
+ if (unreachableMerges.count(b) != 0 || headerForUnreachableContinue.count(b) != 0) {
+ auto ii = b->getInstructions().cbegin();
+ ++ii; // Keep potential decorations on the label.
+ for (; ii != b->getInstructions().cend(); ++ii)
+ unreachableDefinitions.insert(ii->get()->getResultId());
+ } else if (reachableBlocks.count(b) == 0) {
+ // The normal case for unreachable code. All definitions are considered dead.
+ for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ++ii)
unreachableDefinitions.insert(ii->get()->getResultId());
}
}
}
+ // Modify unreachable merge blocks and unreachable continue targets.
+ // Delete their contents.
+ for (auto mergeIter = unreachableMerges.begin(); mergeIter != unreachableMerges.end(); ++mergeIter) {
+ (*mergeIter)->rewriteAsCanonicalUnreachableMerge();
+ }
+ for (auto continueIter = headerForUnreachableContinue.begin();
+ continueIter != headerForUnreachableContinue.end();
+ ++continueIter) {
+ Block* continue_target = continueIter->first;
+ Block* header = continueIter->second;
+ continue_target->rewriteAsCanonicalUnreachableContinue(header);
+ }
+
// Remove unneeded decorations, for unreachable instructions
decorations.erase(std::remove_if(decorations.begin(), decorations.end(),
[&unreachableDefinitions](std::unique_ptr<Instruction>& I) -> bool {
@@ -361,7 +379,11 @@ void Builder::postProcess()
return unreachableDefinitions.count(decoration_id) != 0;
}),
decorations.end());
+}
+#ifndef GLSLANG_WEB
+// comment in header
+void Builder::postProcessFeatures() {
// Add per-instruction capabilities, extensions, etc.,
// Look for any 8/16 bit type in physical storage buffer class, and set the
@@ -371,24 +393,17 @@ void Builder::postProcess()
Instruction* type = groupedTypes[OpTypePointer][t];
if (type->getImmediateOperand(0) == (unsigned)StorageClassPhysicalStorageBufferEXT) {
if (containsType(type->getIdOperand(1), OpTypeInt, 8)) {
- addExtension(spv::E_SPV_KHR_8bit_storage);
+ addIncorporatedExtension(spv::E_SPV_KHR_8bit_storage, spv::Spv_1_5);
addCapability(spv::CapabilityStorageBuffer8BitAccess);
}
if (containsType(type->getIdOperand(1), OpTypeInt, 16) ||
containsType(type->getIdOperand(1), OpTypeFloat, 16)) {
- addExtension(spv::E_SPV_KHR_16bit_storage);
+ addIncorporatedExtension(spv::E_SPV_KHR_16bit_storage, spv::Spv_1_3);
addCapability(spv::CapabilityStorageBuffer16BitAccess);
}
}
}
- // process all reachable instructions...
- for (auto bi = reachableBlocks.cbegin(); bi != reachableBlocks.cend(); ++bi) {
- const Block* block = *bi;
- const auto function = [this](const std::unique_ptr<Instruction>& inst) { postProcessReachable(*inst.get()); };
- std::for_each(block->getInstructions().begin(), block->getInstructions().end(), function);
- }
-
// process all block-contained instructions
for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
Function* f = *fi;
@@ -422,5 +437,14 @@ void Builder::postProcess()
}
}
}
+#endif
+
+// comment in header
+void Builder::postProcess() {
+ postProcessCFG();
+#ifndef GLSLANG_WEB
+ postProcessFeatures();
+#endif
+}
}; // end spv namespace