Skip to content

Commit

Permalink
Support optimization barrier
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Mar 4, 2024
2 parents c3222e6 + 30cf3d3 commit 3269aa0
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 21 deletions.
23 changes: 20 additions & 3 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -941,10 +941,7 @@ cc_library(
"transforms/rewriters.h",
],
deps = [
":mhlo_disc",
":lmhlo_disc",
":disc_ral",
":disc_map_hlo_to_lhlo_op",
":pass_details",
":placement_utils",
":shape_utils",
Expand Down Expand Up @@ -2121,6 +2118,25 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_reduce_buffer_live_range",
srcs = ["transforms/disc_reduce_buffer_live_range.cc"],
deps = [
":lmhlo_disc",
":disc_util",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:BufferizationTransforms",
],
alwayslink = 1,
)

cc_library(
name = "disc_bf16_expansion",
srcs = ["transforms/disc_bf16_expansion.cc"],
Expand Down Expand Up @@ -2337,6 +2353,7 @@ cc_library(
":disc_assign_memory_space",
":disc_bf16_expansion",
":disc_buffer_deallocation",
":disc_reduce_buffer_live_range",
":disc_canonicalizer",
":disc_comp_intens_fusion_to_cuda_source",
":disc_comp_intens_fusion_to_func",
Expand Down
25 changes: 25 additions & 0 deletions tao_compiler/mlir/disc/IR/lhlo_disc_ops.td
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,30 @@ def LHLO_ArgsMutationOp : LHLODISC_Op<"args_mutation", []> {
);
}

def LHLODISC_OptimizationBarrierOp : LHLODISC_Op<"optimization_barrier", []> {
let summary = "OptimizationBarrier operation";
let description = [{
Ensures that the operations that produce the `operand` are executed before any
operations that depend on the `result` and prevents compiler transformations
from moving operations across the barrier. Other than that, the operation is
an identity, i.e. `result` = `operand`.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier

Example:
```mlir
%result0, %result1 = mhlo.optimization_barrier %operand0, %operand1 : tensor<f32>, tensor<f32>
```
}];

let arguments = (ins
Arg<Variadic<LHLO_BufferOrIndexBuffer>, "", [MemRead]>:$args
);

let results = (outs Variadic<LHLO_BufferOrIndexBuffer>);

}


#endif // LMHLO_DISC_OPS
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscReduceBufferLiveRangePass());
pm.addNestedPass<FuncOp>(bufferization::createBufferDeallocationPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscBufferDeallocationPass());

Expand Down
10 changes: 4 additions & 6 deletions tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ struct HloToLhloOptimizationBarrierOpConverter
LogicalResult matchAndRewrite(
mhlo::OptimizationBarrierOp hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {

llvm::dbgs() << "Converting mhlo::OptimizationBarrierOp \n";

Operation* op = hloOp.getOperation();
auto operands = adaptor.getOperands();

Expand All @@ -197,9 +194,9 @@ struct HloToLhloOptimizationBarrierOpConverter
resultTypes.push_back(
MemRefType::get(ty.getShape(), ty.getElementType()));
}
llvm::dbgs() << "Replace Op With lmhlo_disc::OptimizationBarrierOp\n";
rewriter.replaceOpWithNewOp<lmhlo_disc::OptimizationBarrierOp>(hloOp, resultTypes, operands, op->getAttrs());

rewriter.replaceOpWithNewOp<lmhlo_disc::OptimizationBarrierOp>(
hloOp, resultTypes, operands, op->getAttrs());

return success();
}
Expand Down Expand Up @@ -247,6 +244,7 @@ struct HloToLhloCustomCallOpV2Converter
resultTypes.push_back(
MemRefType::get(ty.getShape(), ty.getElementType()));
}

rewriter.replaceOpWithNewOp<lmhlo_disc::CustomCallV2Op>(
hloOp, resultTypes, adaptor.getOperands(), hloOp->getAttrs());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/disc/IR/disc_ral_ops.h"
#include "mlir/disc/IR/disc_shape_ops.h"
#include "mlir/disc/IR/hlo_disc_ops.h"
#include "mlir/disc/IR/lhlo_disc_ops.h"
#include "mlir/disc/transforms/PassDetail.h"
#include "mlir/disc/transforms/disc_map_hlo_to_lhlo_op.h"
#include "mlir/disc/transforms/placement_utils.h"
#include "mlir/disc/transforms/rewriters.h"
#include "mlir/disc/transforms/shape_utils.h"
Expand All @@ -57,19 +54,18 @@ namespace {
template <typename T>
using BaseOpConversion = OpConversionPattern<T>;

struct LhloDISCOptimizationBarrierOpConverter : public OpRewritePattern<lmhlo_disc::OptimizationBarrierOp> {
struct LhloDISCOptimizationBarrierOpConverter
: public OpRewritePattern<lmhlo_disc::OptimizationBarrierOp> {
explicit LhloDISCOptimizationBarrierOpConverter(MLIRContext* context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(lmhlo_disc::OptimizationBarrierOp lhloOp,
PatternRewriter& rewriter) const override {

llvm::dbgs() << "Expand OptimizationBarrierPass \n";
Operation* op = lhloOp.getOperation();

auto operands = op->getOperands();
auto results = op->getResults();

for(int i=0; i<operands.size(); i++) {
for (int i = 0; i < operands.size(); i++) {
results[i].replaceAllUsesWith(operands[i]);
}

Expand All @@ -80,13 +76,13 @@ struct LhloDISCOptimizationBarrierOpConverter : public OpRewritePattern<lmhlo_di
};

struct DiscOptimizationBarrierExpandPass
: public DiscOptimizationBarrierExpandPassBase<DiscOptimizationBarrierExpandPass> {
: public DiscOptimizationBarrierExpandPassBase<
DiscOptimizationBarrierExpandPass> {
using DiscOptimizationBarrierExpandPassBase<
DiscOptimizationBarrierExpandPass>::DiscOptimizationBarrierExpandPassBase;

void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect,
disc_ral::RalDialect, lmhlo::LmhloDialect>();
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect>();
}

public:
Expand All @@ -108,7 +104,8 @@ struct DiscOptimizationBarrierExpandPass
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>> createDiscOptimizationBarrierExpandPass() {
std::unique_ptr<OperationPass<ModuleOp>>
createDiscOptimizationBarrierExpandPass() {
return std::make_unique<DiscOptimizationBarrierExpandPass>();
}

Expand Down
5 changes: 5 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -662,4 +662,9 @@ def DiscEraseBufferDeallocationPass : Pass<"disc-erase-buffer-deallocation", "ml
def DiscInputOutputAliasPass : Pass<"disc-input-output-alias", "ModuleOp"> {
let summary = "Input and output alias information for buffer reuse";
let constructor = "createDiscInputOutputAliasPass()";
}

def DiscReduceBufferLiveRangePass : Pass<"disc-reduce-buffer-live-range", "mlir::func::FuncOp"> {
let summary = "reduce buffer live range";
let constructor = "createDiscReduceBufferLiveRangePass()";
}
90 changes: 90 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "lhlo/IR/lhlo_ops.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/disc/disc_util.h"
#include "mlir/disc/transforms/PassDetail.h"

namespace mlir {
namespace disc_ral {

using lmhlo::FusionOp;
using memref::AllocOp;

namespace {

LogicalResult moveBufferAllocator(AllocOp allocOp) {
Value alloc = allocOp.getResult();
BufferViewFlowAnalysis aliasAnalysis(allocOp);
PostDominanceInfo postDominators(allocOp);
auto aliasesSet = aliasAnalysis.resolve(alloc);
// Determine the actual block to place the alloc and get liveness
// information.
Block* placementBlock =
bufferization::BufferPlacementTransformationBase::findCommonDominator(
alloc, aliasesSet, postDominators);

Operation* toMoveBefore = nullptr;
for (auto user : alloc.getUsers()) {
if (isa<func::ReturnOp>(user)) continue;
// user maybe in the sub-block of the placementBlock,
// find the closest parent op inside of placementBlock
while (user->getBlock() != placementBlock) {
user = user->getParentOp();
}
if (toMoveBefore == nullptr || user->isBeforeInBlock(toMoveBefore)) {
toMoveBefore = user;
}
}
allocOp->moveBefore(toMoveBefore);
return success();
}

struct DiscReduceBufferLiveRangePass
: public DiscReduceBufferLiveRangePassBase<DiscReduceBufferLiveRangePass> {
void runOnOperation() override {
SmallVector<AllocOp, 4> candidateBuffers;
func::FuncOp func = getOperation();

func.walk([&](AllocOp op) { candidateBuffers.push_back(op); });

for (int i = 0; i < candidateBuffers.size(); ++i) {
if (failed(moveBufferAllocator(candidateBuffers[i]))) {
return signalPassFailure();
}
}
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
createDiscReduceBufferLiveRangePass() {
return std::make_unique<DiscReduceBufferLiveRangePass>();
}

} // namespace disc_ral
} // namespace mlir
4 changes: 3 additions & 1 deletion tao_compiler/mlir/disc/transforms/passes.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ createDiscEraseBufferDeallocationPass();

// Insert ArgsMutationOp for buffer reuse
std::unique_ptr<OperationPass<ModuleOp>> createDiscInputOutputAliasPass();
std::unique_ptr<OperationPass<ModuleOp>> createDiscReduceBufferLiveRangePass();
} // namespace disc_ral
} // namespace mlir

Expand All @@ -339,7 +340,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createDiscLegalizeToLhloPass();

std::unique_ptr<OperationPass<ModuleOp>> createDiscLhloRewriterPass();

std::unique_ptr<OperationPass<ModuleOp>> createDiscOptimizationBarrierExpandPass();
std::unique_ptr<OperationPass<ModuleOp>>
createDiscOptimizationBarrierExpandPass();

} // namespace mhlo_disc
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@

// CHECK-LABEL: @optimization_barrier_expand
func.func @optimization_barrier_expand(%arg0 : tensor<1x2048x4096xf32>, %arg1: tensor<1x2048x4096xf32>) -> tensor<2048x4096xf16> {
// CHECK: %alloc = memref.alloc() : memref<1x2048x4096xf32>
// CHECK: "lmhlo.add"(%arg0, %arg1, %alloc) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> ()
%1 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32>
// CHECK: %alloc_0 = memref.alloc() : memref<1x2048x4096xf32>
// CHECK: "lmhlo.add"(%arg0, %arg1, %alloc_0) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> ()
%2 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32>
// CHECK: %alloc_1 = memref.alloc() : memref<1x2048x4096xf16>
// CHECK: "lmhlo.convert"(%alloc_0, %alloc_1) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf16>) -> ()
%3:2 = "mhlo.optimization_barrier"(%1, %2): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>)
%4 = "mhlo.convert"(%3#1): (tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf16>
// CHECK: %alloc_2 = memref.alloc() : memref<2048x4096xf16>
// CHECK: "lmhlo.reshape"(%alloc_1, %alloc_2) : (memref<1x2048x4096xf16>, memref<2048x4096xf16>) -> ()
%5 = "mhlo.reshape"(%4) : (tensor<1x2048x4096xf16>) -> tensor<2048x4096xf16>
// CHECK: return %alloc_2 : memref<2048x4096xf16>
return %5: tensor<2048x4096xf16>
}

0 comments on commit 3269aa0

Please sign in to comment.