diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td b/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td index 5f4864511..9d75719e2 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -21,26 +21,26 @@ def XeGPU_ScatteredAttr : XeGPUAttr<"Scattered", "scattered"> { } def XeGPU_SgMapAttr: XeGPUAttr<"SgMap", "sg_map"> { - let parameters = (ins + let parameters = (ins ArrayRefParameter<"unsigned">:$mmaBlockSize, ArrayRefParameter<"unsigned">:$wiLayout, ArrayRefParameter<"unsigned">:$wiData); - + // In format of #xegpu.sg_map<{mma_block_size = [2, 4], wi_layout = [2, 4], wi_data = [2, 4]}> let assemblyFormat = "`<` custom($mmaBlockSize, $wiLayout, $wiData) `>`"; } def XeGPU_WgMapAttr: XeGPUAttr<"WgMap", "wg_map"> { - let parameters = (ins + let parameters = (ins ArrayRefParameter<"unsigned">:$sgLayout, ArrayRefParameter<"unsigned">:$sgData); - + // In format of #xegpu.wg_map<{sg_layout = [2, 4], sg_data = [2, 4]}> let assemblyFormat = "`<` custom($sgLayout, $sgData) `>`"; } def XeGPU_XeMapAttr: XeGPUAttr<"XeMap", "xe_map"> { - let parameters = (ins + let parameters = (ins XeGPU_WgMapAttr: $wg, XeGPU_SgMapAttr: $sg); diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td index fe5372471..e2105c092 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td @@ -255,9 +255,9 @@ def XeGPU_CreateDescOp (scattered) subviews. It accepts the following parameters: * source: a 1D memref or pointer (uint64_t) represents the memory object. - * offsets: In VectorCompute (VC) mode, it is a 1D vector containing offsets of each access point, the size is aligned with - supportted group size, e.g., vector<16xindex>. And each element in the vector corresponds to a - work item (SIMT lane) in the subgroup. + * offsets: In VectorCompute (VC) mode, it is a 1D vector containing offsets of each access point, the size is aligned with + supportted group size, e.g., vector<16xindex>. And each element in the vector corresponds to a + work item (SIMT lane) in the subgroup. In SIMT mode (default), it is an index scalar representing the offset of the access point. * memory_scope: [optional attribute] indicates where the memory is located, "global" for global memory (default), and "slm" for shared memory. * chunk_size_per_lane: [optional attribute] indicates number of continious elements accessed for each offset, default is 1. @@ -267,7 +267,7 @@ def XeGPU_CreateDescOp %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> %1 = xegpu.create_tdesc %a, %c0: memref<1024xf32> -> TensorDesc<4xf32> - Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. + Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71] %0 = memref.alloc() : memref<1024xf32> %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> @@ -330,7 +330,7 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { DefaultValuedAttr: $mode); let results = (outs XeGPU_ValueType: $value); - // Format: xegpu.load_nd %1 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} + // Format: xegpu.load_nd %1 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} // : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> let hasCustomAssemblyFormat = 1; @@ -363,7 +363,7 @@ def XeGPU_PrefetchNDOp : XeGPU_Op<"prefetch_nd", []> { DefaultValuedAttr: $mode ); - // In format of: xegpu.prefetch_nd %tdesc {l1_hint = cached, l2_hint = uncached}: + // In format of: xegpu.prefetch_nd %tdesc {l1_hint = cached, l2_hint = uncached}: // !xegpu.tensor_desc<8x16xf16> let hasCustomAssemblyFormat = 1; } @@ -417,7 +417,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load"> { let results = (outs XeGPU_ValueType: $value); - // In format of: %2 = xegpu.load %1, %0 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached} + // In format of: %2 = xegpu.load %1, %0 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached} // : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> let hasCustomAssemblyFormat = 1; let hasVerifier = 1; @@ -483,11 +483,11 @@ def XeGPU_UpdateOffsetOp def XeGPU_InvokeSIMDOp : XeGPU_Op<"invoke_SIMD", []> { let summary = "Invoke_SIMD operation"; let description = [{ - The `xegpu.invoke_SIMD` operation works similar to a direct call to a function. But it is + The `xegpu.invoke_SIMD` operation works similar to a direct call to a function. But it is special to Intel GPU. }]; - let arguments = (ins FlatSymbolRefAttr:$callee, + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands, XeGPU_ArgTypeAttr: $argType); let results = (outs Variadic); @@ -560,7 +560,7 @@ def XeGPU_CreateNbarrierOp let results = (outs Builtin_Vector: $result); let assemblyFormat = [{ - $nbarrier_id `,` $nbarrier_role + $nbarrier_id `,` $nbarrier_role attr-dict `:` `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` `->` qualified(type($result)) }]; diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td b/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td index cef55c67d..e7f0723f4 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td @@ -101,7 +101,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", }]; // let assemblyFormat = "`<` custom($shape, $elementType) (`,` custom($encoding)^)? `>`"; - let assemblyFormat = "`<` custom($shape, $elementType) (`,` $encoding^)? `>`"; + let assemblyFormat = "`<` custom($shape, $elementType) (`,` $encoding^)? `>`"; } #endif // _XEGPU_TYPES_TD_INCLUDED_ diff --git a/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 1b245244b..83b31b8e2 100644 --- a/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -77,86 +77,101 @@ static void printShapeAndType(mlir::AsmPrinter &printer, printer << type; } -template -static mlir::LogicalResult parseArrayList(mlir::AsmParser &parser, - llvm::SmallVector &array, +template +static mlir::LogicalResult parseArrayList(mlir::AsmParser &parser, + llvm::SmallVector &array, bool parsePrecedenceEqual = false) { mlir::FailureOr> result; // Parse literal '=' if (parsePrecedenceEqual) - if (parser.parseEqual()) return mlir::failure(); + if (parser.parseEqual()) + return mlir::failure(); // Parse literal '[' - if (parser.parseLSquare()) return mlir::failure(); + if (parser.parseLSquare()) + return mlir::failure(); result = mlir::FieldParser<::llvm::SmallVector>::parse(parser); - if (::mlir::failed(result)) return mlir::failure(); + if (::mlir::failed(result)) + return mlir::failure(); // Parse literal ']' - if (parser.parseRSquare()) return mlir::failure(); + if (parser.parseRSquare()) + return mlir::failure(); array = result.value(); return mlir::success(); } -template -static void printArrayElement(mlir::AsmPrinter &printer, - llvm::StringRef keyword, +template +static void printArrayElement(mlir::AsmPrinter &printer, + llvm::StringRef keyword, llvm::ArrayRef array) { printer << keyword; printer << ' ' << "="; printer << ' ' << "["; printer.printStrippedAttrOrType(array); - printer << "]"; + printer << "]"; } - -static mlir::LogicalResult parseSgMapAttrElements(mlir::AsmParser &parser, - llvm::SmallVector &mmaBlockSize, - llvm::SmallVector &layout, - llvm::SmallVector &data) { +static mlir::LogicalResult parseSgMapAttrElements( + mlir::AsmParser &parser, llvm::SmallVector &mmaBlockSize, + llvm::SmallVector &layout, llvm::SmallVector &data) { auto loc = parser.getCurrentLocation(); auto parseElt = [&]() -> mlir::LogicalResult { return mlir::AsmParser::KeywordSwitch(parser) - .Case("mma_block_size", [&](llvm::StringRef, llvm::SMLoc) { - return parseArrayList(parser, mmaBlockSize, true); - }) - .Case("wi_layout", [&](llvm::StringRef, llvm::SMLoc) { - return parseArrayList(parser, layout, true); - }) - .Case("wi_data", [&](llvm::StringRef, llvm::SMLoc) { - return parseArrayList(parser, data, true); - }) - .Default([&](llvm::StringRef keyword, llvm::SMLoc) { - llvm::dbgs() << "\n3. Default currLoc: " << llvm::StringRef(parser.getCurrentLocation().getPointer()) << "\n"; - llvm::dbgs() << "\n3. keyword: " << keyword << "\n"; - return mlir::failure(); - }); - }; - - if (parser.parseLBrace()) return mlir::failure(); - if (parser.parseCommaSeparatedList(parseElt)) return mlir::failure(); - if (parser.parseRBrace()) return mlir::failure(); + .Case("mma_block_size", + [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, mmaBlockSize, true); + }) + .Case("wi_layout", + [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, layout, true); + }) + .Case("wi_data", + [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, data, true); + }) + .Default([&](llvm::StringRef keyword, llvm::SMLoc) { + llvm::dbgs() << "\n3. Default currLoc: " + << llvm::StringRef( + parser.getCurrentLocation().getPointer()) + << "\n"; + llvm::dbgs() << "\n3. keyword: " << keyword << "\n"; + return mlir::failure(); + }); + }; + + if (parser.parseLBrace()) + return mlir::failure(); + if (parser.parseCommaSeparatedList(parseElt)) + return mlir::failure(); + if (parser.parseRBrace()) + return mlir::failure(); if (mmaBlockSize.size() != 2) { - parser.emitError(loc, "failed to parse SgMapAttr: missing mma_block_size which is to be a `llvm::ArrayRef` with size 2"); + parser.emitError(loc, + "failed to parse SgMapAttr: missing mma_block_size which " + "is to be a `llvm::ArrayRef` with size 2"); return mlir::failure(); } if (layout.size() != 2) { - parser.emitError(loc, "failed to parse SgMapAttr: missing wi_layout which is to be a `llvm::ArrayRef` with size 2"); + parser.emitError(loc, "failed to parse SgMapAttr: missing wi_layout which " + "is to be a `llvm::ArrayRef` with size 2"); return mlir::failure(); } if (data.size() != 2) { - parser.emitError(loc, "failed to parse SgMapAttr: missing wi_data which is to be a `llvm::ArrayRef` with size 2"); + parser.emitError(loc, "failed to parse SgMapAttr: missing wi_data which is " + "to be a `llvm::ArrayRef` with size 2"); return mlir::failure(); } return mlir::success(); } -static void printSgMapAttrElements(mlir::AsmPrinter &printer, - llvm::ArrayRef mmaBlockSize, - llvm::ArrayRef layout, - llvm::ArrayRef data) { +static void printSgMapAttrElements(mlir::AsmPrinter &printer, + llvm::ArrayRef mmaBlockSize, + llvm::ArrayRef layout, + llvm::ArrayRef data) { printer << "{"; printArrayElement(printer, "mma_block_size", mmaBlockSize); printer << "," << ' '; @@ -166,40 +181,48 @@ static void printSgMapAttrElements(mlir::AsmPrinter &printer, printer << "}"; } -static mlir::LogicalResult parseWgMapAttrElements(mlir::AsmParser &parser, - llvm::SmallVector &layout, - llvm::SmallVector &data) { +static mlir::LogicalResult +parseWgMapAttrElements(mlir::AsmParser &parser, + llvm::SmallVector &layout, + llvm::SmallVector &data) { auto loc = parser.getCurrentLocation(); auto parseElt = [&]() -> mlir::LogicalResult { return mlir::AsmParser::KeywordSwitch(parser) - .Case("sg_layout", [&](llvm::StringRef, llvm::SMLoc) { - return parseArrayList(parser, layout, true); - }) - .Case("sg_data", [&](llvm::StringRef, llvm::SMLoc) { - return parseArrayList(parser, data, true); - }) - .Default([&](llvm::StringRef keyword, llvm::SMLoc) { - return mlir::failure(); - }); + .Case("sg_layout", + [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, layout, true); + }) + .Case("sg_data", + [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, data, true); + }) + .Default([&](llvm::StringRef keyword, llvm::SMLoc) { + return mlir::failure(); + }); }; - if (parser.parseLBrace()) return mlir::failure(); - if (parser.parseCommaSeparatedList(parseElt)) return mlir::failure(); - if (parser.parseRBrace()) return mlir::failure(); + if (parser.parseLBrace()) + return mlir::failure(); + if (parser.parseCommaSeparatedList(parseElt)) + return mlir::failure(); + if (parser.parseRBrace()) + return mlir::failure(); if (layout.size() != 2) { - parser.emitError(loc, "failed to parse WgMapAttr: missing sg_layout which is to be a `llvm::ArrayRef` with size 2"); + parser.emitError(loc, "failed to parse WgMapAttr: missing sg_layout which " + "is to be a `llvm::ArrayRef` with size 2"); return mlir::failure(); } if (data.size() != 2) { - parser.emitError(loc, "failed to parse WgMapAttr: missing sg_data which is to be a `llvm::ArrayRef` with size 2"); + parser.emitError(loc, "failed to parse WgMapAttr: missing sg_data which is " + "to be a `llvm::ArrayRef` with size 2"); return mlir::failure(); } return mlir::success(); } -static void printWgMapAttrElements(mlir::AsmPrinter &printer, - llvm::ArrayRef layout, - llvm::ArrayRef data) { +static void printWgMapAttrElements(mlir::AsmPrinter &printer, + llvm::ArrayRef layout, + llvm::ArrayRef data) { printer << "{"; printArrayElement(printer, "sg_layout", layout); printer << "," << ' '; @@ -207,48 +230,60 @@ static void printWgMapAttrElements(mlir::AsmPrinter &printer, printer << "}"; } - mlir::Attribute XeMapAttr::parse(mlir::AsmParser &parser, mlir::Type type) { imex::xegpu::WgMapAttr wg; imex::xegpu::SgMapAttr sg; // Parse literal '<' - if (parser.parseLess()) return {}; - - auto parseElt = [&]() -> mlir::ParseResult { - mlir::OptionalParseResult result = mlir::AsmParser::KeywordSwitch(parser) - .Case("sg", [&](llvm::StringRef, llvm::SMLoc) { - if (parser.parseEqual()) return mlir::failure(); - llvm::SmallVector mmaBlockSize; - llvm::SmallVector wiLayout; - llvm::SmallVector wiData; - if (mlir::failed(parseSgMapAttrElements(parser, mmaBlockSize, wiLayout, wiData))) - return mlir::failure(); - sg = imex::xegpu::SgMapAttr::get(parser.getContext(), mmaBlockSize, wiLayout, wiData); - return mlir::success(!!sg); - }) - .Case("wg", [&](llvm::StringRef, llvm::SMLoc) { - if (parser.parseEqual()) return mlir::failure(); - llvm::SmallVector sgLayout; - llvm::SmallVector sgData; - if(mlir::failed(parseWgMapAttrElements(parser, sgLayout, sgData))) - return mlir::failure(); - wg = imex::xegpu::WgMapAttr::get(parser.getContext(), sgLayout, sgData); - return mlir::success(!!wg); - }) - .Default([&](llvm::StringRef keyword, llvm::SMLoc) { - return std::nullopt; - }); + if (parser.parseLess()) + return {}; + + auto parseElt = [&]() -> mlir::ParseResult { + mlir::OptionalParseResult result = + mlir::AsmParser::KeywordSwitch(parser) + .Case("sg", + [&](llvm::StringRef, llvm::SMLoc) { + if (parser.parseEqual()) + return mlir::failure(); + llvm::SmallVector mmaBlockSize; + llvm::SmallVector wiLayout; + llvm::SmallVector wiData; + if (mlir::failed(parseSgMapAttrElements( + parser, mmaBlockSize, wiLayout, wiData))) + return mlir::failure(); + sg = imex::xegpu::SgMapAttr::get( + parser.getContext(), mmaBlockSize, wiLayout, wiData); + return mlir::success(!!sg); + }) + .Case("wg", + [&](llvm::StringRef, llvm::SMLoc) { + if (parser.parseEqual()) + return mlir::failure(); + llvm::SmallVector sgLayout; + llvm::SmallVector sgData; + if (mlir::failed( + parseWgMapAttrElements(parser, sgLayout, sgData))) + return mlir::failure(); + wg = imex::xegpu::WgMapAttr::get(parser.getContext(), + sgLayout, sgData); + return mlir::success(!!wg); + }) + .Default([&](llvm::StringRef keyword, llvm::SMLoc) { + return std::nullopt; + }); return result.value(); }; // Parse wg and sg attrs - if (parser.parseCommaSeparatedList(parseElt)) return {}; + if (parser.parseCommaSeparatedList(parseElt)) + return {}; // Parse literal '>' - if (parser.parseGreater()) return {}; + if (parser.parseGreater()) + return {}; - if(!wg && !sg) { - parser.emitError(parser.getCurrentLocation(), "Expecting at least one of sg and wg attributes.\n"); + if (!wg && !sg) { + parser.emitError(parser.getCurrentLocation(), + "Expecting at least one of sg and wg attributes.\n"); return {}; } @@ -265,9 +300,11 @@ void XeMapAttr::print(mlir::AsmPrinter &printer) const { } if (getSg()) { - if (printSep) printer << ", "; + if (printSep) + printer << ", "; printer << "sg = "; - printSgMapAttrElements(printer, getSg().getMmaBlockSize(), getSg().getWiLayout(), getSg().getWiData()); + printSgMapAttrElements(printer, getSg().getMmaBlockSize(), + getSg().getWiLayout(), getSg().getWiData()); } printer << ">"; @@ -281,4 +318,3 @@ void XeMapAttr::print(mlir::AsmPrinter &printer) const { #include #define GET_TYPEDEF_CLASSES #include - diff --git a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 562964f63..a8f0a8aa6 100644 --- a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -63,7 +63,7 @@ const int TK_SIZE_FOR_D8 = 32; // return isValid; // } -static void transpose(llvm::ArrayRef trans, +static void transpose(llvm::ArrayRef trans, std::vector &shape) { std::vector old = shape; for (size_t i = 0; i < trans.size(); i++) @@ -73,15 +73,16 @@ static void transpose(llvm::ArrayRef trans, static void dropOnes(std::vector &array) { std::vector old = array; array.clear(); - for(auto v: old) { - if (v != 1) array.push_back(v); + for (auto v : old) { + if (v != 1) + array.push_back(v); } }; static bool isMappingAttr(mlir::Attribute attr) { - return attr && (llvm::isa(attr) - || llvm::isa(attr) - || llvm::isa(attr)); + return attr && (llvm::isa(attr) || + llvm::isa(attr) || + llvm::isa(attr)); } bool dpasSupportedTypes(mlir::Type type, bool isResult) { @@ -200,7 +201,7 @@ parseOptionalAttrDict(mlir::OpAsmParser &parser, mlir::OperationState &result, } if (nameId == "mode") { - return parseCustomEnumAttr(parser, result, nameId); + return parseCustomEnumAttr(parser, result, nameId); } if (nameId == "chunk_size_per_lane" || nameId == "vnni_axis") @@ -295,7 +296,8 @@ mlir::ParseResult CreateNdDescOp::parse(mlir::OpAsmParser &parser, return ::mlir::failure(); } - if (parseOptionalAttrDict(parser, result, {"memory_scope", "boundary_check", "mode"})) + if (parseOptionalAttrDict(parser, result, + {"memory_scope", "boundary_check", "mode"})) return mlir::failure(); if (parser.parseColon()) @@ -375,7 +377,8 @@ mlir::LogicalResult CreateNdDescOp::verify() { auto encoding = getTensorDesc().getType().getEncoding(); if (mode == imex::xegpu::Mode::SIMT && !isMappingAttr(encoding)) { - return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for SIMT mode operators.\n"); + return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for " + "SIMT mode operators.\n"); } // it is invalid to have both dynamic and static shape @@ -489,7 +492,8 @@ mlir::LogicalResult CreateDescOp::verify() { if (llvm::isa(offsetTy)) { shape = llvm::dyn_cast(offsetTy).getShape().vec(); if (shape.size() != 1) - return emitOpError("Expecting the offset is either a 1D vector (for VC) or scalar (for SIMT)."); + return emitOpError("Expecting the offset is either a 1D vector (for VC) " + "or scalar (for SIMT)."); } if (offsetTy.isIndex() || chunkSize != 1) { @@ -501,9 +505,10 @@ mlir::LogicalResult CreateDescOp::verify() { "tensor descriptor, or one less than."); } - if (!tdescTy.getEncoding()) - return emitOpError("Expecting the presence of scattered attribute for tensor descriptor."); - + if (!tdescTy.getEncoding()) + return emitOpError( + "Expecting the presence of scattered attribute for tensor descriptor."); + return mlir::success(); } @@ -585,10 +590,11 @@ void LoadNDOp::print(::mlir::OpAsmPrinter &printer) { mlir::LogicalResult LoadNDOp::verify() { auto tdescTy = getTensorDesc().getType(); - auto valueTy = llvm::dyn_cast(getValue().getType()); + auto valueTy = llvm::dyn_cast(getValue().getType()); if (tdescTy.getRank() != 2) - return emitOpError("The TensorDesc for LoadNDOp should be a 2D TensorDesc."); + return emitOpError( + "The TensorDesc for LoadNDOp should be a 2D TensorDesc."); if (!valueTy) return emitOpError("Invalid result, it should be a VectorType.\n"); @@ -597,9 +603,11 @@ mlir::LogicalResult LoadNDOp::verify() { auto valueElemTy = valueTy.getElementType(); if (tdescElemTy != valueElemTy) - return emitOpError("Value should have the same element type as TensorDesc."); + return emitOpError( + "Value should have the same element type as TensorDesc."); - { // TODO: The following logic are architecture dependent, pending to be moved out + { // TODO: The following logic are architecture dependent, pending to be moved + // out auto width = tdescTy.getShape()[1]; auto height = tdescTy.getShape()[0]; auto elemTyByteWidth = tdescElemTy.getIntOrFloatBitWidth() / 8; @@ -615,13 +623,12 @@ mlir::LogicalResult LoadNDOp::verify() { if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS || height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) { - return emitOpError( - "Invalid height size for 2D block load. The specification expects the " - "value to be in range [1, 32].\n"); + return emitOpError("Invalid height size for 2D block load. The " + "specification expects the " + "value to be in range [1, 32].\n"); } } - auto mode = getMode(); auto tdescShape = tdescTy.getShape().vec(); auto valueShape = valueTy.getShape().vec(); @@ -632,7 +639,8 @@ mlir::LogicalResult LoadNDOp::verify() { auto encoding = tdescTy.getEncoding(); if (!isMappingAttr(encoding)) { - return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for SIMT mode operators.\n"); + return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for " + "SIMT mode operators.\n"); } if (auto xeMapAttr = llvm::dyn_cast(encoding)) { @@ -646,14 +654,14 @@ mlir::LogicalResult LoadNDOp::verify() { if (wgMap) { auto sgData = wgMap.getSgData(); auto sgLayout = wgMap.getSgLayout(); - for(size_t i = 0; i < sgData.size(); i++) { - if (tdescShape[i] % sgLayout[i] != 0 || - tdescShape[i] % sgData[i] != 0 || - tdescShape[i] % sgData[i] != 0) - return emitOpError("Invalid WgMapAttr. It should meet the following conditions: " - "tdescShape[i] % sgLayout[i] == 0 && " - "tdescShape[i] % sgData[i] == 0 && " - "tdescShape[i] % sgData[i] == 0"); + for (size_t i = 0; i < sgData.size(); i++) { + if (tdescShape[i] % sgLayout[i] != 0 || + tdescShape[i] % sgData[i] != 0 || tdescShape[i] % sgData[i] != 0) + return emitOpError( + "Invalid WgMapAttr. It should meet the following conditions: " + "tdescShape[i] % sgLayout[i] == 0 && " + "tdescShape[i] % sgData[i] == 0 && " + "tdescShape[i] % sgData[i] == 0"); tdescShape[i] /= sgLayout[i]; } // dropOnes(tdescShape); @@ -664,29 +672,29 @@ mlir::LogicalResult LoadNDOp::verify() { auto wiLayout = sgMap.getWiLayout(); auto wiData = sgMap.getWiData(); for (size_t i = 0; i < blockSize.size(); i++) { - if (tdescShape[i] % blockSize[i] != 0 || - blockSize[i] % wiLayout[i] != 0 || - blockSize[i] % wiData[i] != 0 || + if (tdescShape[i] % blockSize[i] != 0 || + blockSize[i] % wiLayout[i] != 0 || blockSize[i] % wiData[i] != 0 || blockSize[i] % (wiLayout[i] * wiData[i]) != 0) { - return emitOpError("Invalid SgMapAttr. It should meet the following conditions: " - "blockSize[i] % wiLayout[i] == 0 && " - "blockSize[i] % wiData[i] == 0 && " - "blockSize[i] % wiData[i] == 0 && " - "tdescShape[i] % blockSize[i] == 0"); - + return emitOpError( + "Invalid SgMapAttr. It should meet the following conditions: " + "blockSize[i] % wiLayout[i] == 0 && " + "blockSize[i] % wiData[i] == 0 && " + "blockSize[i] % wiData[i] == 0 && " + "tdescShape[i] % blockSize[i] == 0"); } - auto tmp = blockSize[i]/wiLayout[i]; + auto tmp = blockSize[i] / wiLayout[i]; tdescShape[i] /= blockSize[i]; tdescShape[i] *= tmp; } - } + } } if (getTranspose()) { auto trans = getTranspose().value(); - if (tdescShape.size() >= trans.size()) + if (tdescShape.size() >= trans.size()) transpose(trans, tdescShape); - else emitWarning("Invalid transpose attr. It is ignored."); + else + emitWarning("Invalid transpose attr. It is ignored."); } if (getVnniAxis()) { @@ -697,16 +705,25 @@ mlir::LogicalResult LoadNDOp::verify() { dropOnes(tdescShape); } if (tdescShape != valueShape) - return emitOpError("Result shape doesn't match TensorDesc shape." - "The expected shape is " + makeString(tdescShape) + ", while " - "the given shape is " + makeString(valueShape) + ". " - "In VC mode, when VNNI is not enabled, the result should have the same " - "shape (or transposed shape if transpose is also enabled) as TensorDesc; " - "when VNNI is enabled, the result should have one more dimention than the " - "TensorDesc, with last dimention having vnni factor, but having same number " - "of total data elements. The vnni factor are typically calculated as simd_lane_width / elementTypeBitWidth. " - "For element type having more than 32 bits, vnni shouldn't be used. " - "In SIMT mode, the shape is derived from the mapping attributes.\n"); + return emitOpError( + "Result shape doesn't match TensorDesc shape." + "The expected shape is " + + makeString(tdescShape) + + ", while " + "the given shape is " + + makeString(valueShape) + + ". " + "In VC mode, when VNNI is not enabled, the result should have the same " + "shape (or transposed shape if transpose is also enabled) as " + "TensorDesc; " + "when VNNI is enabled, the result should have one more dimention than " + "the " + "TensorDesc, with last dimention having vnni factor, but having same " + "number " + "of total data elements. The vnni factor are typically calculated as " + "simd_lane_width / elementTypeBitWidth. " + "For element type having more than 32 bits, vnni shouldn't be used. " + "In SIMT mode, the shape is derived from the mapping attributes.\n"); return mlir::success(); } @@ -729,8 +746,8 @@ ::mlir::ParseResult StoreNDOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(TensorDescRawOperands[0])) return mlir::failure(); - if (parseOptionalAttrDict(parser, result, {"mode", "l1_hint", "l2_hint", "l3_hint"}, - true)) + if (parseOptionalAttrDict(parser, result, + {"mode", "l1_hint", "l2_hint", "l3_hint"}, true)) return mlir::failure(); if (parser.parseColon()) @@ -781,12 +798,12 @@ void StoreNDOp::print(::mlir::OpAsmPrinter &printer) { } mlir::LogicalResult StoreNDOp::verify() { - auto dstTy = getTensorDesc().getType(); // Tile - auto valTy = llvm::dyn_cast(getValue().getType()); // Vector - + auto dstTy = getTensorDesc().getType(); // Tile + auto valTy = llvm::dyn_cast(getValue().getType()); // Vector if (dstTy.getRank() != 2) - return emitOpError("The TensorDesc for StoreNdOp should be a 2D TensorDesc."); + return emitOpError( + "The TensorDesc for StoreNdOp should be a 2D TensorDesc."); if (!valTy) return emitOpError("Invalid value operand, it should be a VectorType.\n"); @@ -799,8 +816,8 @@ mlir::LogicalResult StoreNDOp::verify() { "the elem type of memory (dst) shape.\n"); } - - { // TODO: The following logic are architecture dependent, pending to be moved out + { // TODO: The following logic are architecture dependent, pending to be moved + // out auto width = dstTy.getShape()[1]; auto height = dstTy.getShape()[0]; auto elemTyByteWidth = dstElemTy.getIntOrFloatBitWidth() / 8; @@ -815,8 +832,9 @@ mlir::LogicalResult StoreNDOp::verify() { if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS || height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) { - return emitOpError("Invalid height size for 2D block write. The specification" - "expects the value to be in range [1, 32].\n"); + return emitOpError( + "Invalid height size for 2D block write. The specification" + "expects the value to be in range [1, 32].\n"); } } @@ -824,11 +842,13 @@ mlir::LogicalResult StoreNDOp::verify() { if (mode == imex::xegpu::Mode::VC) { // for VC mode, no attr attached if (dstTy.getShape() != valTy.getShape()) - return emitOpError("In VC mode, the value (vector) shape doesn't match the memory (dst) shape.\n"); + return emitOpError("In VC mode, the value (vector) shape doesn't match " + "the memory (dst) shape.\n"); } else { auto encoding = dstTy.getEncoding(); if (!isMappingAttr(encoding)) { - return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for SIMT mode operators.\n"); + return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for " + "SIMT mode operators.\n"); } imex::xegpu::WgMapAttr wgMap; @@ -846,7 +866,7 @@ mlir::LogicalResult StoreNDOp::verify() { if (wgMap) { auto sgData = wgMap.getSgData(); auto sgLayout = wgMap.getSgLayout(); - for(size_t i = 0; i < sgData.size(); i++) { + for (size_t i = 0; i < sgData.size(); i++) { assert(shape[i] % sgLayout[i] == 0); assert(shape[i] % sgData[i] == 0); assert(shape[i] % (sgLayout[i] * sgData[i]) == 0); @@ -863,15 +883,16 @@ mlir::LogicalResult StoreNDOp::verify() { assert(blockSize[i] % wiLayout[i] == 0); assert(blockSize[i] % wiData[i] == 0); assert(shape[i] % blockSize[i] == 0); - auto tmp = blockSize[i]/wiLayout[i]; + auto tmp = blockSize[i] / wiLayout[i]; shape[i] /= blockSize[i]; shape[i] *= tmp; } } if (shape != valTy.getShape().vec()) - return emitOpError("In SIMT mode, the value (vector) shape doesn't match the memory" - "(dst) shape as derived according to the mapping rule.\n"); + return emitOpError( + "In SIMT mode, the value (vector) shape doesn't match the memory" + "(dst) shape as derived according to the mapping rule.\n"); } return mlir::success(); @@ -890,7 +911,8 @@ ::mlir::ParseResult PrefetchNDOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(TensorDescRawOperands[0])) return ::mlir::failure(); - if (parseOptionalAttrDict(parser, result, {"mode", "l1_hint", "l2_hint", "l3_hint"})) + if (parseOptionalAttrDict(parser, result, + {"mode", "l1_hint", "l2_hint", "l3_hint"})) return mlir::failure(); if (parser.parseColon()) @@ -944,11 +966,12 @@ mlir::LogicalResult DpasOp::verify() { if (getAcc()) { if (getAccType() != getResultType()) - return emitOpError("Accumulator and Result for dpas op should have the same type (both shape and element type)."); + return emitOpError("Accumulator and Result for dpas op should have the " + "same type (both shape and element type)."); } - // TODO: SIMT makes it harder to check semantic errors for DPAS op. - // the only thing we can check seems to be vnni factor. But it + // TODO: SIMT makes it harder to check semantic errors for DPAS op. + // the only thing we can check seems to be vnni factor. But it // depends on hardware though. // if (!dpasSupportedShapes(*this)) { // return emitOpError("Incorrect shapes for dpas op"); @@ -990,7 +1013,8 @@ ::mlir::ParseResult LoadGatherOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(maskRawOperands[0])) return mlir::failure(); - if (parseOptionalAttrDict(parser, result, + if (parseOptionalAttrDict( + parser, result, {"mode", "vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"})) return mlir::failure(); @@ -1034,9 +1058,8 @@ void LoadGatherOp::print(mlir::OpAsmPrinter &printer) { printer << ' ' << "{"; printer << "mode = " << getMode(); - if (getVnniAxisAttr()) + if (getVnniAxisAttr()) printer << ", vnni_axis = " << getVnniAxis().value(); - if (getTransposeAttr()) { printer << ", transpose = "; @@ -1076,16 +1099,17 @@ mlir::LogicalResult LoadGatherOp::verify() { }; auto tdescElemTy = getElementType(tdescTy); - auto valueElemTy = getElementType(valueTy); + auto valueElemTy = getElementType(valueTy); if (tdescElemTy != valueElemTy) - return emitOpError("Value should have the same element type as TensorDesc."); + return emitOpError( + "Value should have the same element type as TensorDesc."); auto getShape = [&](mlir::Type type, std::vector &shape) -> void { if (type.isIntOrIndexOrFloat()) shape.push_back(1); else if (llvm::isa(type)) shape = llvm::dyn_cast(type).getShape().vec(); - else + else assert(0 && "Unreachable !!!"); }; @@ -1099,9 +1123,10 @@ mlir::LogicalResult LoadGatherOp::verify() { if (getTranspose()) { auto trans = getTranspose().value(); - if (tdescShape.size() >= trans.size()) + if (tdescShape.size() >= trans.size()) transpose(trans, tdescShape); - else emitWarning("Invalid transpose attr. It is ignored."); + else + emitWarning("Invalid transpose attr. It is ignored."); } if (getVnniAxis()) { @@ -1113,13 +1138,18 @@ mlir::LogicalResult LoadGatherOp::verify() { } if (valueShape != tdescShape) - return emitOpError("Result shape doesn't match TensorDesc shape. when VNNI is not enabled," - "the result should have the same shape (or transposed shape if transpose" - "is also enabled) as TensorDesc. When VNNI is enabled, the result should" - "have one more dimention than the TensorDesc, with last dimention having" - "vnni factor, but having same number of total data elements. The vnni " - "factor are typically calculated as simd_lane_width / elementTypeBitWidth." - "For element type having more than 32 bits, vnni shouldn't be used.\n"); + return emitOpError( + "Result shape doesn't match TensorDesc shape. when VNNI is not enabled," + "the result should have the same shape (or transposed shape if " + "transpose" + "is also enabled) as TensorDesc. When VNNI is enabled, the result " + "should" + "have one more dimention than the TensorDesc, with last dimention " + "having" + "vnni factor, but having same number of total data elements. The vnni " + "factor are typically calculated as simd_lane_width / " + "elementTypeBitWidth." + "For element type having more than 32 bits, vnni shouldn't be used.\n"); return ::mlir::success(); } @@ -1168,8 +1198,8 @@ ::mlir::ParseResult StoreScatterOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(maskRawOperands[0])) return mlir::failure(); - if (parseOptionalAttrDict(parser, result, {"mode", "l1_hint", "l2_hint", "l3_hint"}, - true)) + if (parseOptionalAttrDict(parser, result, + {"mode", "l1_hint", "l2_hint", "l3_hint"}, true)) return mlir::failure(); if (parser.parseColon()) @@ -1242,51 +1272,57 @@ ::mlir::LogicalResult StoreScatterOp::verify() { shape.push_back(1); else if (llvm::isa(type)) shape = llvm::dyn_cast(type).getShape().vec(); - else + else assert(0 && "Unreachable !!!"); }; getShape(valueTy, valueShape); getShape(maskTy, maskShape); - if (tdescTy.getShape().vec() != maskShape || valueShape != maskShape ) { - return emitOpError("Mask and value should have the same shape/size as TensorDesc." - "Mask and Value can be scalar if TensorDesc is in form of TensorDesc<1xf16>."); + if (tdescTy.getShape().vec() != maskShape || valueShape != maskShape) { + return emitOpError( + "Mask and value should have the same shape/size as TensorDesc." + "Mask and Value can be scalar if TensorDesc is in form of " + "TensorDesc<1xf16>."); } return ::mlir::success(); } ::mlir::LogicalResult UpdateOffsetOp::verify() { auto srcTy = getTensorDesc().getType(); - auto offTy = getOffsets().getType(); + auto offTy = getOffsets().getType(); auto resTy = getResult().getType(); if (srcTy != resTy) - return emitOpError("The result should have the same type" - "(shape and encoding attribute) as the input TensorDesc."); + return emitOpError( + "The result should have the same type" + "(shape and encoding attribute) as the input TensorDesc."); auto shape = srcTy.getShape(); auto encoding = srcTy.getEncoding(); if (!encoding || !llvm::isa(encoding)) { - return emitOpError("Invalid TensorDesc, it should have a scattered attribute."); + return emitOpError( + "Invalid TensorDesc, it should have a scattered attribute."); } - // For VC mode with chunkSize > 1. For chunkSize == 1, it is hard to distinguish - // between VC and SIMT mode by only looking at updateOffsetOp itself. So current - // verifier skipped these two cases. + // For VC mode with chunkSize > 1. For chunkSize == 1, it is hard to + // distinguish between VC and SIMT mode by only looking at updateOffsetOp + // itself. So current verifier skipped these two cases. if (shape.size() == 2) { if (!llvm::isa(offTy)) - return emitOpError("Based on TensorDesc shape, it is an VC tensor descriptor, " - "in which the offset should be an 1D vector."); + return emitOpError( + "Based on TensorDesc shape, it is an VC tensor descriptor, " + "in which the offset should be an 1D vector."); auto vecTy = llvm::dyn_cast(offTy); if (vecTy.getRank() != 1) - return emitOpError("The index should be an 1D vector Type for VC mode tensor descriptor."); + return emitOpError("The index should be an 1D vector Type for VC mode " + "tensor descriptor."); if (shape[0] != vecTy.getShape()[0]) - return emitOpError("For VC Mode TensorDesc. The offset should have same" - "length as the dim-0 of TensorDesc."); + return emitOpError("For VC Mode TensorDesc. The offset should have same" + "length as the dim-0 of TensorDesc."); } return ::mlir::success(); diff --git a/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir b/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir index 20882542c..f23e6f659 100644 --- a/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir +++ b/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir @@ -106,8 +106,8 @@ func.func @test_create_nd_tdesc_vc_9(%src: memref, %w : index, %h : ind %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: {mode = simt, memory_scope = slm, boundary_check = true} - // CHECK-SAME: !xegpu.tensor_desc<64x128xf32, #xegpu.xe_map> - %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {memory_scope = slm, boundary_check = true} : memref + // CHECK-SAME: !xegpu.tensor_desc<64x128xf32, #xegpu.xe_map> + %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {memory_scope = slm, boundary_check = true} : memref -> !xegpu.tensor_desc<64x128xf32, #xegpu.xe_map> return } diff --git a/test/Dialect/XeGPU/IR/create_tdesc.mlir b/test/Dialect/XeGPU/IR/create_tdesc.mlir index 079159655..98e58b55b 100644 --- a/test/Dialect/XeGPU/IR/create_tdesc.mlir +++ b/test/Dialect/XeGPU/IR/create_tdesc.mlir @@ -7,8 +7,8 @@ // CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) { func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 - // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} + // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> return @@ -16,8 +16,8 @@ func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { // CHECK-LABEL: func @test_create_tdesc_vc_2({{.*}}) { func.func @test_create_tdesc_vc_2(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 - // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 1} + // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 1} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc, memory_scope=slm} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> @@ -26,20 +26,20 @@ func.func @test_create_tdesc_vc_2(%src: ui64, %offsets : vector<16 x index>) { // CHECK-LABEL: func @test_create_tdesc_vc_3({{.*}}) { func.func @test_create_tdesc_vc_3(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK: xegpu.create_tdesc %arg0, %arg1 // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 8} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8} + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> return } // CHECK-LABEL: func @test_create_tdesc_vc_4({{.*}}) { func.func @test_create_tdesc_vc_4(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 - // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} + // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> return } @@ -75,4 +75,3 @@ func.func @test_create_tdesc_vc_7(%src: memref, %offset : index) { : memref, index -> !xegpu.tensor_desc<1xf32, #xegpu.scattered> return } - diff --git a/test/Dialect/XeGPU/IR/invalid.mlir b/test/Dialect/XeGPU/IR/invalid.mlir index 893741193..f3d68254c 100644 --- a/test/Dialect/XeGPU/IR/invalid.mlir +++ b/test/Dialect/XeGPU/IR/invalid.mlir @@ -44,7 +44,7 @@ func.func @test_create_nd_tdesc_vc_4(%input: memref) { %c8 = arith.constant 8 : index // expected-error@+1 {{Expecting the rank of shape, strides and offsets should match with each other}} - %1 = xegpu.create_nd_tdesc %input[%c1], [%c8], [%c1] {mode = vc} + %1 = xegpu.create_nd_tdesc %input[%c1], [%c8], [%c1] {mode = vc} : memref -> !xegpu.tensor_desc<8x16xf32> return } @@ -55,7 +55,7 @@ func.func @test_create_nd_tdesc_vc_5(%input: memref<24x32x64xf32>) { %c8 = arith.constant 8 : index // expected-error@+1 {{operand #0 must be 1D/2D memref}} - %1 = xegpu.create_nd_tdesc %input[%c1, %c1, %c8] {mode = vc} + %1 = xegpu.create_nd_tdesc %input[%c1, %c1, %c8] {mode = vc} : memref<24x32x64xf32> -> !xegpu.tensor_desc<8x16x8xf32> return } @@ -63,7 +63,7 @@ func.func @test_create_nd_tdesc_vc_5(%input: memref<24x32x64xf32>) { // ----- func.func @test_create_tdesc(%src: ui64, %offsets : vector<16x8xindex>) { // expected-error@+1 {{operand #1 must be vector of index values of ranks 1}} - %1 = xegpu.create_tdesc %src, %offsets {mode = vc} + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16x8xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> return } @@ -82,5 +82,3 @@ func.func @test_load_gather(%src: ui64, %offsets : vector<16xindex>) { : !xegpu.tensor_desc<16x8xf16, #xegpu.scattered>, vector<16x8xi1> -> vector<8x8x4xf16> return } - - diff --git a/test/Dialect/XeGPU/IR/load_gather.mlir b/test/Dialect/XeGPU/IR/load_gather.mlir index 39ce00088..b04dd022d 100644 --- a/test/Dialect/XeGPU/IR/load_gather.mlir +++ b/test/Dialect/XeGPU/IR/load_gather.mlir @@ -16,7 +16,7 @@ func.func @test_load_gather_vc(%src: ui64, %offsets : vector<16xindex>) { // CHECK: xegpu.load // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> return } @@ -33,7 +33,7 @@ func.func @test_load_gather_vc_2(%src: ui64, %offsets : vector<16xindex>) { // CHECK: xegpu.load // CHECK-SAME: {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> - %2 = xegpu.load %1, %0 {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached} + %2 = xegpu.load %1, %0 {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> return } @@ -69,8 +69,7 @@ func.func @test_load_gather_vc_4(%src: ui64, %offsets : vector<16xindex>) { // CHECK: xegpu.load // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> return } - diff --git a/test/Dialect/XeGPU/IR/store_scatter.mlir b/test/Dialect/XeGPU/IR/store_scatter.mlir index 19238afab..6786692f7 100644 --- a/test/Dialect/XeGPU/IR/store_scatter.mlir +++ b/test/Dialect/XeGPU/IR/store_scatter.mlir @@ -61,4 +61,3 @@ func.func @test_store_scatter(%src: ui64, %offsets : index, %dst: ui64) { : f32, !xegpu.tensor_desc<1xf32, #xegpu.scattered>, i1 return } - diff --git a/test/Dialect/XeGPU/IR/update_offset.mlir b/test/Dialect/XeGPU/IR/update_offset.mlir index b1e712a3c..539a72f48 100644 --- a/test/Dialect/XeGPU/IR/update_offset.mlir +++ b/test/Dialect/XeGPU/IR/update_offset.mlir @@ -11,7 +11,7 @@ func.func @test_update_offset_VC(%src: ui64, %offsets : vector<16 x index>) { // CHECK: xegpu.create_tdesc // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {mode = vc} + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> // CHECK: xegpu.load @@ -57,4 +57,3 @@ func.func @test_update_offset(%src: ui64, %offsets : index) { return } -