Skip to content

Commit

Permalink
Support optimization barrier op (#1286)
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong authored Mar 6, 2024
1 parent 30cf3d3 commit f160eb2
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 2 deletions.
29 changes: 29 additions & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,34 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_optimization_barrier_expand",
srcs = ["transforms/disc_optimization_barrier_expand.cc"],
hdrs = [
"transforms/passes.h",
"transforms/rewriters.h",
],
deps = [
":lmhlo_disc",
":pass_details",
":placement_utils",
":shape_utils",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:SCFDialect",
],
alwayslink = 1,
)

cc_library(
name = "disc_lower_to_library_call",
srcs = ["transforms/disc_lower_to_library_call.cc"],
Expand Down Expand Up @@ -2356,6 +2384,7 @@ cc_library(
":disc_math_approximation",
":disc_memref_canonicalizer",
":disc_outline_cpu_kernel",
":disc_optimization_barrier_expand",
":disc_parallel_loop_collapsing",
":disc_parallel_loop_tiling",
":disc_remove_dead_buffer",
Expand Down
25 changes: 25 additions & 0 deletions tao_compiler/mlir/disc/IR/lhlo_disc_ops.td
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,30 @@ def LHLO_ArgsMutationOp : LHLODISC_Op<"args_mutation", []> {
);
}

def LHLODISC_OptimizationBarrierOp : LHLODISC_Op<"optimization_barrier", []> {
let summary = "OptimizationBarrier operation";
let description = [{
Ensures that the operations that produce the `operand` are executed before any
operations that depend on the `result` and prevents compiler transformations
from moving operations across the barrier. Other than that, the operation is
an identity, i.e. `result` = `operand`.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier

Example:
```mlir
%result0, %result1 = mhlo.optimization_barrier %operand0, %operand1 : tensor<f32>, tensor<f32>
```
}];

let arguments = (ins
Arg<Variadic<LHLO_BufferOrIndexBuffer>, "", [MemRead]>:$args
);

let results = (outs Variadic<LHLO_BufferOrIndexBuffer>);

}


#endif // LMHLO_DISC_OPS
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(disc_ral::createDiscStitchFusionPass());
}

pm.addPass(mhlo_disc::createDiscOptimizationBarrierExpandPass());

if (useTransformSchedule()) {
std::string transform_schedule;
tensorflow::ReadStringFromEnvVar("DISC_TRANSFORM_SCHEDULE_FILE", "",
Expand Down
4 changes: 2 additions & 2 deletions tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,8 @@ LogicalResult DiscAssignMemorySpacePass::applyOperationAssignment(
// clang-format: off
Operation* newOp = tryReplaceResultType<
memref::AllocOp, memref::AllocaOp, memref::SubViewOp, memref::ViewOp,
memref::CastOp, memref::ReinterpretCastOp, lmhlo_disc::CustomCallV2Op>(
op, assignment);
memref::CastOp, memref::ReinterpretCastOp, lmhlo_disc::CustomCallV2Op,
lmhlo_disc::OptimizationBarrierOp>(op, assignment);
// clang-format: on

if (newOp) {
Expand Down
26 changes: 26 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,30 @@ struct HloToLhloArgsMutationOpConverter
}
};

struct HloToLhloOptimizationBarrierOpConverter
: public BaseOpConversion<mhlo::OptimizationBarrierOp> {
public:
using BaseOpConversion<mhlo::OptimizationBarrierOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::OptimizationBarrierOp hloOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
Operation* op = hloOp.getOperation();
auto operands = adaptor.getOperands();

SmallVector<Type> resultTypes;
for (Value v : hloOp.getResults()) {
auto ty = v.getType().cast<RankedTensorType>();
resultTypes.push_back(
MemRefType::get(ty.getShape(), ty.getElementType()));
}

rewriter.replaceOpWithNewOp<lmhlo_disc::OptimizationBarrierOp>(
hloOp, resultTypes, operands, op->getAttrs());

return success();
}
};

struct HloToLhloCustomCallOpConverter
: public BaseOpConversion<mhlo_disc::CustomCallOp> {
public:
Expand Down Expand Up @@ -382,6 +406,7 @@ struct DiscHloLegalizeToLhlo
target.addIllegalOp<disc_shape::TieShapeOp>();
target.addIllegalOp<mhlo_disc::ArgsMutationOp>();
target.addIllegalOp<mhlo::CustomCallOp>();
target.addIllegalOp<mhlo::OptimizationBarrierOp>();

bufferization::BufferizeTypeConverter converter;
populateDiscHLOToLHLOConversionPattern(&context, &converter, &patterns);
Expand Down Expand Up @@ -411,6 +436,7 @@ void populateDiscHLOToLHLOConversionPattern(
HloToLhloArgsMutationOpConverter,
HloToLhloCustomCallOpConverter,
HloToLhloCustomCallOpV2Converter,
HloToLhloOptimizationBarrierOpConverter,
TieShapeOpConverter
>(*converter, context);
// clang-format on
Expand Down
113 changes: 113 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2021 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This file implements logic for lowering HLO DISC dialect to LHLO DISC
// dialect.

#include <algorithm>
#include <utility>

#include "lhlo/IR/lhlo_ops.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/disc/IR/disc_shape_ops.h"
#include "mlir/disc/IR/lhlo_disc_ops.h"
#include "mlir/disc/transforms/PassDetail.h"
#include "mlir/disc/transforms/placement_utils.h"
#include "mlir/disc/transforms/rewriters.h"
#include "mlir/disc/transforms/shape_utils.h"

namespace mlir {
using placement_utils::kDiscPlaceAssignment;
using placement_utils::kGpu;

namespace mhlo_disc {
namespace {

template <typename T>
using BaseOpConversion = OpConversionPattern<T>;

struct LhloDISCOptimizationBarrierOpConverter
: public OpRewritePattern<lmhlo_disc::OptimizationBarrierOp> {
explicit LhloDISCOptimizationBarrierOpConverter(MLIRContext* context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(lmhlo_disc::OptimizationBarrierOp lhloOp,
PatternRewriter& rewriter) const override {
Operation* op = lhloOp.getOperation();

auto operands = op->getOperands();
auto results = op->getResults();

for (int i = 0; i < operands.size(); i++) {
results[i].replaceAllUsesWith(operands[i]);
}

rewriter.eraseOp(op);

return success();
}
};

struct DiscOptimizationBarrierExpandPass
: public DiscOptimizationBarrierExpandPassBase<
DiscOptimizationBarrierExpandPass> {
using DiscOptimizationBarrierExpandPassBase<
DiscOptimizationBarrierExpandPass>::DiscOptimizationBarrierExpandPassBase;

void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect>();
}

public:
DiscOptimizationBarrierExpandPass() = default;

void runOnOperation() override {
auto& context = getContext();
RewritePatternSet patterns(&context);
ConversionTarget target(context);
target.addLegalDialect<arith::ArithDialect, lmhlo_disc::LmhloDiscDialect,
memref::MemRefDialect, shape::ShapeDialect,
tensor::TensorDialect>();
target.addIllegalOp<lmhlo_disc::OptimizationBarrierOp>();
patterns.insert<LhloDISCOptimizationBarrierOpConverter>(&context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
createDiscOptimizationBarrierExpandPass() {
return std::make_unique<DiscOptimizationBarrierExpandPass>();
}

} // namespace mhlo_disc
} // namespace mlir
5 changes: 5 additions & 0 deletions tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ def DiscLhloRewriterPass: Pass<"disc-lhlo-rewriter", "ModuleOp"> {
let summary = "rewrite lmhlo ops to lmhlo_disc ops.";
let constructor = "createDiscLhloRewriterPass()";
}

def DiscOptimizationBarrierExpandPass : Pass<"disc-optimization-barrier-expand", "ModuleOp"> {
let summary = "Expand OptimizationBarrierOp";
let constructor = "createDiscOptimizationBarrierExpandPass()";
}
3 changes: 3 additions & 0 deletions tao_compiler/mlir/disc/transforms/passes.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ std::unique_ptr<OperationPass<ModuleOp>> createDiscLegalizeToLhloPass();

std::unique_ptr<OperationPass<ModuleOp>> createDiscLhloRewriterPass();

std::unique_ptr<OperationPass<ModuleOp>>
createDiscOptimizationBarrierExpandPass();

} // namespace mhlo_disc
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

// RUN: disc-opt -split-input-file -disc-hlo-legalize-to-lhlo -hlo-legalize-to-lhlo -disc-optimization-barrier-expand %s -o - | FileCheck %s


// CHECK-LABEL: @optimization_barrier_expand
func.func @optimization_barrier_expand(%arg0 : tensor<1x2048x4096xf32>, %arg1: tensor<1x2048x4096xf32>) -> tensor<2048x4096xf16> {
// CHECK: %alloc = memref.alloc() : memref<1x2048x4096xf32>
// CHECK: "lmhlo.add"(%arg0, %arg1, %alloc) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> ()
%1 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32>
// CHECK: %alloc_0 = memref.alloc() : memref<1x2048x4096xf32>
// CHECK: "lmhlo.add"(%arg0, %arg1, %alloc_0) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> ()
%2 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32>
// CHECK: %alloc_1 = memref.alloc() : memref<1x2048x4096xf16>
// CHECK: "lmhlo.convert"(%alloc_0, %alloc_1) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf16>) -> ()
%3:2 = "mhlo.optimization_barrier"(%1, %2): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>)
%4 = "mhlo.convert"(%3#1): (tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf16>
// CHECK: %alloc_2 = memref.alloc() : memref<2048x4096xf16>
// CHECK: "lmhlo.reshape"(%alloc_1, %alloc_2) : (memref<1x2048x4096xf16>, memref<2048x4096xf16>) -> ()
%5 = "mhlo.reshape"(%4) : (tensor<1x2048x4096xf16>) -> tensor<2048x4096xf16>
// CHECK: return %alloc_2 : memref<2048x4096xf16>
return %5: tensor<2048x4096xf16>
}

0 comments on commit f160eb2

Please sign in to comment.