Skip to content

Commit e0c9de8

Browse files
authored
Initial implementation of SPV_KHR_cooperative_matrix extension (KhronosGroup#2099)
The intention is to replace existing SPV_INTEL_joint_matrix extension to the Khronos one in future. Spec: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc
1 parent eb051c7 commit e0c9de8

15 files changed

+293
-3
lines changed

include/LLVMSPIRVExtensions.inc

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ EXT(SPV_KHR_uniform_group_instructions)
1414
EXT(SPV_KHR_subgroup_rotate)
1515
EXT(SPV_KHR_non_semantic_info)
1616
EXT(SPV_KHR_shader_clock)
17+
EXT(SPV_KHR_cooperative_matrix)
1718
EXT(SPV_INTEL_subgroups)
1819
EXT(SPV_INTEL_media_block_io)
1920
EXT(SPV_INTEL_device_side_avc_motion_estimation)

lib/SPIRV/OCLUtil.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
899899
return SPIRV_SAMPLER_T_ADDR_SPACE;
900900
case internal::OpTypeJointMatrixINTEL:
901901
case internal::OpTypeJointMatrixINTELv2:
902+
case OpTypeCooperativeMatrixKHR:
902903
return SPIRAS_Global;
903904
default:
904905
if (isSubgroupAvcINTELTypeOpCode(OpCode))

lib/SPIRV/SPIRVInternal.h

+2
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ const static char PipeStorage[] = "PipeStorage";
311311
const static char ConstantPipeStorage[] = "ConstantPipeStorage";
312312
const static char VmeImageINTEL[] = "VmeImageINTEL";
313313
const static char JointMatrixINTEL[] = "JointMatrixINTEL";
314+
const static char CooperativeMatrixKHR[] = "CooperativeMatrixKHR";
314315
const static char BufferSurfaceINTEL[] = "BufferSurfaceINTEL";
315316
} // namespace kSPIRVTypeName
316317

@@ -957,6 +958,7 @@ template <> inline void SPIRVMap<std::string, Op, SPIRVOpaqueType>::init() {
957958
_SPIRV_OP(AvcSicResultINTEL)
958959
_SPIRV_OP(VmeImageINTEL)
959960
_SPIRV_OP(BufferSurfaceINTEL)
961+
_SPIRV_OP(CooperativeMatrixKHR)
960962
#undef _SPIRV_OP
961963
add("JointMatrixINTEL", internal::OpTypeJointMatrixINTEL);
962964
}

lib/SPIRV/SPIRVReader.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,22 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
469469
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
470470
transType(MT->getCompType()), Params));
471471
}
472+
case OpTypeCooperativeMatrixKHR: {
473+
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(T);
474+
unsigned Scope =
475+
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue();
476+
unsigned Rows =
477+
static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
478+
unsigned Cols =
479+
static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
480+
unsigned Use =
481+
static_cast<SPIRVConstant *>(MT->getUse())->getZExtIntValue();
482+
483+
std::vector<unsigned> Params = {Scope, Rows, Cols, Use};
484+
return mapType(
485+
T, llvm::TargetExtType::get(*Context, "spirv.CooperativeMatrixKHR",
486+
transType(MT->getCompType()), Params));
487+
}
472488
case OpTypeForwardPointer: {
473489
SPIRVTypeForwardPointer *FP =
474490
static_cast<SPIRVTypeForwardPointer *>(static_cast<SPIRVEntry *>(T));
@@ -2217,6 +2233,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
22172233
return mapValue(BV, ConstantStruct::get(ST, CV));
22182234
}
22192235
case internal::OpTypeJointMatrixINTEL:
2236+
case OpTypeCooperativeMatrixKHR:
22202237
return mapValue(BV, transSPIRVBuiltinFromInst(CC, BB));
22212238
default:
22222239
llvm_unreachable("Unhandled type!");
@@ -3293,6 +3310,7 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
32933310
case OpUDotAccSatKHR:
32943311
case OpSUDotAccSatKHR:
32953312
case internal::OpJointMatrixLoadINTEL:
3313+
case OpCooperativeMatrixLoadKHR:
32963314
AddRetTypePostfix = true;
32973315
break;
32983316
default: {

lib/SPIRV/SPIRVWriter.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,17 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
518518
Args.emplace_back(transConstant(getUInt32(M, Op)));
519519
return mapType(T, BM->addJointMatrixINTELType(ElemTy, Args));
520520
}
521+
case OpTypeCooperativeMatrixKHR: {
522+
// The expected representation is:
523+
// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,
524+
// %cols%, %use%)
525+
auto *ElemTy = transType(TargetTy->getTypeParameter(0));
526+
ArrayRef<unsigned> Ops = TargetTy->int_params();
527+
std::vector<SPIRVValue *> Args;
528+
for (const auto &Op : Ops)
529+
Args.emplace_back(transConstant(getUInt32(M, Op)));
530+
return mapType(T, BM->addCooperativeMatrixKHRType(ElemTy, Args));
531+
}
521532
default:
522533
if (isSubgroupAvcINTELTypeOpCode(Opcode))
523534
return mapType(T, BM->addSubgroupAvcINTELType(Opcode));

lib/SPIRV/libSPIRV/SPIRVInstruction.h

+21
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,7 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
19251925
case OpTypeStruct:
19261926
case internal::OpTypeJointMatrixINTEL:
19271927
case internal::OpTypeJointMatrixINTELv2:
1928+
case OpTypeCooperativeMatrixKHR:
19281929
break;
19291930
default:
19301931
assert(false && "Invalid type");
@@ -3386,6 +3387,26 @@ class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase {
33863387
_SPIRV_OP(JointMatrixGetElementCoord, true, 5)
33873388
#undef _SPIRV_OP
33883389

3390+
class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase {
3391+
protected:
3392+
std::optional<ExtensionID> getRequiredExtension() const override {
3393+
return ExtensionID::SPV_KHR_cooperative_matrix;
3394+
}
3395+
SPIRVCapVec getRequiredCapability() const override {
3396+
return getVec(CapabilityCooperativeMatrixKHR);
3397+
}
3398+
};
3399+
3400+
#define _SPIRV_OP(x, ...) \
3401+
typedef SPIRVInstTemplate<SPIRVCooperativeMatrixKHRInstBase, Op##x, \
3402+
__VA_ARGS__> \
3403+
SPIRV##x;
3404+
_SPIRV_OP(CooperativeMatrixLoadKHR, true, 5, true, 3)
3405+
_SPIRV_OP(CooperativeMatrixStoreKHR, false, 4, true, 4)
3406+
_SPIRV_OP(CooperativeMatrixLengthKHR, true, 4, false)
3407+
_SPIRV_OP(CooperativeMatrixMulAddKHR, true, 6, true, 3)
3408+
#undef _SPIRV_OP
3409+
33893410
class SPIRVSplitBarrierINTELBase : public SPIRVInstTemplateBase {
33903411
protected:
33913412
SPIRVCapVec getRequiredCapability() const override {

lib/SPIRV/libSPIRV/SPIRVModule.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ class SPIRVModuleImpl : public SPIRVModule {
244244
SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) override;
245245
SPIRVTypeJointMatrixINTEL *
246246
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
247+
SPIRVTypeCooperativeMatrixKHR *
248+
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) override;
247249
SPIRVType *addOpaqueGenericType(Op) override;
248250
SPIRVTypeDeviceEvent *addDeviceEventType() override;
249251
SPIRVTypeQueue *addQueueType() override;
@@ -982,6 +984,13 @@ SPIRVModuleImpl::addJointMatrixINTELType(SPIRVType *CompType,
982984
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Args));
983985
}
984986

987+
SPIRVTypeCooperativeMatrixKHR *
988+
SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType,
989+
std::vector<SPIRVValue *> Args) {
990+
return addType(
991+
new SPIRVTypeCooperativeMatrixKHR(this, getId(), CompType, Args));
992+
}
993+
985994
SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
986995
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
987996
}

lib/SPIRV/libSPIRV/SPIRVModule.h

+3
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class SPIRVAsmCallINTEL;
9696
class SPIRVTypeBufferSurfaceINTEL;
9797
class SPIRVTypeTokenINTEL;
9898
class SPIRVTypeJointMatrixINTEL;
99+
class SPIRVTypeCooperativeMatrixKHR;
99100

100101
typedef SPIRVBasicBlock SPIRVLabel;
101102
struct SPIRVTypeImageDescriptor;
@@ -255,6 +256,8 @@ class SPIRVModule {
255256
virtual SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) = 0;
256257
virtual SPIRVTypeJointMatrixINTEL *
257258
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
259+
virtual SPIRVTypeCooperativeMatrixKHR *
260+
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
258261
virtual SPIRVTypeVoid *addVoidType() = 0;
259262
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
260263
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

+1
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
603603
add(CapabilityDotProduct, "DotProduct");
604604
add(CapabilityDotProductKHR, "DotProductKHR");
605605
add(CapabilityRayCullMaskKHR, "RayCullMaskKHR");
606+
add(CapabilityCooperativeMatrixKHR, "CooperativeMatrixKHR");
606607
add(CapabilityBitInstructions, "BitInstructions");
607608
add(CapabilityGroupNonUniformRotateKHR, "GroupNonUniformRotateKHR");
608609
add(CapabilityAtomicFloat32AddEXT, "AtomicFloat32AddEXT");

lib/SPIRV/libSPIRV/SPIRVOpCode.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ inline bool isTypeOpCode(Op OpCode) {
225225
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
226226
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL ||
227227
OC == internal::OpTypeJointMatrixINTEL ||
228-
OC == internal::OpTypeJointMatrixINTELv2;
228+
OC == internal::OpTypeJointMatrixINTELv2 ||
229+
OC == OpTypeCooperativeMatrixKHR;
229230
}
230231

231232
inline bool isSpecConstantOpCode(Op OpCode) {

lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

+5
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,11 @@ _SPIRV_OP(SUDotKHR, 4452)
335335
_SPIRV_OP(SDotAccSatKHR, 4453)
336336
_SPIRV_OP(UDotAccSatKHR, 4454)
337337
_SPIRV_OP(SUDotAccSatKHR, 4455)
338+
_SPIRV_OP(TypeCooperativeMatrixKHR, 4456)
339+
_SPIRV_OP(CooperativeMatrixLoadKHR, 4457)
340+
_SPIRV_OP(CooperativeMatrixStoreKHR, 4458)
341+
_SPIRV_OP(CooperativeMatrixMulAddKHR, 4459)
342+
_SPIRV_OP(CooperativeMatrixLengthKHR, 4460)
338343
_SPIRV_OP(ReadClockKHR, 5056)
339344
_SPIRV_OP(SubgroupShuffleINTEL, 5571)
340345
_SPIRV_OP(SubgroupShuffleDownINTEL, 5572)

lib/SPIRV/libSPIRV/SPIRVType.cpp

+28-1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ SPIRVType *SPIRVType::getVectorComponentType() const {
115115
return static_cast<const SPIRVTypeVector *>(this)->getComponentType();
116116
if (OpCode == internal::OpTypeJointMatrixINTEL)
117117
return static_cast<const SPIRVTypeJointMatrixINTEL *>(this)->getCompType();
118+
if (OpCode == OpTypeCooperativeMatrixKHR)
119+
return static_cast<const SPIRVTypeCooperativeMatrixKHR *>(this)
120+
->getCompType();
118121
assert(0 && "getVectorComponentType(): Not a vector or joint matrix type");
119122
return nullptr;
120123
}
@@ -156,7 +159,7 @@ bool SPIRVType::isTypeBool() const { return OpCode == OpTypeBool; }
156159

157160
bool SPIRVType::isTypeComposite() const {
158161
return isTypeVector() || isTypeArray() || isTypeStruct() ||
159-
isTypeJointMatrixINTEL();
162+
isTypeJointMatrixINTEL() || isTypeCooperativeMatrixKHR();
160163
}
161164

162165
bool SPIRVType::isTypeFloat(unsigned Bits) const {
@@ -203,6 +206,10 @@ bool SPIRVType::isTypeJointMatrixINTEL() const {
203206
OpCode == internal::OpTypeJointMatrixINTELv2;
204207
}
205208

209+
bool SPIRVType::isTypeCooperativeMatrixKHR() const {
210+
return OpCode == OpTypeCooperativeMatrixKHR;
211+
}
212+
206213
bool SPIRVType::isTypeVectorBool() const {
207214
return isTypeVector() && getVectorComponentType()->isTypeBool();
208215
}
@@ -306,4 +313,24 @@ void SPIRVTypeJointMatrixINTEL::decode(std::istream &I) {
306313
Decoder >> Id >> CompType >> Args;
307314
}
308315

316+
SPIRVTypeCooperativeMatrixKHR::SPIRVTypeCooperativeMatrixKHR(
317+
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
318+
std::vector<SPIRVValue *> Args)
319+
: SPIRVType(M, FixedWC, OpTypeCooperativeMatrixKHR, TheId),
320+
CompType(CompType), Args(std::move(Args)) {}
321+
322+
SPIRVTypeCooperativeMatrixKHR::SPIRVTypeCooperativeMatrixKHR()
323+
: SPIRVType(OpTypeCooperativeMatrixKHR), CompType(nullptr),
324+
Args({nullptr, nullptr, nullptr, nullptr}) {}
325+
326+
void SPIRVTypeCooperativeMatrixKHR::encode(spv_ostream &O) const {
327+
auto Encoder = getEncoder(O);
328+
Encoder << Id << CompType << Args;
329+
}
330+
331+
void SPIRVTypeCooperativeMatrixKHR::decode(std::istream &I) {
332+
auto Decoder = getDecoder(I);
333+
Decoder >> Id >> CompType >> Args;
334+
}
335+
309336
} // namespace SPIRV

lib/SPIRV/libSPIRV/SPIRVType.h

+29
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class SPIRVType : public SPIRVEntry {
9696
bool isTypeStruct() const;
9797
bool isTypeVector() const;
9898
bool isTypeJointMatrixINTEL() const;
99+
bool isTypeCooperativeMatrixKHR() const;
99100
bool isTypeVectorInt() const;
100101
bool isTypeVectorFloat() const;
101102
bool isTypeVectorBool() const;
@@ -1115,5 +1116,33 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
11151116
}
11161117
};
11171118

1119+
class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
1120+
SPIRVType *CompType;
1121+
std::vector<SPIRVValue *> Args;
1122+
1123+
public:
1124+
const static Op OC = OpTypeCooperativeMatrixKHR;
1125+
const static SPIRVWord FixedWC = 7;
1126+
// Incomplete constructor
1127+
SPIRVTypeCooperativeMatrixKHR(SPIRVModule *M, SPIRVId TheId,
1128+
SPIRVType *CompType,
1129+
std::vector<SPIRVValue *> Args);
1130+
// Incomplete constructor
1131+
SPIRVTypeCooperativeMatrixKHR();
1132+
_SPIRV_DCL_ENCDEC
1133+
std::optional<ExtensionID> getRequiredExtension() const override {
1134+
return ExtensionID::SPV_KHR_cooperative_matrix;
1135+
}
1136+
SPIRVCapVec getRequiredCapability() const override {
1137+
return getVec(CapabilityCooperativeMatrixKHR);
1138+
}
1139+
1140+
SPIRVType *getCompType() const { return CompType; }
1141+
SPIRVValue *getScope() const { return Args[0]; }
1142+
SPIRVValue *getRows() const { return Args[1]; }
1143+
SPIRVValue *getColumns() const { return Args[2]; }
1144+
SPIRVValue *getUse() const { return Args[3]; }
1145+
};
1146+
11181147
} // namespace SPIRV
11191148
#endif // SPIRV_LIBSPIRV_SPIRVTYPE_H

spirv-headers-tag.conf

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1feaf4414eb2b353764d01d88f8aa4bcc67b60db
1+
9b527c0fb60124936d0906d44803bec51a0200fb

0 commit comments

Comments
 (0)