Skip to content

Commit

Permalink
[opt] Fix uses of type manager in fix storage class (#5740)
Browse files Browse the repository at this point in the history
This removes some uses of the type manager. One use could not be
removed. Instead I had to update GenCopy to not use the type manager,
and be able to copy pointers.

Part of #5691
  • Loading branch information
s-perron authored Jul 24, 2024
1 parent e99a5c0 commit 81a1160
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 62 deletions.
44 changes: 30 additions & 14 deletions source/opt/fix_storage_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,26 @@ bool FixStorageClass::IsPointerResultType(Instruction* inst) {
if (inst->type_id() == 0) {
return false;
}
const analysis::Type* ret_type =
context()->get_type_mgr()->GetType(inst->type_id());
return ret_type->AsPointer() != nullptr;

Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
return type_def->opcode() == spv::Op::OpTypePointer;
}

bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
spv::StorageClass storage_class) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::Type* pType = type_mgr->GetType(inst->type_id());
const analysis::Pointer* result_type = pType->AsPointer();
if (inst->type_id() == 0) {
return false;
}

if (result_type == nullptr) {
Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
if (type_def->opcode() != spv::Op::OpTypePointer) {
return false;
}

return (result_type->storage_class() == storage_class);
const uint32_t kPointerTypeStorageClassIndex = 0;
spv::StorageClass pointer_storage_class = static_cast<spv::StorageClass>(
type_def->GetSingleWordInOperand(kPointerTypeStorageClassIndex));
return pointer_storage_class == storage_class;
}

bool FixStorageClass::ChangeResultType(Instruction* inst,
Expand Down Expand Up @@ -301,9 +305,11 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
break;
}

Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id);
assert(orig_type_inst->opcode() == spv::Op::OpTypePointer);
id = orig_type_inst->GetSingleWordInOperand(1);
Instruction* id_type_inst = get_def_use_mgr()->GetDef(id);
assert(id_type_inst->opcode() == spv::Op::OpTypePointer);
id = id_type_inst->GetSingleWordInOperand(1);
spv::StorageClass input_storage_class =
static_cast<spv::StorageClass>(id_type_inst->GetSingleWordInOperand(0));

for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
Instruction* type_inst = get_def_use_mgr()->GetDef(id);
Expand Down Expand Up @@ -336,9 +342,19 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
"Tried to extract from an object where it cannot be done.");
}

return context()->get_type_mgr()->FindPointerToType(
id, static_cast<spv::StorageClass>(
orig_type_inst->GetSingleWordInOperand(0)));
Instruction* orig_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
spv::StorageClass orig_storage_class =
static_cast<spv::StorageClass>(orig_type_inst->GetSingleWordInOperand(0));
assert(orig_type_inst->opcode() == spv::Op::OpTypePointer);
if (orig_type_inst->GetSingleWordInOperand(1) == id &&
input_storage_class == orig_storage_class) {
// The existing type is correct. Avoid the search for the type. Note that if
// there is a duplicate type, the search below could return a different type
// forcing more changes to the code than necessary.
return inst->type_id();
}

return context()->get_type_mgr()->FindPointerToType(id, input_storage_class);
}

// namespace opt
Expand Down
92 changes: 44 additions & 48 deletions source/opt/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ uint32_t Pass::GetNullId(uint32_t type_id) {

uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id,
Instruction* insertion_position) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();

uint32_t original_type_id = object_to_copy->type_id();
Expand All @@ -95,55 +94,52 @@ uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id,
context(), insertion_position,
IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse);

analysis::Type* original_type = type_mgr->GetType(original_type_id);
analysis::Type* new_type = type_mgr->GetType(new_type_id);

if (const analysis::Array* original_array_type = original_type->AsArray()) {
uint32_t original_element_type_id =
type_mgr->GetId(original_array_type->element_type());

analysis::Array* new_array_type = new_type->AsArray();
assert(new_array_type != nullptr && "Can't copy an array to a non-array.");
uint32_t new_element_type_id =
type_mgr->GetId(new_array_type->element_type());

std::vector<uint32_t> element_ids;
const analysis::Constant* length_const =
const_mgr->FindDeclaredConstant(original_array_type->LengthId());
assert(length_const->AsIntConstant());
uint32_t array_length = length_const->AsIntConstant()->GetU32();
for (uint32_t i = 0; i < array_length; i++) {
Instruction* extract = ir_builder.AddCompositeExtract(
original_element_type_id, object_to_copy->result_id(), {i});
element_ids.push_back(
GenerateCopy(extract, new_element_type_id, insertion_position));
Instruction* original_type = get_def_use_mgr()->GetDef(original_type_id);
Instruction* new_type = get_def_use_mgr()->GetDef(new_type_id);
assert(new_type->opcode() == original_type->opcode() &&
"Can't copy an aggragate type unless the type correspond.");

switch (original_type->opcode()) {
case spv::Op::OpTypeArray: {
uint32_t original_element_type_id =
original_type->GetSingleWordInOperand(0);
uint32_t new_element_type_id = new_type->GetSingleWordInOperand(0);

std::vector<uint32_t> element_ids;
uint32_t length_id = original_type->GetSingleWordInOperand(1);
const analysis::Constant* length_const =
const_mgr->FindDeclaredConstant(length_id);
assert(length_const->AsIntConstant());
uint32_t array_length = length_const->AsIntConstant()->GetU32();
for (uint32_t i = 0; i < array_length; i++) {
Instruction* extract = ir_builder.AddCompositeExtract(
original_element_type_id, object_to_copy->result_id(), {i});
element_ids.push_back(
GenerateCopy(extract, new_element_type_id, insertion_position));
}

return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
->result_id();
}

return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
->result_id();
} else if (const analysis::Struct* original_struct_type =
original_type->AsStruct()) {
analysis::Struct* new_struct_type = new_type->AsStruct();

const std::vector<const analysis::Type*>& original_types =
original_struct_type->element_types();
const std::vector<const analysis::Type*>& new_types =
new_struct_type->element_types();
std::vector<uint32_t> element_ids;
for (uint32_t i = 0; i < original_types.size(); i++) {
Instruction* extract = ir_builder.AddCompositeExtract(
type_mgr->GetId(original_types[i]), object_to_copy->result_id(), {i});
element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]),
insertion_position));
case spv::Op::OpTypeStruct: {
std::vector<uint32_t> element_ids;
for (uint32_t i = 0; i < original_type->NumInOperands(); i++) {
uint32_t orig_member_type_id = original_type->GetSingleWordInOperand(i);
uint32_t new_member_type_id = new_type->GetSingleWordInOperand(i);
Instruction* extract = ir_builder.AddCompositeExtract(
orig_member_type_id, object_to_copy->result_id(), {i});
element_ids.push_back(
GenerateCopy(extract, new_member_type_id, insertion_position));
}
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
->result_id();
}
return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
->result_id();
} else {
// If we do not have an aggregate type, then we have a problem. Either we
// found multiple instances of the same type, or we are copying to an
// incompatible type. Either way the code is illegal.
assert(false &&
"Don't know how to copy this type. Code is likely illegal.");
default:
// If we do not have an aggregate type, then we have a problem. Either we
// found multiple instances of the same type, or we are copying to an
// incompatible type. Either way the code is illegal.
assert(false &&
"Don't know how to copy this type. Code is likely illegal.");
}
return 0;
}
Expand Down
34 changes: 34 additions & 0 deletions test/opt/fix_storage_class_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,40 @@ OpFunctionEnd
SinglePassRunAndCheck<FixStorageClass>(text, text, false, false);
}

// Tests that the pass is not confused when there are multiple definitions
// of a pointer type to the same type with the same storage class.
TEST_F(FixStorageClassTest, DuplicatePointerType) {
const std::string text = R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 64 1 1
OpSource HLSL 600
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_3 = OpConstant %uint 3
%_arr_uint_uint_3 = OpTypeArray %uint %uint_3
%void = OpTypeVoid
%7 = OpTypeFunction %void
%_struct_8 = OpTypeStruct %_arr_uint_uint_3
%_ptr_Function__struct_8 = OpTypePointer Function %_struct_8
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Function__arr_uint_uint_3 = OpTypePointer Function %_arr_uint_uint_3
%_ptr_Function_uint_0 = OpTypePointer Function %uint
%_ptr_Function__ptr_Function_uint_0 = OpTypePointer Function %_ptr_Function_uint_0
%1 = OpFunction %void None %7
%14 = OpLabel
%15 = OpVariable %_ptr_Function__ptr_Function_uint_0 Function
%16 = OpVariable %_ptr_Function__struct_8 Function
%17 = OpAccessChain %_ptr_Function__arr_uint_uint_3 %16 %uint_0
%18 = OpAccessChain %_ptr_Function_uint_0 %17 %uint_0
OpStore %15 %18
OpReturn
OpFunctionEnd
)";

SinglePassRunAndCheck<FixStorageClass>(text, text, false);
}

} // namespace
} // namespace opt
} // namespace spvtools

0 comments on commit 81a1160

Please sign in to comment.