diff --git a/debufferize.mlir b/debufferize.mlir new file mode 100644 index 000000000000..3e310644f4bc --- /dev/null +++ b/debufferize.mlir @@ -0,0 +1,39 @@ +//polygeist-opt --linalg-debufferize debufferize.mlir + +#map16 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map17 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> +#map18 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> +#map19 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + + module @in_place_add{ + func.func @in_place_add(%value: f32) { + %c0 = arith.constant 0 : index + %buffer = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + +module @conv_2 { + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.alloca() : memref<515x67xi32> + %1 = memref.alloca() : memref<4x4xi32> + %2 = memref.alloca() : memref<512x64xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %3 = arith.muli %in, %in_0 : i32 + %4 = arith.addi %out, %3 : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } +} \ No newline at end of file diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index ad7e2fc14fc2..7a95484a2fdb 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -32,6 +32,7 @@ std::unique_ptr createOpenMPOptPass(); std::unique_ptr createCanonicalizeForPass(); std::unique_ptr createRaiseSCFToAffinePass(); std::unique_ptr createRaiseAffineToLinalgPass(); +std::unique_ptr createLinalgDebufferizePass(); std::unique_ptr createRemoveIterArgsPass(); std::unique_ptr createCPUifyPass(StringRef method = ""); std::unique_ptr createBarrierRemovalContinuation(); @@ -129,6 +130,14 @@ namespace linalg { class LinalgDialect; } +namespace bufferization { +class BufferizationDialect; +} + +namespace Tensor { +class TensorDialect; +} + namespace LLVM { class LLVMDialect; } diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 7d5f2315f4ce..5b8251c616b8 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -160,6 +160,17 @@ def RemoveIterArgs : Pass<"remove-iter-args"> { ]; } +def LinalgDebufferize : Pass<"linalg-debufferize"> { + let summary = "Raise affine to linalg"; + let constructor = "mlir::polygeist::createLinalgDebufferizePass()"; + let dependentDialects = [ + "affine::AffineDialect", + "linalg::LinalgDialect", + "bufferization::BufferizationDialect", + "polygeist::PolygeistDialect", + ]; +} + def AffineRaiseToLinalg : Pass<"raise-affine-to-linalg"> { let summary = "Raise affine to linalg"; let constructor = "mlir::polygeist::createRaiseAffineToLinalgPass()"; diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index c694c8520ef4..4010e58330cb 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5843,7 +5843,7 @@ struct SubMapOpCanonicalize : public OpRewritePattern { //If inverse permutation exists, then we can canonicalize the linalg of submap to linalg //TODO: Fails for: // 1. Maps with symbols - // 2. Maps with non + // 2. Maps which are not resolvable 1 to 1 with memref for all dims if(inversePermutation(concatAffineMaps(maps))) { StringAttr empty = StringAttr::get(genericOp.getContext()); auto newGenericOp = rewriter.create(genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, listOfNewMaps, genericOp.getIteratorTypesArray(), diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index f98813fb15b5..ae74300af7a1 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms RaiseToAffine.cpp RemoveIterArgs.cpp RaiseToLinalg.cpp + LinalgDebufferize.cpp ParallelLower.cpp TrivialUse.cpp ConvertPolygeistToLLVM.cpp diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp new file mode 100644 index 000000000000..c5e04a67af5b --- /dev/null +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -0,0 +1,224 @@ +#include "PassDetails.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "polygeist/Passes/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-debufferize" + +using namespace mlir; +using namespace mlir::arith; +using namespace polygeist; +using namespace affine; +using namespace linalg; +using namespace tensor; +using namespace bufferization; + + + +//module @harris_score_with_gradient_extra_kernel { +// memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> +// memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> +// memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> +// memref.global @score : memref<512x512xi32> = uninitialized +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c4_i32 = arith.constant 4 : i32 +// %c0_i32 = arith.constant 0 : i32 +// %alloca = memref.alloca() : memref<512x512xi32> +// %alloca_0 = memref.alloca() : memref<512x512xi32> +// %alloca_1 = memref.alloca() : memref<512x512xi32> +// %alloca_2 = memref.alloca() : memref<516x516xi32> +// %alloca_3 = memref.alloca() : memref<516x516xi32> +// %alloca_4 = memref.alloca() : memref<518x518xi32> +// %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> +// %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> +// %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> +// // 2nd variant +// // %0 = memref.alloca() : memref<3x3xi32> +// // %1 = memref.alloca() : memref<3x3xi32> +// // %2 = memref.alloca() : memref<5x5xi32> +// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): +// %4 = arith.muli %in, %in_5 : i32 +// %5 = arith.addi %out_7, %4 : i32 +// %6 = arith.muli %in, %in_6 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7, %5 : i32, i32 +// } +// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): +// %4 = arith.muli %in, %in : i32 +// %5 = arith.muli %4, %in_6 : i32 +// %6 = arith.addi %out_8, %5 : i32 +// %7 = arith.muli %in_5, %in_5 : i32 +// %8 = arith.muli %7, %in_6 : i32 +// %9 = arith.addi %out_7, %8 : i32 +// %10 = arith.muli %in, %in_5 : i32 +// %11 = arith.muli %10, %in_6 : i32 +// %12 = arith.addi %out, %11 : i32 +// linalg.yield %12, %9, %6 : i32, i32, i32 +// } +// %3 = memref.get_global @score : memref<512x512xi32> +// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%3 : memref<512x512xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): +// %4 = arith.muli %in, %in_5 : i32 +// %5 = arith.muli %in_6, %in_6 : i32 +// %6 = arith.subi %4, %5 : i32 +// %7 = arith.addi %in, %in_5 : i32 +// %8 = arith.muli %7, %c4_i32 : i32 +// %9 = arith.muli %8, %7 : i32 +// %10 = arith.subi %6, %9 : i32 +// linalg.yield %10 : i32 +// } +// return %c0_i32 : i32 +// } +// } +struct LinalgDebufferization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::FuncOp funcOp, + PatternRewriter &rewriter) const final { + + auto module = funcOp->getParentOfType(); + + SmallVector opsToDelete; + llvm::SmallPtrSet opsToDeleteSet; + //Tracks both old linalg.generics and linalg.generics with repeated values in ins and outs + llvm::SmallPtrSet processedGenericOps; + + LogicalResult passResult = success(); + funcOp.walk([&](mlir::memref::AllocaOp allocaOp) -> WalkResult { + auto module = allocaOp->getParentOfType(); + rewriter.setInsertionPointAfter(allocaOp); + auto tensorType = RankedTensorType::get(allocaOp.getType().getShape(), allocaOp.getType().getElementType()); + + //Check to see if only linalg.generic are users of the allocaOp for now. + //TODO: Extend this + if(!llvm::all_of(allocaOp->getUsers(),[](Operation *op) { + return isa(op); + })){ + passResult = failure(); + return WalkResult::interrupt(); + } + + //auto emptyTensor = rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), allocaOp.getType().getElementType()); + auto toTensorOp = rewriter.create( + allocaOp.getLoc(), + tensorType, + allocaOp); + Value currentTensor = toTensorOp; + + //Check if allocaOp is an output in current genericOp + for (auto user : allocaOp->getUsers()) { + if (auto genericOp = dyn_cast(user)) { + + //auto genericOp = cast(user); + if(processedGenericOps.count(genericOp) > 0) + continue; + rewriter.setInsertionPointAfter(genericOp); + + SmallVector newInputs; + SmallVector newOutputs; + SmallVector resultTypes; + //Create a new linalg.generic in Destination Style Passing format + + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + for(auto input : genericOp.getInputs()){ + newInputs.push_back(input == allocaOp ? currentTensor : input); + } + + //ArrayRef resultTypes; + int newCurrentTensorIndex = -1; + int index = 0; + for(auto output : genericOp.getOutputs()){ + newOutputs.push_back(output == allocaOp ? currentTensor : output); + resultTypes.push_back(currentTensor.getType()); + if(output == allocaOp) { + newCurrentTensorIndex = index; + } + index++; + } + + StringAttr empty = StringAttr::get(genericOp.getContext()); + ArrayRef resultTypesRef(resultTypes); + auto newGenericOp = rewriter.create(genericOp.getLoc(), resultTypesRef, newInputs, newOutputs, + genericOp.getIndexingMaps(), genericOp.getIteratorTypes(), empty, empty); + + Region &opRegion = newGenericOp.getRegion(); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().end()); + + //Replace all uses of original generic op with the new one + int idxOldGeneric=0; + int idxNewGeneric=0; + for (unsigned i = 0; i < genericOp->getNumResults(); ++i) { + if(i == newCurrentTensorIndex) { + idxNewGeneric++; + } + genericOp->getResult(i).replaceAllUsesWith(newGenericOp->getResult(i)); + idxOldGeneric++; + idxNewGeneric++; + } + + //Delete the original genericOp + opsToDelete.push_back(genericOp.getOperation()); + if(newCurrentTensorIndex != -1) + currentTensor = newGenericOp.getResult(newCurrentTensorIndex); + + processedGenericOps.insert(genericOp.getOperation()); + } + } + + auto toMemrefOp = rewriter.create( + allocaOp.getLoc(), + allocaOp.getType(), + currentTensor); + rewriter.create(allocaOp.getLoc(), toMemrefOp, allocaOp); + //opsToDelete.push_back(allocaOp.getOperation()); + return WalkResult::advance(); + }); + for (Operation *op : opsToDelete) { + op->erase(); + } + opsToDelete.clear(); + + return passResult; + } +}; + +namespace { +struct LinalgDebufferize + : public LinalgDebufferizeBase { + void runOnOperation() override; +}; +} // namespace + +void LinalgDebufferize::runOnOperation() { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + +namespace mlir { +namespace polygeist { +std::unique_ptr createLinalgDebufferizePass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir