Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Add support for inline SPIR-V types #125316

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,32 @@ parameters of its underlying image type, so that a sampled image for the
previous type has the representation
``target("spirv.SampledImage, void, 1, 1, 0, 0, 0, 0, 0)``.

.. _inline-spirv-types:

Inline SPIR-V Types
-------------------

HLSL allows users to create types representing specific SPIR-V types, using ``vk::SpirvType`` and
``vk::SpirvOpaqueType``. These are specified in the `Inline SPIR-V`_ proposal. They may be
represented using target extension types:

.. _Inline SPIR-V: https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html#types

.. table:: Inline SPIR-V Types

========================== =================== ==============================
LLVM type name LLVM type arguments LLVM integer arguments
========================== =================== ==============================
``spirv.Type`` SPIR-V operands opcode, size, alignment
``spirv.IntegralConstant`` integral type value
``spirv.Literal`` (none) value
========================== =================== ==============================

The operand arguments to ``spirv.Type`` may be either a ``spirv.IntegralConstant`` type,
representing an ``OpConstant`` id operand, a ``spirv.Literal`` type, representing an immediate
literal operand, or any other type, representing the id of that type as an operand.
``spirv.IntegralConstant`` and ``spirv.Literal`` may not be used outside of this context.

s-perron marked this conversation as resolved.
Show resolved Hide resolved
.. _spirv-intrinsics:

Target Intrinsics
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/IR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,26 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
if (Name == "spirv.Image")
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::CanBeGlobal,
TargetExtType::CanBeLocal);
if (Name == "spirv.Type") {
assert(Ty->getNumIntParameters() == 3 &&
"Wrong number of parameters for spirv.Type");

auto Size = Ty->getIntParameter(1);
auto Alignment = Ty->getIntParameter(2);

// LLVM expects variables that can be allocated to have an alignment and
// size. Default to using a 32-bit int as the layout type if none are
// present.
llvm::Type *LayoutType = Type::getInt32Ty(C);
if (Size > 0 && Alignment > 0)
LayoutType =
ArrayType::get(Type::getIntNTy(C, Alignment), Size * 8 / Alignment);

return TargetTypeInfo(LayoutType, TargetExtType::CanBeGlobal,
TargetExtType::CanBeLocal);
}
if (Name == "spirv.IntegralConstant" || Name == "spirv.Literal")
return TargetTypeInfo(Type::getVoidTy(C));
if (Name.starts_with("spirv."))
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::HasZeroInit,
TargetExtType::CanBeGlobal,
Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
recordOpExtInstImport(MI);
} else if (OpCode == SPIRV::OpExtInst) {
printOpExtInst(MI, OS);
} else if (OpCode == SPIRV::UNKNOWN_type) {
printUnknownType(MI, OS);
} else {
// Print any extra operands for variadic instructions.
const MCInstrDesc &MCDesc = MII.get(OpCode);
Expand Down Expand Up @@ -314,6 +316,34 @@ void SPIRVInstPrinter::printOpDecorate(const MCInst *MI, raw_ostream &O) {
}
}

void SPIRVInstPrinter::printUnknownType(const MCInst *MI, raw_ostream &O) {
const auto EnumOperand = MI->getOperand(1);
assert(EnumOperand.isImm() &&
"second operand of UNKNOWN_type must be opcode!");

const auto Enumerant = EnumOperand.getImm();
const auto NumOps = MI->getNumOperands();

// Encode the instruction enumerant and word count into the opcode
const auto OpCode = (0xFF & NumOps) << 16 | (0xFF & Enumerant);

// Print the opcode using the spirv-as unknown opcode syntax
O << "OpUnknown(" << Enumerant << ", " << NumOps << ") ";

// The result ID must be printed after the opcode when using this syntax
printOperand(MI, 0, O);

O << " ";

const MCInstrDesc &MCDesc = MII.get(MI->getOpcode());
unsigned NumFixedOps = MCDesc.getNumOperands();
if (NumOps == NumFixedOps)
return;

// Print the rest of the operands
printRemainingVariableOps(MI, NumFixedOps, O, true);
}

static void printExpr(const MCExpr *Expr, raw_ostream &O) {
#ifndef NDEBUG
const MCSymbolRefExpr *SRE;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SPIRVInstPrinter : public MCInstPrinter {

void printOpDecorate(const MCInst *MI, raw_ostream &O);
void printOpExtInst(const MCInst *MI, raw_ostream &O);
void printUnknownType(const MCInst *MI, raw_ostream &O);
void printRemainingVariableOps(const MCInst *MI, unsigned StartIndex,
raw_ostream &O, bool SkipFirstSpace = false,
bool SkipImmediates = false);
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class SPIRVMCCodeEmitter : public MCCodeEmitter {
void encodeInstruction(const MCInst &MI, SmallVectorImpl<char> &CB,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const override;
void encodeUnknownType(const MCInst &MI, SmallVectorImpl<char> &CB,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const;
};

} // end anonymous namespace
Expand Down Expand Up @@ -104,10 +107,32 @@ static void emitUntypedInstrOperands(const MCInst &MI,
emitOperand(Op, CB);
}

void SPIRVMCCodeEmitter::encodeUnknownType(const MCInst &MI,
SmallVectorImpl<char> &CB,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const {
// Encode the first 32 SPIR-V bytes with the number of args and the opcode.
const uint64_t OpCode = MI.getOperand(1).getImm();
const uint32_t NumWords = MI.getNumOperands();
const uint32_t FirstWord = (0xFF & NumWords) << 16 | (0xFF & OpCode);

// encoding: <opcode+len> <result type> [<operand0> <operand1> ...]
support::endian::write(CB, FirstWord, llvm::endianness::little);

emitOperand(MI.getOperand(0), CB);
for (unsigned i = 2; i < NumWords; ++i)
emitOperand(MI.getOperand(i), CB);
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved
}

void SPIRVMCCodeEmitter::encodeInstruction(const MCInst &MI,
SmallVectorImpl<char> &CB,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const {
if (MI.getOpcode() == SPIRV::UNKNOWN_type) {
encodeUnknownType(MI, CB, Fixups, STI);
return;
}

// Encode the first 32 SPIR-V bytes with the number of args and the opcode.
const uint64_t OpCode = getBinaryCodeForInstr(MI, Fixups, STI);
const uint32_t NumWords = MI.getNumOperands() + 1;
Expand Down
120 changes: 88 additions & 32 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2868,6 +2868,56 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
}

static SPIRVType *getInlineSpirvType(const TargetExtType *ExtensionType,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
assert(ExtensionType->getNumIntParameters() == 3 &&
"Inline SPIR-V type builtin takes an opcode, size, and alignment "
"parameter");
auto Opcode = ExtensionType->getIntParameter(0);

SmallVector<MCOperand> Operands;
for (llvm::Type *Param : ExtensionType->type_params()) {
if (const TargetExtType *ParamEType = dyn_cast<TargetExtType>(Param)) {
if (ParamEType->getName() == "spirv.IntegralConstant") {
assert(ParamEType->getNumTypeParameters() == 1 &&
"Inline SPIR-V integral constant builtin must have a type "
"parameter");
assert(ParamEType->getNumIntParameters() == 1 &&
"Inline SPIR-V integral constant builtin must have a "
"value parameter");

auto OperandValue = ParamEType->getIntParameter(0);
auto *OperandType = ParamEType->getTypeParameter(0);

const SPIRVType *OperandSPIRVType =
GR->getOrCreateSPIRVType(OperandType, MIRBuilder);

Operands.push_back(MCOperand::createReg(GR->buildConstantInt(
OperandValue, MIRBuilder, OperandSPIRVType, true)));
continue;
} else if (ParamEType->getName() == "spirv.Literal") {
assert(ParamEType->getNumTypeParameters() == 0 &&
"Inline SPIR-V literal builtin does not take type "
"parameters");
assert(ParamEType->getNumIntParameters() == 1 &&
"Inline SPIR-V literal builtin must have an integer "
"parameter");

auto OperandValue = ParamEType->getIntParameter(0);

Operands.push_back(MCOperand::createImm(OperandValue));
continue;
}
}
const SPIRVType *TypeOperand = GR->getOrCreateSPIRVType(Param, MIRBuilder);
Operands.push_back(MCOperand::createReg(GR->getSPIRVTypeID(TypeOperand)));
}

return GR->getOrCreateUnknownType(ExtensionType, MIRBuilder, Opcode,
Operands);
}

namespace SPIRV {
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
LLVMContext &Context) {
Expand Down Expand Up @@ -2940,39 +2990,45 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
const StringRef Name = BuiltinType->getName();
LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");

// Lookup the demangled builtin type in the TableGen records.
const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
if (!TypeRecord)
report_fatal_error("Missing TableGen record for builtin type: " + Name);

// "Lower" the BuiltinType into TargetType. The following get<...>Type methods
// use the implementation details from TableGen records or TargetExtType
// parameters to either create a new OpType<...> machine instruction or get an
// existing equivalent SPIRVType from GlobalRegistry.
SPIRVType *TargetType;
switch (TypeRecord->Opcode) {
case SPIRV::OpTypeImage:
TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
break;
case SPIRV::OpTypePipe:
TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
break;
case SPIRV::OpTypeDeviceEvent:
TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
break;
case SPIRV::OpTypeSampler:
TargetType = getSamplerType(MIRBuilder, GR);
break;
case SPIRV::OpTypeSampledImage:
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
break;
case SPIRV::OpTypeCooperativeMatrixKHR:
TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
break;
default:
TargetType =
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
break;
if (Name == "spirv.Type") {
TargetType = getInlineSpirvType(BuiltinType, MIRBuilder, GR);
} else {
// Lookup the demangled builtin type in the TableGen records.
const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
if (!TypeRecord)
report_fatal_error("Missing TableGen record for builtin type: " + Name);

// "Lower" the BuiltinType into TargetType. The following get<...>Type
// methods use the implementation details from TableGen records or
// TargetExtType parameters to either create a new OpType<...> machine
// instruction or get an existing equivalent SPIRVType from
// GlobalRegistry.

switch (TypeRecord->Opcode) {
case SPIRV::OpTypeImage:
TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
break;
case SPIRV::OpTypePipe:
TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
break;
case SPIRV::OpTypeDeviceEvent:
TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
break;
case SPIRV::OpTypeSampler:
TargetType = getSamplerType(MIRBuilder, GR);
break;
case SPIRV::OpTypeSampledImage:
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
break;
case SPIRV::OpTypeCooperativeMatrixKHR:
TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
break;
default:
TargetType =
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
break;
}
}

// Emit OpName instruction if a new OpType<...> instruction was added
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,28 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
return SpirvTy;
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
const ArrayRef<MCOperand> Operands) {
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
ResVReg = createTypeVReg(MIRBuilder);

MachineInstrBuilder MIB =
MIRBuilder.buildInstr(SPIRV::UNKNOWN_type).addDef(ResVReg).addImm(Opcode);
for (MCOperand Operand : Operands) {
if (Operand.isReg()) {
MIB.addUse(Operand.getReg());
} else if (Operand.isImm()) {
MIB.addImm(Operand.getImm());
}
}

DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
return MIB;
}

const MachineInstr *
SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
MachineIRBuilder &MIRBuilder) {
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,11 @@ class SPIRVGlobalRegistry {
MachineIRBuilder &MIRBuilder,
unsigned Opcode);

SPIRVType *getOrCreateUnknownType(const Type *Ty,
MachineIRBuilder &MIRBuilder,
unsigned Opcode,
const ArrayRef<MCOperand> Operands);

const TargetRegisterClass *getRegClass(SPIRVType *SpvType) const;
LLT getRegType(SPIRVType *SpvType) const;
};
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
let Pattern = pattern;
}

class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
: Op<0, outs, ins, asmstr, pattern> {
let isPseudo = 1;
}

// Pseudo instructions
class Pseudo<dag outs, dag ins> : Op<0, outs, ins, ""> {
let isPseudo = 1;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ let isCodeGenOnly=1 in {
def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins vpID:$src)>;
}

def UNKNOWN_type
: UnknownOp<(outs TYPE:$type), (ins i32imm:$opcode, variable_ops), " ">;

def SPVTypeBin : SDTypeProfile<1, 2, []>;

def assigntype : SDNode<"SPIRVISD::AssignType", SPVTypeBin>;
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
bool HasDefs = I.getNumDefs() > 0;
Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0);
SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF);
if (spvSelect(ResVReg, ResType, I)) {
if (HasDefs) // Make all vregs 64 bits (for SPIR-V IDs).
for (unsigned i = 0; i < I.getNumDefs(); ++i)
Expand Down
Loading