diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 622d2d5a6e3..bccadfd8565 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -933,6 +933,34 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_optimization_barrier_expand", + srcs = ["transforms/disc_optimization_barrier_expand.cc"], + hdrs = [ + "transforms/passes.h", + "transforms/rewriters.h", + ], + deps = [ + ":lmhlo_disc", + ":pass_details", + ":placement_utils", + ":shape_utils", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:SCFDialect", + ], + alwayslink = 1, +) + cc_library( name = "disc_lower_to_library_call", srcs = ["transforms/disc_lower_to_library_call.cc"], @@ -2356,6 +2384,7 @@ cc_library( ":disc_math_approximation", ":disc_memref_canonicalizer", ":disc_outline_cpu_kernel", + ":disc_optimization_barrier_expand", ":disc_parallel_loop_collapsing", ":disc_parallel_loop_tiling", ":disc_remove_dead_buffer", diff --git a/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td b/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td old mode 100644 new mode 100755 index 48f8afe5ce0..3cb14895865 --- a/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td +++ b/tao_compiler/mlir/disc/IR/lhlo_disc_ops.td @@ -295,5 +295,30 @@ def LHLO_ArgsMutationOp : LHLODISC_Op<"args_mutation", []> { ); } +def LHLODISC_OptimizationBarrierOp : LHLODISC_Op<"optimization_barrier", []> { + let summary = "OptimizationBarrier operation"; + let description = [{ + Ensures that the operations that produce the `operand` are executed before any + operations that depend on the `result` and prevents compiler transformations + from moving operations across the barrier. Other than that, the operation is + an identity, i.e. `result` = `operand`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier + + Example: + ```mlir + %result0, %result1 = mhlo.optimization_barrier %operand0, %operand1 : tensor, tensor + ``` + }]; + + let arguments = (ins + Arg, "", [MemRead]>:$args + ); + + let results = (outs Variadic); + +} + #endif // LMHLO_DISC_OPS diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index 582f6846d1e..9e681c8b738 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -514,6 +514,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(disc_ral::createDiscStitchFusionPass()); } + pm.addPass(mhlo_disc::createDiscOptimizationBarrierExpandPass()); + if (useTransformSchedule()) { std::string transform_schedule; tensorflow::ReadStringFromEnvVar("DISC_TRANSFORM_SCHEDULE_FILE", "", diff --git a/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc b/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc old mode 100644 new mode 100755 index 901fe835df6..0cfafb2a893 --- a/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc +++ b/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc @@ -593,8 +593,8 @@ LogicalResult DiscAssignMemorySpacePass::applyOperationAssignment( // clang-format: off Operation* newOp = tryReplaceResultType< memref::AllocOp, memref::AllocaOp, memref::SubViewOp, memref::ViewOp, - memref::CastOp, memref::ReinterpretCastOp, lmhlo_disc::CustomCallV2Op>( - op, assignment); + memref::CastOp, memref::ReinterpretCastOp, lmhlo_disc::CustomCallV2Op, + lmhlo_disc::OptimizationBarrierOp>(op, assignment); // clang-format: on if (newOp) { diff --git a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc old mode 100644 new mode 100755 index a01ad7e9025..21ad976ae22 --- a/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc +++ b/tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc @@ -178,6 +178,30 @@ struct HloToLhloArgsMutationOpConverter } }; +struct HloToLhloOptimizationBarrierOpConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mhlo::OptimizationBarrierOp hloOp, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Operation* op = hloOp.getOperation(); + auto operands = adaptor.getOperands(); + + SmallVector resultTypes; + for (Value v : hloOp.getResults()) { + auto ty = v.getType().cast(); + resultTypes.push_back( + MemRefType::get(ty.getShape(), ty.getElementType())); + } + + rewriter.replaceOpWithNewOp( + hloOp, resultTypes, operands, op->getAttrs()); + + return success(); + } +}; + struct HloToLhloCustomCallOpConverter : public BaseOpConversion { public: @@ -382,6 +406,7 @@ struct DiscHloLegalizeToLhlo target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); bufferization::BufferizeTypeConverter converter; populateDiscHLOToLHLOConversionPattern(&context, &converter, &patterns); @@ -411,6 +436,7 @@ void populateDiscHLOToLHLOConversionPattern( HloToLhloArgsMutationOpConverter, HloToLhloCustomCallOpConverter, HloToLhloCustomCallOpV2Converter, + HloToLhloOptimizationBarrierOpConverter, TieShapeOpConverter >(*converter, context); // clang-format on diff --git a/tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc b/tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc new file mode 100755 index 00000000000..eb986bc3cb2 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc @@ -0,0 +1,113 @@ +// Copyright 2021 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements logic for lowering HLO DISC dialect to LHLO DISC +// dialect. + +#include +#include + +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/disc/IR/disc_shape_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +#include "mlir/disc/transforms/PassDetail.h" +#include "mlir/disc/transforms/placement_utils.h" +#include "mlir/disc/transforms/rewriters.h" +#include "mlir/disc/transforms/shape_utils.h" + +namespace mlir { +using placement_utils::kDiscPlaceAssignment; +using placement_utils::kGpu; + +namespace mhlo_disc { +namespace { + +template +using BaseOpConversion = OpConversionPattern; + +struct LhloDISCOptimizationBarrierOpConverter + : public OpRewritePattern { + explicit LhloDISCOptimizationBarrierOpConverter(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(lmhlo_disc::OptimizationBarrierOp lhloOp, + PatternRewriter& rewriter) const override { + Operation* op = lhloOp.getOperation(); + + auto operands = op->getOperands(); + auto results = op->getResults(); + + for (int i = 0; i < operands.size(); i++) { + results[i].replaceAllUsesWith(operands[i]); + } + + rewriter.eraseOp(op); + + return success(); + } +}; + +struct DiscOptimizationBarrierExpandPass + : public DiscOptimizationBarrierExpandPassBase< + DiscOptimizationBarrierExpandPass> { + using DiscOptimizationBarrierExpandPassBase< + DiscOptimizationBarrierExpandPass>::DiscOptimizationBarrierExpandPassBase; + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + DiscOptimizationBarrierExpandPass() = default; + + void runOnOperation() override { + auto& context = getContext(); + RewritePatternSet patterns(&context); + ConversionTarget target(context); + target.addLegalDialect(); + target.addIllegalOp(); + patterns.insert(&context); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +createDiscOptimizationBarrierExpandPass() { + return std::make_unique(); +} + +} // namespace mhlo_disc +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td old mode 100644 new mode 100755 index 44cb765beb7..c0e1c254bbd --- a/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td @@ -24,3 +24,8 @@ def DiscLhloRewriterPass: Pass<"disc-lhlo-rewriter", "ModuleOp"> { let summary = "rewrite lmhlo ops to lmhlo_disc ops."; let constructor = "createDiscLhloRewriterPass()"; } + +def DiscOptimizationBarrierExpandPass : Pass<"disc-optimization-barrier-expand", "ModuleOp"> { + let summary = "Expand OptimizationBarrierOp"; + let constructor = "createDiscOptimizationBarrierExpandPass()"; +} diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h old mode 100755 new mode 100644 index 05def8a81b8..e533d442ac2 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -340,6 +340,9 @@ std::unique_ptr> createDiscLegalizeToLhloPass(); std::unique_ptr> createDiscLhloRewriterPass(); +std::unique_ptr> +createDiscOptimizationBarrierExpandPass(); + } // namespace mhlo_disc } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir new file mode 100644 index 00000000000..60447bab9ec --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir @@ -0,0 +1,22 @@ + +// RUN: disc-opt -split-input-file -disc-hlo-legalize-to-lhlo -hlo-legalize-to-lhlo -disc-optimization-barrier-expand %s -o - | FileCheck %s + + +// CHECK-LABEL: @optimization_barrier_expand +func.func @optimization_barrier_expand(%arg0 : tensor<1x2048x4096xf32>, %arg1: tensor<1x2048x4096xf32>) -> tensor<2048x4096xf16> { + // CHECK: %alloc = memref.alloc() : memref<1x2048x4096xf32> + // CHECK: "lmhlo.add"(%arg0, %arg1, %alloc) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> () + %1 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32> + // CHECK: %alloc_0 = memref.alloc() : memref<1x2048x4096xf32> + // CHECK: "lmhlo.add"(%arg0, %arg1, %alloc_0) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf32>, memref<1x2048x4096xf32>) -> () + %2 = "mhlo.add"(%arg0, %arg1): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf32> + // CHECK: %alloc_1 = memref.alloc() : memref<1x2048x4096xf16> + // CHECK: "lmhlo.convert"(%alloc_0, %alloc_1) : (memref<1x2048x4096xf32>, memref<1x2048x4096xf16>) -> () + %3:2 = "mhlo.optimization_barrier"(%1, %2): (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) -> (tensor<1x2048x4096xf32>, tensor<1x2048x4096xf32>) + %4 = "mhlo.convert"(%3#1): (tensor<1x2048x4096xf32>) -> tensor<1x2048x4096xf16> + // CHECK: %alloc_2 = memref.alloc() : memref<2048x4096xf16> + // CHECK: "lmhlo.reshape"(%alloc_1, %alloc_2) : (memref<1x2048x4096xf16>, memref<2048x4096xf16>) -> () + %5 = "mhlo.reshape"(%4) : (tensor<1x2048x4096xf16>) -> tensor<2048x4096xf16> + // CHECK: return %alloc_2 : memref<2048x4096xf16> + return %5: tensor<2048x4096xf16> +} \ No newline at end of file