Skip to content

Commit

Permalink
Update LLVM SHA to 7fc792cba7663b2aa54f259515319d74a5625be0 and updat…
Browse files Browse the repository at this point in the history
…e LLVM patches.
  • Loading branch information
silee2 committed Dec 6, 2023
1 parent efa1f09 commit 1b773a9
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 115 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
49af6502c6dcb4a7f7520178bd14df396f78240c
7fc792cba7663b2aa54f259515319d74a5625be0
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ requirement initially, then do the check for capability inferred extension.
14 files changed, 311 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1e61aa747967..6f0f728f811e 100644
index ee1fbba1e284..ee112b3b0099 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4126,7 +4126,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
Expand All @@ -58,21 +58,17 @@ index 1e61aa747967..6f0f728f811e 100644
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
@@ -4186,10 +4191,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
@@ -4206,7 +4211,7 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
"Joint Matrix">;

class SPIRV_ScalarOrVectorOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+ AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0x7FFFFFFFFFFFFFFF], [type]>]>;

class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+ AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0x7FFFFFFFFFFFFFFF], [type]>,
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
class SPIRV_VectorOf<Type type> :
- VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
+ VectorOfLengthRangeAndType<[2, 0x7FFFFFFFFFFFFFFF], [type]>;

class SPIRV_MatrixOrCoopMatrixOf<Type type> :
class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index b0b5348baaad..2c569a651f8b 100644
index 03180a687523..e4f2d5562ed7 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -604,6 +604,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
Expand Down Expand Up @@ -169,10 +165,10 @@ index b0b5348baaad..2c569a651f8b 100644
// Negative values for `n` index in reverse.
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 124d4ed6e8e6..9188f8b699b4 100644
index 8a68decc5878..c6789315ba06 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -183,9 +183,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
@@ -185,9 +185,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
}
Expand All @@ -188,7 +184,7 @@ index 124d4ed6e8e6..9188f8b699b4 100644
return Type();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f9..741d8069471d 100644
index f1bac6490837..4db2c8c5c5d0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
Expand Down Expand Up @@ -230,7 +226,7 @@ index 39d6603a46f9..741d8069471d 100644
capabilities.push_back(ref);
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9..f7a8a2a3d281 100644
index 2b79c8022b8e..b778e4f4daf9 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
Expand Down Expand Up @@ -389,7 +385,7 @@ index c75d217663a9..f7a8a2a3d281 100644
}

static Type
@@ -1150,16 +1236,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
@@ -1162,16 +1248,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
Expand Down Expand Up @@ -434,7 +430,7 @@ index 0d92a8e676d8..d61ace8d6876 100644
}

diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index aa2cd649ecd7..b951d7490d64 100644
index 0221e4815a93..9693f96a3300 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -29,6 +29,7 @@ func.func @int32_scalar(%lhs: i32, %rhs: i32) {
Expand All @@ -445,7 +441,7 @@ index aa2cd649ecd7..b951d7490d64 100644
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
@@ -1362,3 +1363,35 @@ func.func @float_scalar(%arg0: f16) {
@@ -1407,3 +1408,35 @@ func.func @float_scalar(%arg0: f16) {
}

} // end module
Expand Down Expand Up @@ -510,10 +506,10 @@ index 82d750755ffe..6f364c5b0875 100644
} // end module

diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index eeaa607b5604..78e2fffda755 100644
index 82a2316f6c78..b3ad053baa3a 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -97,7 +97,7 @@ func.func @bitwise_or_vector(%arg: vector<4xi32>) -> vector<4xi32> {
@@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
// -----

func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
Expand All @@ -522,7 +518,7 @@ index eeaa607b5604..78e2fffda755 100644
%0 = spirv.BitwiseOr %arg0, %arg1 : f16
return %0 : f16
}
@@ -123,7 +123,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
@@ -163,7 +163,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
// -----

func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
Expand All @@ -531,7 +527,7 @@ index eeaa607b5604..78e2fffda755 100644
%0 = spirv.BitwiseXor %arg0, %arg1 : f16
return %0 : f16
}
@@ -149,7 +149,7 @@ func.func @bitwise_and_vector(%arg: vector<4xi32>) -> vector<4xi32> {
@@ -272,7 +272,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
// -----

func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {
Expand Down Expand Up @@ -567,7 +563,7 @@ index 7dc0bd99f54b..5dd9901828cd 100644
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 29a4a4613615..869de34c83b1 100644
index 81ba471d3f51..2dbebb2db98e 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
Expand Down Expand Up @@ -682,4 +678,4 @@ index 9a2e4cf62e37..31a7f616d648 100644
// CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32
--
2.42.0
2.34.1
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ Date: Mon, 24 Jul 2023 18:25:05 +0000
Subject: [PATCH 2/2] Add serialization and de-serialization support for
several decorations.

Added decoratios:
Added decorations:
- Alignment
- DescriptorSet
- FuncParamIOKindINTEL
- NoSignedWrap
- NoUnsignedWrap
- SingleElementVectorINTEL
- VectorComputeCallableFunctionINTEL
- VectorComputeFunctionINTEL
Expand All @@ -20,36 +18,32 @@ Added decoratios:
2 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b84d1d9c2187..90416289134b 100644
index 89e2e7ad52fa..f6bdc646f384 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -242,8 +242,9 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
auto symbol = opBuilder.getStringAttr(attrName);
switch (static_cast<spirv::Decoration>(words[1])) {
@@ -251,8 +251,9 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
static_cast<FPFastMathMode>(words[2])));
break;
- case spirv::Decoration::DescriptorSet:
+ case spirv::Decoration::Alignment:
case spirv::Decoration::Binding:
+ case spirv::Decoration::DescriptorSet:
if (words.size() != 3) {
return emitError(unknownLoc, "OpDecorate with ")
<< decorationName << " needs a single integer literal";
@@ -295,8 +296,14 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
+ case spirv::Decoration::NoSignedWrap:
+ case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::Restrict:
@@ -308,6 +309,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
+ case spirv::Decoration::SingleElementVectorINTEL:
+ case spirv::Decoration::VectorComputeCallableFunctionINTEL:
+ case spirv::Decoration::VectorComputeFunctionINTEL:
+ case spirv::Decoration::VectorComputeVariableINTEL:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
@@ -307,6 +314,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
@@ -318,6 +323,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
// it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
Expand All @@ -58,28 +52,24 @@ index b84d1d9c2187..90416289134b 100644
case spirv::Decoration::SpecId:
if (words.size() != 3) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1ef8ff043e69..988e60d08edf 100644
index 9e9a16456cc1..b4d467f6656c 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -231,8 +231,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
args.push_back(static_cast<uint32_t>(linkageType));
break;
}
@@ -247,8 +247,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
<< attrName;
+ case spirv::Decoration::Alignment:
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
+ case spirv::Decoration::FuncParamIOKindINTEL:
case spirv::Decoration::Location:
if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
args.push_back(intAttr.getValue().getZExtValue());
@@ -255,8 +257,14 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
+ case spirv::Decoration::NoSignedWrap:
+ case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::Restrict:
@@ -275,6 +277,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
+ case spirv::Decoration::SingleElementVectorINTEL:
+ case spirv::Decoration::VectorComputeCallableFunctionINTEL:
+ case spirv::Decoration::VectorComputeFunctionINTEL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/d
7 files changed, 70 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca56..4a374e713e3c 100644
index ee1fbba1e284..d3e756c31889 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4038,15 +4038,28 @@ def SPIRV_SamplerUseAttr: SPIRV_I32EnumAttr<
Expand Down Expand Up @@ -58,11 +58,11 @@ index 1013cbc8ca56..4a374e713e3c 100644
// Cooperative Matrix Use for the SPV_KHR_cooperative_matrix extension.
def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 07f2f158ecab..e0b3c5448a44 100644
index d946d936d4e6..0c08d7c083e5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -459,7 +459,8 @@ public:
using Base::Base;
@@ -457,7 +457,8 @@ public:
static constexpr StringLiteral name = "spirv.jointmatrix";

static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
- unsigned columns, MatrixLayout matrixLayout);
Expand All @@ -71,7 +71,7 @@ index 07f2f158ecab..e0b3c5448a44 100644
Type getElementType() const;

/// Return the scope of the joint matrix.
@@ -472,6 +473,9 @@ public:
@@ -470,6 +471,9 @@ public:
/// return the layout of the matrix
MatrixLayout getMatrixLayout() const;

Expand All @@ -82,10 +82,10 @@ index 07f2f158ecab..e0b3c5448a44 100644
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78b..fa4cd8dc8447 100644
index 8a68decc5878..00905c32f98a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -392,7 +392,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
@@ -393,7 +393,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,

// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
// element-type
Expand All @@ -95,7 +95,7 @@ index a51d77dda78b..fa4cd8dc8447 100644
static Type parseJointMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
@@ -419,10 +420,14 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect,
@@ -420,10 +421,14 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect,
if (parser.parseComma() ||
spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();
Expand All @@ -111,7 +111,7 @@ index a51d77dda78b..fa4cd8dc8447 100644
}

// TODO: Reorder methods to be utilities first and parse*Type
@@ -926,7 +931,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
@@ -927,7 +932,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", "
<< stringifyMatrixLayout(type.getMatrixLayout());
Expand All @@ -122,7 +122,7 @@ index a51d77dda78b..fa4cd8dc8447 100644

static void print(MatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f9..57e4d5c8fb81 100644
index f1bac6490837..7890f292a50a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -336,7 +336,8 @@ void CooperativeMatrixNVType::getCapabilities(
Expand Down Expand Up @@ -182,10 +182,10 @@ index 39d6603a46f9..57e4d5c8fb81 100644
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ce8b3ab38946..a253d69c5267 100644
index 89e2e7ad52fa..84e43ae858a6 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -977,7 +977,7 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
@@ -988,7 +988,7 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(

LogicalResult
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
Expand All @@ -194,7 +194,7 @@ index ce8b3ab38946..a253d69c5267 100644
return emitError(unknownLoc, "OpTypeJointMatrix must have element "
"type and row x column parameters");
}
@@ -988,6 +988,14 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
@@ -999,6 +999,14 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
<< operands[1];
}

Expand All @@ -209,7 +209,7 @@ index ce8b3ab38946..a253d69c5267 100644
auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
if (!scope) {
return emitError(unknownLoc,
@@ -998,14 +1006,15 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
@@ -1009,14 +1017,15 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
if (!matrixLayout) {
return emitError(unknownLoc,
Expand All @@ -229,20 +229,21 @@ index ce8b3ab38946..a253d69c5267 100644
}

diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index dad085e21b42..e030eb3767ad 100644
index 9e9a16456cc1..412be6ac208d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -653,6 +653,8 @@ LogicalResult Serializer::prepareBasicType(
static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
operands.push_back(
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
+ operands.push_back(
@@ -667,7 +667,8 @@ LogicalResult Serializer::prepareBasicType(
operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
getConstantOp(jointMatrixType.getColumns()),
getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
- getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
+ getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())),
+ getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixUse())));
return success();
}

diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 9aeb14d14eec..4b273aa0f04a 100644
index 9aeb14d14eec..d54b267bea47 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -523,7 +523,7 @@ static mlir::GenRegistration
Expand Down
Loading

0 comments on commit 1b773a9

Please sign in to comment.