Skip to content

Commit

Permalink
[SPIR-V] Improve type inference for a known instruction's builtin: Op…
Browse files Browse the repository at this point in the history
…GroupAsyncCopy (llvm#96895)

This PR improves type inference for a known instruction's builtin:
OpGroupAsyncCopy:
* deduce a type of one source/destination pointer when it's possible to
deduce a type of another argument, and
* validate src and dest types and tries to unfold a parameter if it's a
structure wrapper around a scalar/vector type.
  • Loading branch information
VyacheslavLevytskyy authored Jul 3, 2024
1 parent 1db4221 commit bf9e9e5
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 9 deletions.
85 changes: 81 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,16 @@ lookupBuiltin(StringRef DemangledCall,
SPIRV::InstructionSet::InstructionSet Set,
Register ReturnRegister, const SPIRVType *ReturnType,
const SmallVectorImpl<Register> &Arguments) {
const static std::string PassPrefix = "(anonymous namespace)::";
std::string BuiltinName;
// Itanium Demangler result may have "(anonymous namespace)::" prefix
if (DemangledCall.starts_with(PassPrefix.c_str()))
BuiltinName = DemangledCall.substr(PassPrefix.length());
else
BuiltinName = DemangledCall;
// Extract the builtin function name and types of arguments from the call
// skeleton.
std::string BuiltinName =
DemangledCall.substr(0, DemangledCall.find('(')).str();
BuiltinName = BuiltinName.substr(0, BuiltinName.find('('));

// Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
Expand Down Expand Up @@ -2377,9 +2383,80 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
return true;
}

/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
/// and external instruction \p Set.
namespace SPIRV {
// Try to find a builtin function attributes by a demangled function name and
// return a tuple <builtin group, op code, ext instruction number>, or a special
// tuple value <-1, 0, 0> if the builtin function is not found.
// Not all builtin functions are supported, only those with a ready-to-use op
// code or instruction number defined in TableGen.
// TODO: consider a major rework of mapping demangled calls into a builtin
// functions to unify search and decrease number of individual cases.
std::tuple<int, unsigned, unsigned>
mapBuiltinToOpcode(const StringRef DemangledCall,
SPIRV::InstructionSet::InstructionSet Set) {
Register Reg;
SmallVector<Register> Args;
std::unique_ptr<const IncomingCall> Call =
lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
if (!Call)
return std::make_tuple(-1, 0, 0);

switch (Call->Builtin->Group) {
case SPIRV::Relational:
case SPIRV::Atomic:
case SPIRV::Barrier:
case SPIRV::CastToPtr:
case SPIRV::ImageMiscQuery:
case SPIRV::SpecConstant:
case SPIRV::Enqueue:
case SPIRV::AsyncCopy:
case SPIRV::LoadStore:
case SPIRV::CoopMatr:
if (const auto *R =
SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::Extended:
if (const auto *R = SPIRV::lookupExtendedBuiltin(Call->Builtin->Name,
Call->Builtin->Set))
return std::make_tuple(Call->Builtin->Group, 0, R->Number);
break;
case SPIRV::VectorLoadStore:
if (const auto *R = SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
Call->Builtin->Set))
return std::make_tuple(SPIRV::Extended, 0, R->Number);
break;
case SPIRV::Group:
if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::AtomicFloating:
if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::IntelSubgroups:
if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::GroupUniform:
if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::WriteImage:
return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
case SPIRV::Select:
return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
case SPIRV::Construct:
return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct,
0);
case SPIRV::KernelClock:
return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
default:
return std::make_tuple(-1, 0, 0);
}
return std::make_tuple(-1, 0, 0);
}

std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
SPIRV::InstructionSet::InstructionSet Set,
MachineIRBuilder &MIRBuilder,
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace llvm {
namespace SPIRV {
/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
/// Lowers a builtin function call using the provided \p DemangledCall skeleton
/// and external instruction \p Set.
///
/// \return the lowering success status if the called function is a recognized
Expand All @@ -38,6 +38,12 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR);

/// Helper function for finding a builtin function attributes
/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
std::tuple<int, unsigned, unsigned>
mapBuiltinToOpcode(const StringRef DemangledCall,
SPIRV::InstructionSet::InstructionSet Set);

/// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
/// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
/// element type (e.g. i8 for char*), or builtin type (TargetExtType).
Expand Down
28 changes: 27 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class SPIRVEmitIntrinsics
DenseMap<Instruction *, Constant *> AggrConsts;
DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
SPIRV::InstructionSet::InstructionSet InstrSet;

// deduce element type of untyped pointers
Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
Expand Down Expand Up @@ -384,9 +385,10 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
auto AsArgIt = ResTypeByArg.find(DemangledName);
if (AsArgIt != ResTypeByArg.end())
if (AsArgIt != ResTypeByArg.end()) {
Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
Visited);
}
}
}

Expand Down Expand Up @@ -544,6 +546,28 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
KnownElemTy = ElemTy1;
Ops.push_back(std::make_pair(Op0, 0));
}
} else if (auto *CI = dyn_cast<CallInst>(I)) {
if (Function *CalledF = CI->getCalledFunction()) {
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
if (DemangledName.length() > 0 &&
!StringRef(DemangledName).starts_with("llvm.")) {
auto [Grp, Opcode, ExtNo] =
SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
if (Opcode == SPIRV::OpGroupAsyncCopy) {
for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
++i) {
Value *Op = CI->getArgOperand(i);
if (!isPointerTy(Op->getType()))
continue;
++PtrCnt;
if (Type *ElemTy = GR->findDeducedElementType(Op))
KnownElemTy = ElemTy; // src will rewrite dest if both are defined
Ops.push_back(std::make_pair(Op, i));
}
}
}
}
}

// There is no enough info to deduce types or all is valid.
Expand Down Expand Up @@ -1385,6 +1409,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {

const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
GR = ST.getSPIRVGlobalRegistry();
InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
: SPIRV::InstructionSet::GLSL_std_450;

F = &Func;
IRBuilder<> B(Func.getContext());
Expand Down
37 changes: 37 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,39 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}

static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI,
SPIRVGlobalRegistry &GR, MachineInstr &I,
unsigned OpIdx) {
MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
Register OpTypeReg = getTypeReg(MRI, OpReg);
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
ElemType->getNumOperands() != 2)
return;
// It's a structure-wrapper around another type with a single member field.
SPIRVType *MemberType =
GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
if (!MemberType)
return;
unsigned MemberTypeOp = MemberType->getOpcode();
if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
return;
// It's a structure-wrapper around a valid type. Insert a bitcast before the
// instruction to keep SPIR-V code valid.
SPIRV::StorageClass::StorageClass SC =
static_cast<SPIRV::StorageClass::StorageClass>(
OpType->getOperand(1).getImm());
MachineIRBuilder MIB(I);
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}

// Insert a bitcast before the function call instruction to keep SPIR-V code
// valid when there is a type mismatch between actual and expected types of an
// argument:
Expand Down Expand Up @@ -380,6 +413,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
case SPIRV::OpGroupAsyncCopy:
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
break;
case SPIRV::OpGroupWaitEvents:
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
validateGroupWaitEventsPtr(STI, MRI, GR, MI);
Expand Down
54 changes: 51 additions & 3 deletions llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
; CHECK-DAG: %[[#TyStruct:]] = OpTypeStruct %[[#TyEvent]]
; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
; CHECK-DAG: %[[#TyEventPtrGen:]] = OpTypePointer Generic %[[#TyEvent]]
; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
; CHECK-DAG: %[[#TyChar:]] = OpTypeInt 8 0
; CHECK-DAG: %[[#TyV4:]] = OpTypeVector %[[#TyChar]] 4
; CHECK-DAG: %[[#TyStructV4:]] = OpTypeStruct %[[#TyV4]]
; CHECK-DAG: %[[#TyPtrSV4_W:]] = OpTypePointer Workgroup %[[#TyStructV4]]
; CHECK-DAG: %[[#TyPtrSV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyStructV4]]
; CHECK-DAG: %[[#TyPtrV4_W:]] = OpTypePointer Workgroup %[[#TyV4]]
; CHECK-DAG: %[[#TyPtrV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyV4]]

; Check correct translation of __spirv_GroupAsyncCopy and target("spirv.Event") zeroinitializer

; CHECK: OpFunction
; CHECK: OpFunctionParameter
; CHECK: %[[#Src:]] = OpFunctionParameter
Expand All @@ -17,12 +28,13 @@
; CHECK: %[[#Dest:]] = OpInBoundsPtrAccessChain
; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
; CHECK: OpFunctionEnd

%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
%StructEvent = type { target("spirv.Event") }

define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) noundef %_arg_local_acc) {
define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
entry:
%var = alloca %"class.sycl::_V1::device_event"
%var = alloca %StructEvent
%dev_event.i.sroa.0 = alloca target("spirv.Event")
%add.ptr.i26 = getelementptr inbounds i32, ptr addrspace(1) %_arg_out_ptr, i64 0
%call3.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %add.ptr.i26, ptr addrspace(3) %_arg_local_acc, i64 16, i64 10, target("spirv.Event") zeroinitializer)
Expand All @@ -31,3 +43,39 @@ entry:
}

declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))

; Check correct type inference when calling __spirv_GroupAsyncCopy:
; we expect that the Backend is able to deduce a type of the %_arg_Local
; given facts that it's possible to deduce a type of the %_arg
; and %_arg_Local and %_arg are source/destination arguments in OpGroupAsyncCopy

; CHECK: OpFunction
; CHECK: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrSV4_W]]
; CHECK: %[[#BarArg2:]] = OpFunctionParameter %[[#TyPtrSV4_CW]]
; CHECK: %[[#EventVarBar:]] = OpVariable %[[#TyStructPtr]] Function
; CHECK: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrSV4_CW]] %[[#BarArg2]] %[[#]]
; CHECK-DAG: %[[#BarArg1Casted:]] = OpBitcast %[[#TyPtrV4_W]] %[[#BarArg1]]
; CHECK-DAG: %[[#SrcBarCasted:]] = OpBitcast %[[#TyPtrV4_CW]] %[[#SrcBar]]
; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1Casted]] %[[#SrcBarCasted]] %[[#]] %[[#]] %[[#ConstEvent]]
; CHECK: %[[#EventVarBarCasted:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
; CHECK: OpStore %[[#EventVarBarCasted]] %[[#ResBar]]
; CHECK: %[[#EventVarBarCasted2:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
; CHECK: %[[#EventVarBarGen:]] = OpPtrCastToGeneric %[[#TyEventPtrGen]] %[[#EventVarBarCasted2]]
; CHECK: OpGroupWaitEvents %[[#]] %[[#]] %[[#EventVarBarGen]]
; CHECK: OpFunctionEnd

%Vec4 = type { <4 x i8> }

define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg) {
entry:
%E1 = alloca %StructEvent
%srcptr = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg, i64 0
%r1 = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %srcptr, i64 16, i64 10, target("spirv.Event") zeroinitializer)
store target("spirv.Event") %r1, ptr %E1
%E.ascast.i = addrspacecast ptr %E1 to ptr addrspace(4)
call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %E.ascast.i)
ret void
}

declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))

0 comments on commit bf9e9e5

Please sign in to comment.