Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
leshikus committed Aug 28, 2024
2 parents 914f24e + eb8c81a commit 30ffd1c
Show file tree
Hide file tree
Showing 39 changed files with 3,077 additions and 545 deletions.
9 changes: 6 additions & 3 deletions include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", []> {
mlir::Type getElementType() {
return getA().getType().getElementType();
}
mlir::VectorType getOutputType() {
return getOutput().getType();
}
}];

let hasVerifier = 1;
Expand Down Expand Up @@ -581,7 +584,7 @@ def XeTile_TransposeOp: XeTile_Op<"transpose", []> {
let hasVerifier = 1;
}

def XeTile_ReduceOp: XeTile_Op<"reduce", []> {
def XeTile_ReductionOp: XeTile_Op<"reduction", []> {
let summary = "performs a reduction operation over a 2D vector.";
let description = [{
It has the same semantics as the `vector.multi_reduction`,
Expand All @@ -591,10 +594,10 @@ def XeTile_ReduceOp: XeTile_Op<"reduce", []> {

let arguments = (ins Vector_CombiningKindAttr: $kind,
XeTile_2DOr4DVector: $source,
DenseI64ArrayAttr: $reduction_dim);
DenseI64ArrayAttr: $reduction_dims);
let results = (outs XeTile_2DOr4DVector: $result);
let assemblyFormat = [{
$kind `,` $source $reduction_dim attr-dict `:` type($source) `->` type($result)
$kind `,` $source $reduction_dims attr-dict `:` type($source) `->` type($result)
}];

let hasVerifier = 1;
Expand Down
2 changes: 2 additions & 0 deletions include/imex/Dialect/XeTile/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ std::unique_ptr<mlir::Pass> createXeTileInitDuplicatePass();

std::unique_ptr<mlir::Pass>
createXeTileBlockingPass(const std::string &device = "pvc");
std::unique_ptr<mlir::Pass>
createNewXeTileBlockingPass(const std::string &device = "pvc");
std::unique_ptr<mlir::Pass> createXeTileBlockAligningPass();
std::unique_ptr<mlir::Pass> createXeTileWgToSgPass();
std::unique_ptr<mlir::Pass> createXeTileOptimizeTransposePass();
Expand Down
26 changes: 26 additions & 0 deletions include/imex/Dialect/XeTile/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,31 @@ def XeTileCanonicalization : Pass<"xetile-canonicalization", "::mlir::gpu::GPUMo
];
}

def NewXeTileBlocking : Pass<"new-xetile-blocking", "::mlir::gpu::GPUModuleOp">{
let summary = "transform XeTile large tiles(input) into arrays of smaller "
"blocks with appropriate size, such that the operator on each "
"of the blocks can be mapped into one hardware instruction.";

let description = [{
This transform pass preprocesses the xetile program by decomposing large XeTile tiles
into smaller ones that can be handled by a hardware instruction. It is going to replace
the xetile-blocking pass.
}];

let constructor = "imex::createNewXeTileBlockingPass()";
let dependentDialects = ["imex::xetile::XeTileDialect",
"mlir::arith::ArithDialect",
"mlir::math::MathDialect",
"mlir::gpu::GPUDialect",
"mlir::memref::MemRefDialect",
"mlir::vector::VectorDialect"];

let options = [
Option<"device", "device", "std::string",
/*default=*/"\"pvc\"",
"gpu platform architecture where these ops are running">
];
}


#endif // _XeTile_PASSES_TD_INCLUDED_
5 changes: 5 additions & 0 deletions include/imex/ExecutionEngine/ImexRunnerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ _mlir_ciface_fillResource1DRandomF16(UnrankedMemRefType<f16> *ptr,
const float lower, const float upper,
const bool genInt);

extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_fillResource1DRandomF32(UnrankedMemRefType<float> *ptr,
const float lower, const float upper,
const bool genInt);

extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *m);
extern "C" IMEX_RUNNERUTILS_EXPORT void
Expand Down
13 changes: 7 additions & 6 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,15 +736,16 @@ extern llvm::SmallVector<mlir::Value> lowerInnerReductionWithVectorReduction(
mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy,
XeOneToNPatternRewriter &rewriter);

struct SgTileReduceOpPattern : public XeOneToNConversion<xetile::ReduceOp> {
using XeOneToNConversion<xetile::ReduceOp>::XeOneToNConversion;
struct SgTileReductionOpPattern
: public XeOneToNConversion<xetile::ReductionOp> {
using XeOneToNConversion<xetile::ReductionOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(xetile::ReduceOp op, OpAdaptor adaptor,
matchAndRewrite(xetile::ReductionOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
auto srcTy = op.getSource().getType();
auto elemTy = srcTy.getElementType();
auto dims = op.getReductionDim();
auto dims = op.getReductionDims();
// its input should be a 4D vector, and has 2 reduction dims,
// otherwise run the blocking pass first.
if (dims.size() != 2 || srcTy.getRank() != 4)
Expand Down Expand Up @@ -1092,8 +1093,8 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
SgTileMMAOpPattern, SgUpdateTileOffsetOpPattern,
SgTransposeOpPattern<mlir::vector::TransposeOp>,
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
SgTileReduceOpPattern, SgVectorCreateMaskOpPattern>(patterns.getContext(),
converter, analysis);
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>(
patterns.getContext(), converter, analysis);
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
ElementWiseOpPattern<mlir::math::ExpOp, 1>,
ElementWiseOpPattern<mlir::math::SinOp, 1>,
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/XeTile/IR/XeTileOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,8 @@ mlir::LogicalResult TransposeOp::verify() {
return mlir::success();
}

mlir::LogicalResult ReduceOp::verify() {
auto dims = getReductionDim();
mlir::LogicalResult ReductionOp::verify() {
auto dims = getReductionDims();
auto resShape = getResult().getType().getShape();
for (auto i : dims)
if (resShape[i] != 1)
Expand Down
23 changes: 12 additions & 11 deletions lib/Dialect/XeTile/Transforms/Blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,29 +556,30 @@ struct VectorMultiDimReductionOpPattern
}
};

struct TileReduceOpPattern
: public XeTileConversion<xetile::ReduceOp, TileUsageAnalysis> {
struct TileReductionOpPattern
: public XeTileConversion<xetile::ReductionOp, TileUsageAnalysis> {

using XeTileConversion<xetile::ReduceOp, TileUsageAnalysis>::XeTileConversion;
using XeTileConversion<xetile::ReductionOp,
TileUsageAnalysis>::XeTileConversion;

TileReduceOpPattern(mlir::MLIRContext *context,
imex::XeTypeConverter &converter,
TileUsageAnalysis &analysis,
std::shared_ptr<XeuArchInterface> ptruArch)
TileReductionOpPattern(mlir::MLIRContext *context,
imex::XeTypeConverter &converter,
TileUsageAnalysis &analysis,
std::shared_ptr<XeuArchInterface> ptruArch)
: XeTileConversion(context, converter, analysis) {
this->uArchInterface = ptruArch;
}

std::shared_ptr<XeuArchInterface> uArchInterface = nullptr;

mlir::LogicalResult
matchAndRewrite(xetile::ReduceOp op, OpAdaptor adaptor,
matchAndRewrite(xetile::ReductionOp op, OpAdaptor adaptor,
OpPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto srcTy = op.getSource().getType();
auto elemTy = srcTy.getElementType();
auto shape = srcTy.getShape();
auto reductionDims = op.getReductionDim();
auto reductionDims = op.getReductionDims();

if (srcTy.getRank() != 2 || reductionDims.size() != 1)
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -611,7 +612,7 @@ struct TileReduceOpPattern

auto newSource =
addPackOp(adaptor.getSource(), {blkSizes[0], blkSizes[1]}, rewriter);
auto newDest = rewriter.create<xetile::ReduceOp>(
auto newDest = rewriter.create<xetile::ReductionOp>(
loc, newDestType, op.getKind(), newSource, newReductionDims);
auto unpack = addUnpackOp(newDest.getResult(), rewriter);
rewriter.replaceOp(op, unpack);
Expand Down Expand Up @@ -1161,7 +1162,7 @@ void populateXeTileBlockingPatterns(
VectorizableOpPattern, SCFForOpPattern, SCFYieldOpPattern,
InitTileOpPattern, LoadTileOpPattern, StoreTileOpPattern,
TileMMAOpPattern, UpdateTileOffsetOpPattern,
VectorMultiDimReductionOpPattern, TileReduceOpPattern,
VectorMultiDimReductionOpPattern, TileReductionOpPattern,
TileBroadcastOpPattern>(patterns.getContext(), converter,
analysis, ptruArch);
patterns.insert<TransposeOpPattern<mlir::vector::TransposeOp>,
Expand Down
Loading

0 comments on commit 30ffd1c

Please sign in to comment.