Skip to content

Commit

Permalink
Initial working implementation of debufferize flow for linalg with ex…
Browse files Browse the repository at this point in the history
…amples
  • Loading branch information
arpitj1 committed Jan 13, 2025
1 parent e2b4b2d commit f2ab09e
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 1 deletion.
39 changes: 39 additions & 0 deletions debufferize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//polygeist-opt --linalg-debufferize debufferize.mlir

#map16 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map17 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)>
#map18 = affine_map<(d0, d1, d2, d3) -> (d1, d0)>
#map19 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>

module @in_place_add{
func.func @in_place_add(%value: f32) {
%c0 = arith.constant 0 : index
%buffer = memref.alloca() : memref<128xf32>
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
linalg.yield %sum : f32
}
return
}
}

module @conv_2 {
func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
%c0_i32 = arith.constant 0 : i32
%0 = memref.alloca() : memref<515x67xi32>
%1 = memref.alloca() : memref<4x4xi32>
%2 = memref.alloca() : memref<512x64xi32>
linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) {
^bb0(%in: i32, %in_0: i32, %out: i32):
%3 = arith.muli %in, %in_0 : i32
%4 = arith.addi %out, %3 : i32
linalg.yield %4 : i32
}
return %c0_i32 : i32
}
}
9 changes: 9 additions & 0 deletions include/polygeist/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ std::unique_ptr<Pass> createOpenMPOptPass();
std::unique_ptr<Pass> createCanonicalizeForPass();
std::unique_ptr<Pass> createRaiseSCFToAffinePass();
std::unique_ptr<Pass> createRaiseAffineToLinalgPass();
std::unique_ptr<Pass> createLinalgDebufferizePass();
std::unique_ptr<Pass> createRemoveIterArgsPass();
std::unique_ptr<Pass> createCPUifyPass(StringRef method = "");
std::unique_ptr<Pass> createBarrierRemovalContinuation();
Expand Down Expand Up @@ -129,6 +130,14 @@ namespace linalg {
class LinalgDialect;
}

namespace bufferization {
class BufferizationDialect;
}

namespace Tensor {
class TensorDialect;
}

namespace LLVM {
class LLVMDialect;
}
Expand Down
11 changes: 11 additions & 0 deletions include/polygeist/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ def RemoveIterArgs : Pass<"remove-iter-args"> {
];
}

def LinalgDebufferize : Pass<"linalg-debufferize"> {
let summary = "Raise affine to linalg";
let constructor = "mlir::polygeist::createLinalgDebufferizePass()";
let dependentDialects = [
"affine::AffineDialect",
"linalg::LinalgDialect",
"bufferization::BufferizationDialect",
"polygeist::PolygeistDialect",
];
}

def AffineRaiseToLinalg : Pass<"raise-affine-to-linalg"> {
let summary = "Raise affine to linalg";
let constructor = "mlir::polygeist::createRaiseAffineToLinalgPass()";
Expand Down
2 changes: 1 addition & 1 deletion lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5843,7 +5843,7 @@ struct SubMapOpCanonicalize : public OpRewritePattern<polygeist::SubmapOp> {
//If inverse permutation exists, then we can canonicalize the linalg of submap to linalg
//TODO: Fails for:
// 1. Maps with symbols
// 2. Maps with non
// 2. Maps which are not resolvable 1 to 1 with memref for all dims
if(inversePermutation(concatAffineMaps(maps))) {
StringAttr empty = StringAttr::get(genericOp.getContext());
auto newGenericOp = rewriter.create<linalg::GenericOp>(genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, listOfNewMaps, genericOp.getIteratorTypesArray(),
Expand Down
1 change: 1 addition & 0 deletions lib/polygeist/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms
RaiseToAffine.cpp
RemoveIterArgs.cpp
RaiseToLinalg.cpp
LinalgDebufferize.cpp
ParallelLower.cpp
TrivialUse.cpp
ConvertPolygeistToLLVM.cpp
Expand Down
224 changes: 224 additions & 0 deletions lib/polygeist/Passes/LinalgDebufferize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#include "PassDetails.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "polygeist/Passes/Passes.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "linalg-debufferize"

using namespace mlir;
using namespace mlir::arith;
using namespace polygeist;
using namespace affine;
using namespace linalg;
using namespace tensor;
using namespace bufferization;



//module @harris_score_with_gradient_extra_kernel {
// memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1>
// memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
// memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
// memref.global @score : memref<512x512xi32> = uninitialized
// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
// %c4_i32 = arith.constant 4 : i32
// %c0_i32 = arith.constant 0 : i32
// %alloca = memref.alloca() : memref<512x512xi32>
// %alloca_0 = memref.alloca() : memref<512x512xi32>
// %alloca_1 = memref.alloca() : memref<512x512xi32>
// %alloca_2 = memref.alloca() : memref<516x516xi32>
// %alloca_3 = memref.alloca() : memref<516x516xi32>
// %alloca_4 = memref.alloca() : memref<518x518xi32>
// %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32>
// %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32>
// %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32>
// // 2nd variant
// // %0 = memref.alloca() : memref<3x3xi32>
// // %1 = memref.alloca() : memref<3x3xi32>
// // %2 = memref.alloca() : memref<5x5xi32>
// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32):
// %4 = arith.muli %in, %in_5 : i32
// %5 = arith.addi %out_7, %4 : i32
// %6 = arith.muli %in, %in_6 : i32
// %7 = arith.addi %out, %6 : i32
// linalg.yield %7, %5 : i32, i32
// }
// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32):
// %4 = arith.muli %in, %in : i32
// %5 = arith.muli %4, %in_6 : i32
// %6 = arith.addi %out_8, %5 : i32
// %7 = arith.muli %in_5, %in_5 : i32
// %8 = arith.muli %7, %in_6 : i32
// %9 = arith.addi %out_7, %8 : i32
// %10 = arith.muli %in, %in_5 : i32
// %11 = arith.muli %10, %in_6 : i32
// %12 = arith.addi %out, %11 : i32
// linalg.yield %12, %9, %6 : i32, i32, i32
// }
// %3 = memref.get_global @score : memref<512x512xi32>
// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%3 : memref<512x512xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32):
// %4 = arith.muli %in, %in_5 : i32
// %5 = arith.muli %in_6, %in_6 : i32
// %6 = arith.subi %4, %5 : i32
// %7 = arith.addi %in, %in_5 : i32
// %8 = arith.muli %7, %c4_i32 : i32
// %9 = arith.muli %8, %7 : i32
// %10 = arith.subi %6, %9 : i32
// linalg.yield %10 : i32
// }
// return %c0_i32 : i32
// }
// }
struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern<func::FuncOp>::OpRewritePattern;

LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const final {

auto module = funcOp->getParentOfType<ModuleOp>();

SmallVector<Operation*> opsToDelete;
llvm::SmallPtrSet<Operation*, 16> opsToDeleteSet;
//Tracks both old linalg.generics and linalg.generics with repeated values in ins and outs
llvm::SmallPtrSet<Operation*, 16> processedGenericOps;

LogicalResult passResult = success();
funcOp.walk([&](mlir::memref::AllocaOp allocaOp) -> WalkResult {
auto module = allocaOp->getParentOfType<ModuleOp>();
rewriter.setInsertionPointAfter(allocaOp);
auto tensorType = RankedTensorType::get(allocaOp.getType().getShape(), allocaOp.getType().getElementType());

//Check to see if only linalg.generic are users of the allocaOp for now.
//TODO: Extend this
if(!llvm::all_of(allocaOp->getUsers(),[](Operation *op) {
return isa<linalg::GenericOp>(op);
})){
passResult = failure();
return WalkResult::interrupt();
}

//auto emptyTensor = rewriter.create<tensor::EmptyOp>(allocaOp.getLoc(),allocaOp.getType().getShape(), allocaOp.getType().getElementType());
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(
allocaOp.getLoc(),
tensorType,
allocaOp);
Value currentTensor = toTensorOp;

//Check if allocaOp is an output in current genericOp
for (auto user : allocaOp->getUsers()) {
if (auto genericOp = dyn_cast<linalg::GenericOp>(user)) {

//auto genericOp = cast<linalg::GenericOp>(user);
if(processedGenericOps.count(genericOp) > 0)
continue;
rewriter.setInsertionPointAfter(genericOp);

SmallVector<Value, 4> newInputs;
SmallVector<Value, 4> newOutputs;
SmallVector<Type> resultTypes;
//Create a new linalg.generic in Destination Style Passing format

ArrayAttr indexingMaps = genericOp.getIndexingMaps();
for(auto input : genericOp.getInputs()){
newInputs.push_back(input == allocaOp ? currentTensor : input);
}

//ArrayRef<Type> resultTypes;
int newCurrentTensorIndex = -1;
int index = 0;
for(auto output : genericOp.getOutputs()){
newOutputs.push_back(output == allocaOp ? currentTensor : output);
resultTypes.push_back(currentTensor.getType());
if(output == allocaOp) {
newCurrentTensorIndex = index;
}
index++;
}

StringAttr empty = StringAttr::get(genericOp.getContext());
ArrayRef<Type> resultTypesRef(resultTypes);
auto newGenericOp = rewriter.create<linalg::GenericOp>(genericOp.getLoc(), resultTypesRef, newInputs, newOutputs,
genericOp.getIndexingMaps(), genericOp.getIteratorTypes(), empty, empty);

Region &opRegion = newGenericOp.getRegion();
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().end());

//Replace all uses of original generic op with the new one
int idxOldGeneric=0;
int idxNewGeneric=0;
for (unsigned i = 0; i < genericOp->getNumResults(); ++i) {
if(i == newCurrentTensorIndex) {
idxNewGeneric++;
}
genericOp->getResult(i).replaceAllUsesWith(newGenericOp->getResult(i));
idxOldGeneric++;
idxNewGeneric++;
}

//Delete the original genericOp
opsToDelete.push_back(genericOp.getOperation());
if(newCurrentTensorIndex != -1)
currentTensor = newGenericOp.getResult(newCurrentTensorIndex);

processedGenericOps.insert(genericOp.getOperation());
}
}

auto toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
allocaOp.getLoc(),
allocaOp.getType(),
currentTensor);
rewriter.create<memref::CopyOp>(allocaOp.getLoc(), toMemrefOp, allocaOp);
//opsToDelete.push_back(allocaOp.getOperation());
return WalkResult::advance();
});
for (Operation *op : opsToDelete) {
op->erase();
}
opsToDelete.clear();

return passResult;
}
};

namespace {
struct LinalgDebufferize
: public LinalgDebufferizeBase<LinalgDebufferize> {
void runOnOperation() override;
};
} // namespace

void LinalgDebufferize::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.insert<LinalgDebufferization>(&getContext());
GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}

namespace mlir {
namespace polygeist {
std::unique_ptr<Pass> createLinalgDebufferizePass() {
return std::make_unique<LinalgDebufferize>();
}
} // namespace polygeist
} // namespace mlir

0 comments on commit f2ab09e

Please sign in to comment.