@@ -115,6 +115,9 @@ SPIRVType *SPIRVType::getVectorComponentType() const {
115
115
return static_cast <const SPIRVTypeVector *>(this )->getComponentType ();
116
116
if (OpCode == internal::OpTypeJointMatrixINTEL)
117
117
return static_cast <const SPIRVTypeJointMatrixINTEL *>(this )->getCompType ();
118
+ if (OpCode == OpTypeCooperativeMatrixKHR)
119
+ return static_cast <const SPIRVTypeCooperativeMatrixKHR *>(this )
120
+ ->getCompType ();
118
121
assert (0 && " getVectorComponentType(): Not a vector or joint matrix type" );
119
122
return nullptr ;
120
123
}
@@ -156,7 +159,7 @@ bool SPIRVType::isTypeBool() const { return OpCode == OpTypeBool; }
156
159
157
160
bool SPIRVType::isTypeComposite () const {
158
161
return isTypeVector () || isTypeArray () || isTypeStruct () ||
159
- isTypeJointMatrixINTEL ();
162
+ isTypeJointMatrixINTEL () || isTypeCooperativeMatrixKHR () ;
160
163
}
161
164
162
165
bool SPIRVType::isTypeFloat (unsigned Bits) const {
@@ -203,6 +206,10 @@ bool SPIRVType::isTypeJointMatrixINTEL() const {
203
206
OpCode == internal::OpTypeJointMatrixINTELv2;
204
207
}
205
208
209
+ bool SPIRVType::isTypeCooperativeMatrixKHR () const {
210
+ return OpCode == OpTypeCooperativeMatrixKHR;
211
+ }
212
+
206
213
bool SPIRVType::isTypeVectorBool () const {
207
214
return isTypeVector () && getVectorComponentType ()->isTypeBool ();
208
215
}
@@ -306,4 +313,24 @@ void SPIRVTypeJointMatrixINTEL::decode(std::istream &I) {
306
313
Decoder >> Id >> CompType >> Args;
307
314
}
308
315
316
+ SPIRVTypeCooperativeMatrixKHR::SPIRVTypeCooperativeMatrixKHR (
317
+ SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
318
+ std::vector<SPIRVValue *> Args)
319
+ : SPIRVType(M, FixedWC, OpTypeCooperativeMatrixKHR, TheId),
320
+ CompType(CompType), Args(std::move(Args)) {}
321
+
322
+ SPIRVTypeCooperativeMatrixKHR::SPIRVTypeCooperativeMatrixKHR ()
323
+ : SPIRVType(OpTypeCooperativeMatrixKHR), CompType(nullptr ),
324
+ Args({nullptr , nullptr , nullptr , nullptr }) {}
325
+
326
+ void SPIRVTypeCooperativeMatrixKHR::encode (spv_ostream &O) const {
327
+ auto Encoder = getEncoder (O);
328
+ Encoder << Id << CompType << Args;
329
+ }
330
+
331
+ void SPIRVTypeCooperativeMatrixKHR::decode (std::istream &I) {
332
+ auto Decoder = getDecoder (I);
333
+ Decoder >> Id >> CompType >> Args;
334
+ }
335
+
309
336
} // namespace SPIRV
0 commit comments