Skip to content

Commit

Permalink
Update llvm and redo patch 0003 and 0004
Browse files Browse the repository at this point in the history
Fix all post llvm update issues.
Add filter for SPIRV JointMatrix tests.
  • Loading branch information
silee2 committed Sep 29, 2023
1 parent 914763d commit 6e981bc
Show file tree
Hide file tree
Showing 29 changed files with 374 additions and 375 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ba7cb620ac002a94af0e1656ba591308f7073ab9
d20190e68413634b87f0f9426312a0e9d8456d18
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
From 810eccbbbb872402391a0f01a53aaf0205ea10c4 Mon Sep 17 00:00:00 2001
From 9837f33f72569d59c0ebef542e3647ec554f369c Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <[email protected]>
Date: Tue, 8 Aug 2023 00:28:31 +0000
Subject: [PATCH] Update the Joint Matrix support to match Spec supported by
IGC
Date: Tue, 26 Sep 2023 19:07:40 +0000
Subject: [PATCH] Update the Joint Matrix support to match IGC spec

Update the Joint Matrix support to match the following spec:
https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 30 ++++++++++++++-----
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 31 +++++++++++++------
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 6 +++-
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 12 ++++++--
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 20 +++++++++----
.../SPIRV/Deserialization/Deserializer.cpp | 17 +++++++----
.../Target/SPIRV/Serialization/Serializer.cpp | 5 +++-
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 16 +++++-----
7 files changed, 75 insertions(+), 31 deletions(-)
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 12 +++++--
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 20 ++++++++----
.../SPIRV/Deserialization/Deserializer.cpp | 17 +++++++---
.../Target/SPIRV/Serialization/Serializer.cpp | 2 ++
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 10 +++---
7 files changed, 70 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6f0f728f811e..c2ad6ff24bea 100644
index 1013cbc8ca56..4a374e713e3c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4039,16 +4039,30 @@ def SPIRV_SamplerUseAttr: SPIRV_I32EnumAttr<
@@ -4038,15 +4038,28 @@ def SPIRV_SamplerUseAttr: SPIRV_I32EnumAttr<
"image_sampler_use_info",
[SPIRV_ISUI_SamplerUnknown, SPIRV_ISUI_NeedSampler, SPIRV_ISUI_NoSampler]>;

Expand All @@ -32,6 +31,7 @@ index 6f0f728f811e..c2ad6ff24bea 100644
-def SPIRV_MatrixLayoutAttr :
- SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
- SPIRV_ML_ColumnMajor, SPIRV_ML_RowMajor, SPIRV_ML_PackedA, SPIRV_ML_PackedB
- ]>;
+// Change the layout parameter to IGC spec, the currnet MLIR version
+// does not match the IGC spec, IGC spec has been updated
+// https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
Expand All @@ -44,8 +44,8 @@ index 6f0f728f811e..c2ad6ff24bea 100644
+ def SPIRV_MatrixLayoutAttr :
+ SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
+ SPIRV_ML_RowMajor, SPIRV_ML_ColumnMajor, SPIRV_ML_Packed, SPIRV_ML_Unused
]>;

+ ]>;
+
+def SPIRV_ML_MATRIX_A : I32EnumAttrCase<"MatrixA", 0>;
+def SPIRV_ML_MATRIX_B : I32EnumAttrCase<"MatrixB", 1>;
+def SPIRV_ML_MATRIX_ACC : I32EnumAttrCase<"Accumulator", 2>;
Expand All @@ -54,11 +54,9 @@ index 6f0f728f811e..c2ad6ff24bea 100644
+ SPIRV_I32EnumAttr<"MatrixUse", "valid SPIR-V MatrixUse", "matrixUse", [
+ SPIRV_ML_MATRIX_A, SPIRV_ML_MATRIX_B, SPIRV_ML_MATRIX_ACC
+ ]>;
+
+

// Cooperative Matrix Use for the SPV_KHR_cooperative_matrix extension.
def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>;
def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 07f2f158ecab..e0b3c5448a44 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Expand All @@ -84,10 +82,10 @@ index 07f2f158ecab..e0b3c5448a44 100644
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 9188f8b699b4..4c099bf77a88 100644
index a51d77dda78b..fa4cd8dc8447 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -394,7 +394,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
@@ -392,7 +392,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,

// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
// element-type
Expand All @@ -97,7 +95,7 @@ index 9188f8b699b4..4c099bf77a88 100644
static Type parseJointMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
@@ -421,10 +422,14 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect,
@@ -419,10 +420,14 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect,
if (parser.parseComma() ||
spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();
Expand All @@ -113,7 +111,7 @@ index 9188f8b699b4..4c099bf77a88 100644
}

// TODO: Reorder methods to be utilities first and parse*Type
@@ -952,7 +957,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
@@ -926,7 +931,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", "
<< stringifyMatrixLayout(type.getMatrixLayout());
Expand All @@ -124,10 +122,10 @@ index 9188f8b699b4..4c099bf77a88 100644

static void print(MatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 741d8069471d..49ded5c60951 100644
index 39d6603a46f9..57e4d5c8fb81 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -352,7 +352,8 @@ void CooperativeMatrixNVType::getCapabilities(
@@ -336,7 +336,8 @@ void CooperativeMatrixNVType::getCapabilities(
//===----------------------------------------------------------------------===//

struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
Expand All @@ -137,7 +135,7 @@ index 741d8069471d..49ded5c60951 100644

static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
@@ -361,26 +362,29 @@ struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
@@ -345,26 +346,29 @@ struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
}

bool operator==(const KeyTy &key) const {
Expand Down Expand Up @@ -172,7 +170,7 @@ index 741d8069471d..49ded5c60951 100644
}

Type JointMatrixINTELType::getElementType() const {
@@ -397,6 +401,10 @@ MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
@@ -381,6 +385,10 @@ MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
return getImpl()->matrixLayout;
}

Expand All @@ -184,10 +182,10 @@ index 741d8069471d..49ded5c60951 100644
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 90416289134b..4598dc608034 100644
index ce8b3ab38946..a253d69c5267 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -939,7 +939,7 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
@@ -977,7 +977,7 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(

LogicalResult
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
Expand All @@ -196,22 +194,22 @@ index 90416289134b..4598dc608034 100644
return emitError(unknownLoc, "OpTypeJointMatrix must have element "
"type and row x column parameters");
}
@@ -949,7 +949,13 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
@@ -988,6 +988,14 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
<< operands[1];
}
-

+ auto matrixUse =
+ spirv::symbolizeMatrixUse(getConstantInt(operands[6]).getInt());
+ if (!matrixUse) {
+ return emitError(unknownLoc,
+ "OpTypeJointMatrix references undefined Use <id> ")
+ << operands[6];
+ }
+
auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
if (!scope) {
return emitError(unknownLoc,
@@ -960,14 +966,15 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
@@ -998,14 +1006,15 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
if (!matrixLayout) {
return emitError(unknownLoc,
Expand All @@ -231,20 +229,10 @@ index 90416289134b..4598dc608034 100644
}

diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 988e60d08edf..b6ec58648d72 100644
index dad085e21b42..e030eb3767ad 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -222,7 +222,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+ auto linkageAttr =
+ llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
@@ -639,6 +640,8 @@ LogicalResult Serializer::prepareBasicType(
@@ -653,6 +653,8 @@ LogicalResult Serializer::prepareBasicType(
static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
operands.push_back(
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
Expand All @@ -254,52 +242,17 @@ index 988e60d08edf..b6ec58648d72 100644
}

diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ccf4240f8e56..a793564e0477 100644
index 9aeb14d14eec..4b273aa0f04a 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -396,10 +396,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
avail.getMergeInstanceType(), avail.getQueryFnName(),
enumName);

- os << formatv(
- " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
- " && \"cannot have more than one bit set\");\n",
- underlyingType);
+ os << formatv(" assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
+ " && \"cannot have more than one bit set\");\n",
+ underlyingType);
@@ -523,7 +523,7 @@ static mlir::GenRegistration
constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
"SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
"SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
- "SPIRV_MatrixLayoutAttr"};
+ "SPIRV_MatrixLayoutAttr", "SPIRV_MatrixUseAttr"};

os << " switch (value) {\n";
for (const auto &caseSpecPair : classCasePair.getValue()) {
@@ -523,7 +522,8 @@ static void emitAttributeSerialization(const Attribute &attr,
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
+ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" ||
+ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
@@ -818,7 +818,8 @@ static void emitAttributeDeserialization(const Attribute &attr,
raw_ostream &os) {
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
+ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" ||
+ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
@@ -926,7 +927,8 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
// Process operands/attributes
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
- if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
+ if (auto *valueArg =
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
if (valueArg->isVariableLength()) {
if (i != e - 1) {
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
/// generates code extracts the attribute with name `attrName` from
--
2.34.1
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
commit c64e2cf1aaae2f9c41c4f9ccf932695417027432
Author: Dewei Wang <[email protected]>
Date: Thu Sep 7 00:18:12 2023 +0800
From 4bae3239d087490f7cbfb07befafe770ecff6d22 Mon Sep 17 00:00:00 2001
From: Dewei Wang <[email protected]>
Date: Fri, 29 Sep 2023 10:30:54 -0700
Subject: [PATCH] [mlir][spirv] fix linkage_name StringAttr

[mlir][spirv] fix linkage_name StringAttr
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td | 2 +-
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 3 ++-
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | 2 +-
3 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 259a96651abb..7eaa6c10ae6e 100644
index f2c1ee5cfd56..74d36445e311 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -48,7 +48,7 @@ def SPIRV_CapabilityArrayAttr : TypedArrayAttrBase<
Expand All @@ -18,10 +23,10 @@ index 259a96651abb..7eaa6c10ae6e 100644
);
let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 4598dc608034..3c74ab5fb5d8 100644
index ce8b3ab38946..18ec1ef7a660 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -282,10 +282,11 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
@@ -281,10 +281,11 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
// 3 + ceildiv(strlen(name), 4).
unsigned wordIndex = 2;
auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
Expand All @@ -35,15 +40,17 @@ index 4598dc608034..3c74ab5fb5d8 100644
break;
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b6ec58648d72..893d0a609a9a 100644
index dad085e21b42..9565dc982f30 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -224,7 +224,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
@@ -223,7 +223,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
auto linkageAttr =
llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
- auto linkageName = linkageAttr.getLinkageName();
+ auto linkageName = linkageAttr.getLinkageName().getValue();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
spirv::encodeStringLiteralInto(args, linkageName);
--
2.42.0
1 change: 1 addition & 0 deletions include/imex/Dialect/Dist/IR/DistOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef _Dist_OPS_H_INCLUDED_
#define _Dist_OPS_H_INCLUDED_

#include <mlir/Bytecode/BytecodeOpInterface.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/OpDefinition.h>
Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/GPUXToLLVM/GPUXToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
mlir::Value one = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmInt32Type, rewriter.getI32IntegerAttr(1));
auto computeTypeSize = [&](mlir::Type type) -> mlir::Value {
auto nullPtr = rewriter.create<mlir::LLVM::NullOp>(loc, type);
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, type);
auto gep = rewriter.create<mlir::LLVM::GEPOp>(loc, type, nullPtr, one);
return rewriter.create<mlir::LLVM::PtrToIntOp>(loc, llvmIndexType, gep);
};
Expand Down Expand Up @@ -483,7 +483,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
range, i);
}

auto nullPtr = rewriter.create<mlir::LLVM::NullOp>(loc, llvmPointerType);
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
auto nullRange = [&]() {
auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, 0));
Expand Down Expand Up @@ -547,8 +547,8 @@ class ConvertGpuStreamCreatePattern
// TODO: Pass nullptrs now for the current workflow where user is
// not passing device and context. Add different streambuilders
// later.
auto device = rewriter.create<mlir::LLVM::NullOp>(loc, llvmPointerType);
auto context = rewriter.create<mlir::LLVM::NullOp>(loc, llvmPointerType);
auto device = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
auto context = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
auto res = streamCreateCallBuilder.create(loc, rewriter, {device, context});
rewriter.replaceOp(op, res.getResults());
return mlir::success();
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/PTensorToLinalg/PTensorToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId binOp,
// case ptensor::LSHIFT] =
// case ptensor::MATMUL] =
case ptensor::MAXIMUM:
return buildTrivial<mlir::arith::MaxSIOp, mlir::arith::MaxFOp>(typ);
return buildTrivial<mlir::arith::MaxSIOp, mlir::arith::MaximumFOp>(typ);
case ptensor::MINIMUM:
return buildTrivial<mlir::arith::MinSIOp, mlir::arith::MinFOp>(typ);
return buildTrivial<mlir::arith::MinSIOp, mlir::arith::MinimumFOp>(typ);
case ptensor::MODULO:
return buildTrivial<mlir::arith::RemSIOp, mlir::arith::RemFOp>(typ);
case ptensor::MULTIPLY:
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#define _IMEX_CONVERSION_PASSDETAIL_H_

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/FunctionInterfaces.h>
#include <mlir/Interfaces/FunctionInterfaces.h>
#include <mlir/Pass/Pass.h>

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/PTensor/Transforms/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#define _PTENSOR_PASSDETAIL_H_INCLUDED_

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/FunctionInterfaces.h>
#include <mlir/Interfaces/FunctionInterfaces.h>
#include <mlir/Pass/Pass.h>

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/XeGPU/Transforms/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#define _XeGPU_PASSDETAIL_H_INCLUDED_

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/FunctionInterfaces.h>
#include <mlir/Interfaces/FunctionInterfaces.h>
#include <mlir/Pass/Pass.h>

namespace mlir {
Expand Down
Loading

0 comments on commit 6e981bc

Please sign in to comment.