diff --git a/source/opt/private_to_local_pass.cpp b/source/opt/private_to_local_pass.cpp index 4904e058b3d..5724a40ab21 100644 --- a/source/opt/private_to_local_pass.cpp +++ b/source/opt/private_to_local_pass.cpp @@ -26,6 +26,7 @@ namespace opt { namespace { constexpr uint32_t kVariableStorageClassInIdx = 0; constexpr uint32_t kSpvTypePointerTypeIdInIdx = 1; +constexpr uint32_t kEntryPointFunctionIdInIdx = 1; } // namespace Pass::Status PrivateToLocalPass::Process() { @@ -48,9 +49,11 @@ Pass::Status PrivateToLocalPass::Process() { continue; } - Function* target_function = FindLocalFunction(inst); - if (target_function != nullptr) { - variables_to_move.push_back({&inst, target_function}); + // TODO: Handle all functions. + // TODO: Might want to return the map from entry points to functions. + std::set target_functions = FindLocalFunctions(inst); + if (!target_functions.empty()) { + variables_to_move.push_back({&inst, *(target_functions.begin())}); } } @@ -85,31 +88,52 @@ Pass::Status PrivateToLocalPass::Process() { return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } -Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const { - bool found_first_use = false; - Function* target_function = nullptr; +std::set PrivateToLocalPass::FindLocalFunctions(const Instruction& inst) const { + // Create a map of entry points to the function id containing the first use of the instruction. There + // must only be one function per entry point if we wish to substitute the private variable. + // TODO: Make a map of uint32_t to Function*? + std::unordered_map ep_to_use {}; + + // Return target functions that can substitute the variable. + + auto const result_id = inst.result_id(); context()->get_def_use_mgr()->ForEachUser( - inst.result_id(), - [&target_function, &found_first_use, this](Instruction* use) { - BasicBlock* current_block = context()->get_instr_block(use); - if (current_block == nullptr) { - return; - } + result_id, + [&target_functions, &ep_to_use, &inst, this](Instruction* use) { + BasicBlock* current_block = context()->get_instr_block(use); + if (current_block == nullptr) { + return; + } - if (!IsValidUse(use)) { - found_first_use = true; - target_function = nullptr; + Function* current_function = current_block->GetParent(); + if (!IsValidUse(use)) { + ep_to_use.erase(current_function); + return; + } + + // Find all entry points that can reach the use instruction. + std::unordered_set ep_ids; + std::set visited_ids; + FindEntryPointFuncs(current_function->result_id(), ep_ids, visited_ids); + + // Update the map of entry points. If the function isn't found, then add it. If the function is found, + // then it must match the current function; otherwise, substitution will not be allowed. + for (auto const ep_id : ep_ids) { + auto const ep_search = ep_to_use.find(ep_id); + if(ep_search == std::end(ep_to_use)) { + ep_to_use[ep_id] = current_function->result_id(); + } else if(ep_to_use[ep_id] != current_function->result_id()) { return; } - Function* current_function = current_block->GetParent(); - if (!found_first_use) { - found_first_use = true; - target_function = current_function; - } else if (target_function != current_function) { - target_function = nullptr; - } - }); - return target_function; + } + }); + + // TODO: Copy functions from ep_to_use to return variable. + std::set target_functions {}; + //if(target_functions.empty()) + + + return target_functions; } // namespace opt bool PrivateToLocalPass::MoveVariable(Instruction* variable, @@ -232,5 +256,77 @@ bool PrivateToLocalPass::UpdateUses(Instruction* inst) { return true; } +bool PrivateToLocalPass::IsEntryPointFunc(const Function* func) const { + for (auto& entry_point : get_module()->entry_points()) { + if (entry_point.GetSingleWordInOperand(kEntryPointFunctionIdInIdx) == + func->result_id()) { + return true; + } + } + + return false; +} + +Instruction PrivateToLocalPass::GetEntryPointFunc(const Function& func) const { + // if(IsEntryPointFunc(func)) { + // return func.DefInst(); + // } + + Instruction* ep_func {nullptr}; + context()->get_def_use_mgr()->WhileEachUser(func.result_id(), + [&ep_func](Instruction* use) { + switch (use->opcode()) { + case spv::Op::OpFunctionCall: + ep_func = use; + return false; + break; + default: + return true; + break; + }; + + }); + + return *ep_func; +} + +bool PrivateToLocalPass::IsEntryPointFunc(const uint32_t& func_id) const { + for (auto& entry_point : get_module()->entry_points()) { + if (entry_point.GetSingleWordInOperand(kEntryPointFunctionIdInIdx) == func_id) { + return true; + } + } + + return false; +} + +// A function may be reached from more than one entry point. +void PrivateToLocalPass::FindEntryPointFuncs(const Function* func, + std::unordered_set& ep_ids, + std::set& visited_ids) const { + // Ignore cycles. Stop if we've visited this function already. + if(visited_ids.find(func) != std::end(visited_ids)) { + return; + } else { + visited_ids.insert(func); + } + + if(IsEntryPointFunc(func)) { + ep_ids.insert(func); + } + + context()->get_def_use_mgr()->ForEachUser(func->result_id(), [this, &ep_ids, &visited_ids](Instruction* use) { + switch (use->opcode()) { + case spv::Op::OpFunctionCall: { + auto current_function = context()->get_instr_block(use)->GetParent(); + FindEntryPointFuncs(current_function->result_id(), ep_ids, visited_ids); + break; + } + default: + break; + }; + }); +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/private_to_local_pass.h b/source/opt/private_to_local_pass.h index e96a965e910..e0ff71b3950 100644 --- a/source/opt/private_to_local_pass.h +++ b/source/opt/private_to_local_pass.h @@ -44,11 +44,12 @@ class PrivateToLocalPass : public Pass { // class of |function|. Returns false if the variable could not be moved. bool MoveVariable(Instruction* variable, Function* function); + // TODO: Update the comment. // |inst| is an instruction declaring a variable. If that variable is // referenced in a single function and all of uses are valid as defined by // |IsValidUse|, then that function is returned. Otherwise, the return // value is |nullptr|. - Function* FindLocalFunction(const Instruction& inst) const; + std::set FindLocalFunctions(const Instruction& inst) const; // Returns true is |inst| is a valid use of a pointer. In this case, a // valid use is one where the transformation is able to rewrite the type to @@ -65,6 +66,14 @@ class PrivateToLocalPass : public Pass { // change of the base pointer now pointing to the function storage class. bool UpdateUse(Instruction* inst, Instruction* user); bool UpdateUses(Instruction* inst); + + bool IsEntryPointFunc(const Function* func) const; + Instruction GetEntryPointFunc(const Function& func) const; + + bool IsEntryPointFunc(const uint32_t& func_id) const; + void FindEntryPointFuncs(const Function* func, + std::unordered_set& ep_ids, + std::set& visited_ids) const; }; } // namespace opt diff --git a/test/opt/private_to_local_test.cpp b/test/opt/private_to_local_test.cpp index f7c37c91119..0347d5519ae 100644 --- a/test/opt/private_to_local_test.cpp +++ b/test/opt/private_to_local_test.cpp @@ -495,6 +495,52 @@ TEST_F(PrivateToLocalTest, DebugPrivateToLocal) { SinglePassRunAndMatch(text, true); } +TEST_F(PrivateToLocalTest, TwoEntryPoints) { + const std::string text = R"( +; CHECK-NOT: OpEntryPoint GLCompute %foo "foo" %in %priv1 %priv2 +; CHECK: OpEntryPoint GLCompute %foo "foo" %in +; CHECK: %priv1 = OpVariable {{%\w+}} Function +; CHECK: %priv2 = OpVariable {{%\w+}} Function +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %foo "foo" %in %priv1 %priv2 +OpExecutionMode %foo LocalSize 1 1 1 +OpName %foo "foo" +OpName %in "in" +OpName %priv1 "priv1" +OpName %priv2 "priv2" +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_ssbo_int = OpTypePointer StorageBuffer %int +%ptr_private_int = OpTypePointer Private %int +%in = OpVariable %ptr_ssbo_int StorageBuffer +%priv1 = OpVariable %ptr_private_int Private +%priv2 = OpVariable %ptr_private_int Private +%void_fn = OpTypeFunction %void +%foo = OpFunction %void None %void_fn +%entry = OpLabel +%1 = OpFunctionCall %void %bar1 +%2 = OpFunctionCall %void %bar2 +OpReturn +OpFunctionEnd +%bar1 = OpFunction %void None %void_fn +%3 = OpLabel +%ld1 = OpLoad %int %in +OpStore %priv1 %ld1 +OpReturn +OpFunctionEnd +%bar2 = OpFunction %void None %void_fn +%4 = OpLabel +%ld2 = OpLoad %int %in +OpStore %priv2 %ld2 +OpReturn +OpFunctionEnd +)"; + + SetTargetEnv(SPV_ENV_UNIVERSAL_1_4); + SinglePassRunAndMatch(text, true); +} + } // namespace } // namespace opt } // namespace spvtools