Skip to content

Commit

Permalink
[Conversion] XeGPU load/store/prefetch2d lowering to raw_send (#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dewei-Wang-sh authored Sep 27, 2023
1 parent faa58d5 commit 9d3fdc8
Show file tree
Hide file tree
Showing 11 changed files with 538 additions and 284 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ index 6f0f728f811e..c2ad6ff24bea 100644
@@ -4039,16 +4039,30 @@ def SPIRV_SamplerUseAttr: SPIRV_I32EnumAttr<
"image_sampler_use_info",
[SPIRV_ISUI_SamplerUnknown, SPIRV_ISUI_NeedSampler, SPIRV_ISUI_NoSampler]>;

-def SPIRV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>;
-def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>;
-def SPIRV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>;
Expand All @@ -45,7 +45,7 @@ index 6f0f728f811e..c2ad6ff24bea 100644
+ SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
+ SPIRV_ML_RowMajor, SPIRV_ML_ColumnMajor, SPIRV_ML_Packed, SPIRV_ML_Unused
]>;

+def SPIRV_ML_MATRIX_A : I32EnumAttrCase<"MatrixA", 0>;
+def SPIRV_ML_MATRIX_B : I32EnumAttrCase<"MatrixB", 1>;
+def SPIRV_ML_MATRIX_ACC : I32EnumAttrCase<"Accumulator", 2>;
Expand All @@ -65,18 +65,18 @@ index 07f2f158ecab..e0b3c5448a44 100644
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -459,7 +459,8 @@ public:
using Base::Base;

static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
- unsigned columns, MatrixLayout matrixLayout);
+ unsigned columns, MatrixLayout matrixLayout,
+ MatrixUse matrixUse);
Type getElementType() const;

/// Return the scope of the joint matrix.
@@ -472,6 +473,9 @@ public:
/// return the layout of the matrix
MatrixLayout getMatrixLayout() const;

+ /// return the use of the matrix
+ MatrixUse getMatrixUse() const;
+
Expand All @@ -88,7 +88,7 @@ index 9188f8b699b4..4c099bf77a88 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -394,7 +394,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,

// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
// element-type
-// `,` layout `,` scope`>`
Expand All @@ -111,7 +111,7 @@ index 9188f8b699b4..4c099bf77a88 100644
- matrixLayout);
+ matrixLayout, matrixUse);
}

// TODO: Reorder methods to be utilities first and parse*Type
@@ -952,7 +957,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
Expand All @@ -121,46 +121,46 @@ index 9188f8b699b4..4c099bf77a88 100644
+ os << ", " << stringifyScope(type.getScope()) << ", "
+ << stringifyMatrixUse(type.getMatrixUse()) << ">";
}

static void print(MatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 741d8069471d..49ded5c60951 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -352,7 +352,8 @@ void CooperativeMatrixNVType::getCapabilities(
//===----------------------------------------------------------------------===//

struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
+ using KeyTy =
+ std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope, MatrixUse>;

static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
@@ -361,26 +362,29 @@ struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
}

bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
+ return key ==
+ KeyTy(elementType, rows, columns, matrixLayout, scope, matrixUse);
}

JointMatrixTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), rows(std::get<1>(key)),
- columns(std::get<2>(key)), scope(std::get<4>(key)),
- matrixLayout(std::get<3>(key)) {}
+ columns(std::get<2>(key)), matrixLayout(std::get<3>(key)),
+ scope(std::get<4>(key)), matrixUse(std::get<5>(key)) {}

Type elementType;
unsigned rows;
unsigned columns;
Scope scope;
MatrixLayout matrixLayout;
+ MatrixUse matrixUse;
};

JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
unsigned rows, unsigned columns,
- MatrixLayout matrixLayout) {
Expand All @@ -170,12 +170,12 @@ index 741d8069471d..49ded5c60951 100644
- matrixLayout, scope);
+ matrixLayout, scope, matrixUse);
}

Type JointMatrixINTELType::getElementType() const {
@@ -397,6 +401,10 @@ MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
return getImpl()->matrixLayout;
}

+MatrixUse JointMatrixINTELType::getMatrixUse() const {
+ return getImpl()->matrixUse;
+}
Expand All @@ -188,7 +188,7 @@ index 90416289134b..4598dc608034 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -939,7 +939,7 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {

LogicalResult
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
- if (operands.size() != 6) {
Expand Down Expand Up @@ -221,15 +221,15 @@ index 90416289134b..4598dc608034 100644
}
unsigned rows = getConstantInt(operands[2]).getInt();
unsigned columns = getConstantInt(operands[3]).getInt();

- typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
- elementTy, scope.value(), rows, columns, matrixLayout.value());
+ typeMap[operands[0]] =
+ spirv::JointMatrixINTELType::get(elementTy, scope.value(), rows, columns,
+ matrixLayout.value(), matrixUse.value());
return success();
}

diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 988e60d08edf..b6ec58648d72 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Expand All @@ -252,23 +252,23 @@ index 988e60d08edf..b6ec58648d72 100644
+ getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixUse())));
return success();
}

diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ccf4240f8e56..a793564e0477 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -396,10 +396,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
avail.getMergeInstanceType(), avail.getQueryFnName(),
enumName);

- os << formatv(
- " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
- " && \"cannot have more than one bit set\");\n",
- underlyingType);
+ os << formatv(" assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
+ " && \"cannot have more than one bit set\");\n",
+ underlyingType);

os << " switch (value) {\n";
for (const auto &caseSpecPair : classCasePair.getValue()) {
@@ -523,7 +522,8 @@ static void emitAttributeSerialization(const Attribute &attr,
Expand Down Expand Up @@ -301,6 +301,5 @@ index ccf4240f8e56..a793564e0477 100644
if (valueArg->isVariableLength()) {
if (i != e - 1) {
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
--
--
2.34.1

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
commit c64e2cf1aaae2f9c41c4f9ccf932695417027432
Author: Dewei Wang <[email protected]>
Date: Thu Sep 7 00:18:12 2023 +0800

[mlir][spirv] fix linkage_name StringAttr

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 259a96651abb..7eaa6c10ae6e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -48,7 +48,7 @@ def SPIRV_CapabilityArrayAttr : TypedArrayAttrBase<

def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attributes"> {
let parameters = (ins
- "std::string":$linkage_name,
+ "StringAttr":$linkage_name,
"mlir::spirv::LinkageTypeAttr":$linkage_type
);
let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 4598dc608034..3c74ab5fb5d8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -282,10 +282,11 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
// 3 + ceildiv(strlen(name), 4).
unsigned wordIndex = 2;
auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
+ auto linkageNameAttr = opBuilder.getStringAttr(linkageName);
auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
- linkageName, linkageTypeAttr);
+ linkageNameAttr, linkageTypeAttr);
decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
break;
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b6ec58648d72..893d0a609a9a 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -224,7 +224,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
// e.g., LinkageAttributes=["linkageName", linkageType].
auto linkageAttr =
llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
- auto linkageName = linkageAttr.getLinkageName();
+ auto linkageName = linkageAttr.getLinkageName().getValue();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
spirv::encodeStringLiteralInto(args, linkageName);
2 changes: 1 addition & 1 deletion include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def XeGPU_CacheReadAttr : I32EnumAttr<
"CacheReadHint", "", [ I32EnumAttrCase<"UNCACHED", 0, "uncached">,
I32EnumAttrCase<"CACHED", 1, "cached">,
I32EnumAttrCase<"STREAMING", 2, "streaming">,
I32EnumAttrCase<"READ_INVALDIATE", 3, "read_invalidiate"> ]> {
I32EnumAttrCase<"READ_INVALIDATE", 3, "read_invalidate"> ]> {

let cppNamespace = "::imex::xegpu";
}
Expand Down
Loading

0 comments on commit 9d3fdc8

Please sign in to comment.