-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
//===- TileLoops.cpp ------------------------------------*- C++ -*-===// | ||
// | ||
// Copyright 2023 Intel Corporation | ||
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
/// | ||
/// \file | ||
/// This file implements the TileLoops transform which tiles loops for GPU | ||
/// mapping. | ||
/// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include <imex/Utils/PassUtils.h> | ||
#include <mlir/Dialect/Func/IR/FuncOps.h> | ||
#include <mlir/Dialect/Linalg/IR/Linalg.h> | ||
#include <mlir/Dialect/SCF/IR/SCF.h> | ||
#include <mlir/Dialect/SCF/Transforms/TileUsingInterface.h> | ||
#include <mlir/Interfaces/TilingInterface.h> | ||
#include <mlir/Pass/Pass.h> | ||
|
||
#include "llvm/Support/Threading.h" | ||
#include <imex/Dialect/Region/RegionUtils.h> | ||
#include <imex/Transforms/Passes.h> | ||
|
||
namespace imex { | ||
#define GEN_PASS_DEF_TILELOOPS | ||
#include "imex/Transforms/Passes.h.inc" | ||
} // namespace imex | ||
|
||
#define DEBUG_TYPE "tile-loops" | ||
|
||
#ifndef NDEBUG | ||
#define DEBUG_MSG(PREFIX, MSG) \ | ||
LLVM_DEBUG(llvm::dbgs() << PREFIX << ": " << MSG << "\n"); | ||
#define DEBUG_OP(PREFIX, MSG, OP) \ | ||
LLVM_DEBUG(llvm::dbgs() << PREFIX << ": " << MSG << " '" << OP->getName() \ | ||
<< "' " << OP->getLoc() << "\n"); | ||
#define DEBUG_OP_VEC(PREFIX, MSG, OPVEC) \ | ||
LLVM_DEBUG(llvm::dbgs() << PREFIX << ": " << MSG << " (" << OPVEC.size() \ | ||
<< ")\n"); \ | ||
for (auto op : OPVEC) { \ | ||
DEBUG_OP(PREFIX, " ", op) \ | ||
} | ||
#endif | ||
|
||
using namespace imex; | ||
|
||
namespace { | ||
|
||
static ::mlir::FailureOr<::mlir::SmallVector<int64_t>> | ||
getDefaultTileSizes(::mlir::linalg::LinalgOp linalgOp, | ||
::mlir::ArrayRef<int64_t> userProvidedTiles) { | ||
// The user-provided tiles are considered from the outer | ||
// most loop. If not enough tiles are provided we pad with | ||
// zeros. | ||
if (!userProvidedTiles.empty()) { | ||
size_t numParallelLoops = linalgOp.getNumParallelLoops(); | ||
size_t nonZeros = 0; | ||
for (auto tile : userProvidedTiles) | ||
if (tile != 0) | ||
nonZeros++; | ||
if (nonZeros > numParallelLoops || | ||
userProvidedTiles.size() > linalgOp.getNumLoops()) { | ||
return ::mlir::failure(); | ||
} | ||
|
||
::mlir::SmallVector<int64_t> userTiles(linalgOp.getNumLoops(), 0); | ||
for (auto tile : ::llvm::enumerate(userProvidedTiles)) | ||
userTiles[tile.index()] = tile.value(); | ||
return userTiles; | ||
} | ||
// FIXME | ||
return ::mlir::failure(); | ||
} | ||
|
||
struct TileLoops final : public imex::impl::TileLoopsBase<TileLoops> { | ||
|
||
using TileLoopsBase::TileLoopsBase; | ||
|
||
void runOnOperation() override { | ||
|
||
::mlir::func::FuncOp func = getOperation(); | ||
::mlir::IRRewriter rewriter(&getContext()); | ||
transform(rewriter, func, this->tileSizes, this->minTileFactor); | ||
|
||
return; | ||
} | ||
|
||
private: | ||
void transform(::mlir::RewriterBase &rewriter, ::mlir::func::FuncOp func, | ||
::mlir::ArrayRef<int64_t> tileSizes, int64_t minTileFactor) { | ||
DEBUG_MSG("tile-loops", "Entering transform"); | ||
::mlir::SmallVector<::mlir::Operation *> allLinalgOps; | ||
func->walk([&](::mlir::linalg::LinalgOp linalgOp) { | ||
if (!inRegions || ::imex::region::isInGpuRegion(linalgOp)) { | ||
allLinalgOps.push_back(linalgOp); | ||
} | ||
}); | ||
DEBUG_OP_VEC("tile-loops", " Found linalg ops", allLinalgOps); | ||
|
||
for (auto op : allLinalgOps) { | ||
DEBUG_OP("tile-loops", " Tiling op:", op); | ||
auto tiles = getDefaultTileSizes( | ||
::llvm::cast<::mlir::linalg::LinalgOp>(op), tileSizes); | ||
if (failed(tiles)) { | ||
DEBUG_MSG("tile-loops", | ||
" Failed to compute default tile sizes. Aborting."); | ||
return; | ||
} | ||
DEBUG_MSG("tile-loops", " tile sizes:"); | ||
LLVM_DEBUG(llvm::dbgs() << "tile-loops: ("); | ||
LLVM_DEBUG(llvm::interleaveComma(*tiles, llvm::dbgs())); | ||
LLVM_DEBUG(llvm::dbgs() << ")\n"); | ||
|
||
auto tilesRes = | ||
::mlir::getAsOpFoldResult(rewriter.getI64ArrayAttr(*tiles)); | ||
::mlir::scf::SCFTilingOptions options; | ||
options.setTileSizes(tilesRes); | ||
options.setLoopType(::mlir::scf::SCFTilingOptions::LoopType::ForallOp); | ||
auto tileOp = ::mlir::cast<::mlir::TilingInterface>(op); | ||
::mlir::FailureOr<::mlir::scf::SCFTilingResult> tilingResult = | ||
mlir::scf::tileUsingSCF(rewriter, tileOp, options); | ||
if (failed(tilingResult)) { | ||
DEBUG_MSG("tile-loops", " Failed to tile op. Aborting."); | ||
return; | ||
} | ||
DEBUG_MSG("tile-loops", " Tiling applied successfully."); | ||
rewriter.replaceOp(op, tilingResult.value().replacements); | ||
} | ||
} | ||
}; | ||
|
||
} // end anonymous namespace | ||
|
||
namespace imex { | ||
std::unique_ptr<mlir::Pass> createTileLoopsPass() { | ||
return std::make_unique<TileLoops>(); | ||
} | ||
} // namespace imex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// RUN: imex-opt --split-input-file -tile-loops='tile-sizes=32' -tile-loops='tile-sizes=1' %s -verify-diagnostics -o -| FileCheck %s | ||
|
||
#map = affine_map<(d0) -> (d0)> | ||
module { | ||
func.func @add(%arg0: tensor<129xf32>, %arg1: tensor<129xf32>, %arg2: tensor<129xf32>) -> tensor<129xf32> { | ||
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<129xf32>, tensor<129xf32>) outs(%arg2 : tensor<129xf32>) { | ||
^bb0(%in: f32, %in_0: f32, %out: f32): | ||
%1 = arith.addf %in, %in_0 : f32 | ||
linalg.yield %1 : f32 | ||
} -> tensor<129xf32> | ||
return %0 : tensor<129xf32> | ||
} | ||
} | ||
// CHECK-LABEL: func.func @add | ||
// CHECK-NEXT: %[[FORALL:.*]] = scf.forall (%arg3) = (0) to (129) step (32) shared_outs(%arg4 = %arg2) -> (tensor<129xf32>) { | ||
// CHECK-NEXT: %[[C129:.*]] = arith.constant 129 : index | ||
// CHECK-NEXT: %[[MIN:.*]] = affine.min #map(%[[ARG3:.*]]) | ||
// CHECK-NEXT: %[[APPLY1:.*]] = affine.apply #map1(%[[MIN]]) | ||
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %arg0[%[[ARG3]]] [%[[MIN]]] [1] : tensor<129xf32> to tensor<?xf32> | ||
// CHECK-NEXT: %[[EXTRACTED_SLICE_0:.*]] = tensor.extract_slice %arg1[%[[ARG3]]] [%[[MIN]]] [1] : tensor<129xf32> to tensor<?xf32> | ||
// CHECK-NEXT: %[[EXTRACTED_SLICE_1:.*]] = tensor.extract_slice %arg4[%[[ARG3]]] [%[[MIN]]] [1] : tensor<129xf32> to tensor<?xf32> | ||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[FORALL:.*]] = scf.forall (%[[ARG5:.*]]) in (%[[MIN]]) shared_outs(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_1]]) -> (tensor<?xf32>) { | ||
// CHECK-NEXT: %[[EXTRACTED_SLICE_4:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]]] [1] [1] : tensor<?xf32> to tensor<1xf32> | ||
// CHECK-NEXT: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_0]][%[[ARG5]]] [1] [1] : tensor<?xf32> to tensor<1xf32> | ||
// CHECK-NEXT: %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG5]]] [1] [1] : tensor<?xf32> to tensor<1xf32> | ||
// CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} ins(%[[EXTRACTED_SLICE_4]], %[[EXTRACTED_SLICE_5]] : tensor<1xf32>, tensor<1xf32>) outs(%[[EXTRACTED_SLICE_6]] : tensor<1xf32>) { | ||
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN_7:.*]]: f32, %[[OUT:.*]]: f32): | ||
// CHECK-NEXT: %[[ADDF:.*]] = arith.addf %[[IN]], %[[IN_7]] : f32 | ||
// CHECK-NEXT: linalg.yield %[[ADDF]] : f32 | ||
// CHECK-NEXT: } -> tensor<1xf32> | ||
// CHECK-NEXT: scf.forall.in_parallel { | ||
// CHECK-NEXT: tensor.parallel_insert_slice %[[GENERIC]] into %[[ARG6]][%[[ARG5]]] [1] [1] : tensor<1xf32> into tensor<?xf32> | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
// CHECK: scf.forall.in_parallel { | ||
// CHECK-NEXT: tensor.parallel_insert_slice %[[FORALL]] into %arg4[%[[ARG3]]] [%[[MIN]]] [1] : tensor<?xf32> into tensor<129xf32> | ||
// CHECK-NEXT: } |