diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 84545752a38..d3dfd3fbc9b 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -1665,6 +1665,29 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_shape_propagate", + srcs = ["transforms/disc_shape_propagate.cc"], + hdrs = [ + ], + deps = [ + ":mhlo_disc", + ":pass_details", + ":shape_utils", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "disc_dot_merge", srcs = ["transforms/disc_dot_merge.cc"], @@ -2471,6 +2494,7 @@ cc_library( ":disc_remove_shape_constraints", ":disc_shape_optimization", ":disc_shape_simplifier", + "disc_shape_propagate", ":disc_shape_to_std", ":disc_sparse_op_rewriter", ":disc_specialize_fusion_with_speculation", diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index 1a978308d80..e8f0eb06ae3 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -235,7 +235,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { auto printingFlags = OpPrintingFlags(); printingFlags.elideLargeElementsAttrs(16); pm.enableIRPrinting( - /*shouldPrintBeforePass=*/nullptr, + /*shouldPrintBeforePass=*/ + nullptr, /*shouldPrintAfterPass=*/ [](Pass* pass, Operation*) { return VLOG_IS_ON(1); }, /*printModuleScope=*/false, @@ -244,6 +245,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(disc_ral::createDiscAlgebraicSimplifierPass()); pm.addPass(disc_ral::createDiscInputOutputAliasPass()); + pm.addPass(disc_ral::createDiscShapePropagatePass()); pm.addPass(mlir::createInlinerPass()); // TODO(disc): Lower HLO shape constraints instead of eliding them here. pm.addNestedPass(disc_ral::createDiscCollectiveOpsRewriterPass()); @@ -1015,7 +1017,8 @@ Status ConvertTF2MlirHlo(mlir::ModuleOp module_op) { auto printingFlags = mlir::OpPrintingFlags(); printingFlags.elideLargeElementsAttrs(16); pm.enableIRPrinting( - /*shouldPrintBeforePass=*/nullptr, + /*shouldPrintBeforePass=*/ + nullptr, /*shouldPrintAfterPass=*/ [](mlir::Pass* pass, mlir::Operation*) { return VLOG_IS_ON(1); }, /*printModuleScope=*/false, diff --git a/tao_compiler/mlir/disc/tools/disc-opt/disc-opt.cc b/tao_compiler/mlir/disc/tools/disc-opt/disc-opt.cc index 6a1080591a0..3bbe0a6d61f 100644 --- a/tao_compiler/mlir/disc/tools/disc-opt/disc-opt.cc +++ b/tao_compiler/mlir/disc/tools/disc-opt/disc-opt.cc @@ -51,6 +51,7 @@ int main(int argc, char** argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::disc_ral::registerTransformDialectCommonExtension(registry); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index c2d3108c9d0..3352f7d5e74 100755 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -672,4 +672,9 @@ def DiscInputOutputAliasPass : Pass<"disc-input-output-alias", "ModuleOp"> { def DiscReduceBufferLiveRangePass : Pass<"disc-reduce-buffer-live-range", "mlir::func::FuncOp"> { let summary = "reduce buffer live range"; let constructor = "createDiscReduceBufferLiveRangePass()"; +} + +def DiscShapePropagatePass : Pass<"disc-shape-propagate", "ModuleOp"> { + let summary = "shape analysis pass"; + let constructor = "createDiscShapePropagatePass()"; } \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc new file mode 100644 index 00000000000..ca9185a13c2 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc @@ -0,0 +1,296 @@ +/* Copyright 2022 The BladeDISC Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements the logic to do some shape optimizations on tensor +// level. +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project +#include "mlir/IR/Dominance.h" +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/disc/IR/disc_shape_ops.h" +#include "mlir/disc/IR/hlo_disc_ops.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" +#include "mlir/disc/transforms/shape_utils.h" + +namespace mlir { +namespace disc_ral { + +using ::mlir::func::FuncOp; + +namespace { +std::string kDynamicDimsAttr = "input_dynamic_dims"; +struct ShapeContext { + ShapeContext() = default; + ShapeContext(Value value, SmallVector shape) + : value(value), shape(shape){}; + + Value value; + SmallVector shape; +}; +struct DiscShapePropagatePass + : public DiscShapePropagatePassBase { + DiscShapePropagatePass() + : DiscShapePropagatePassBase< + DiscShapePropagatePass>::DiscShapePropagatePassBase() {} + void getDependentDialects(DialectRegistry& registry) const override { + DiscShapePropagatePassBase::getDependentDialects( + registry); + registry.insert(); + } + void runOnOperation() override; +}; +bool isBinaryOp(Operation* op) { + return isa(*op) || isa(*op) || + isa(*op) || isa(*op); +} + +bool isUnaryOp(Operation* op) { return isa(op); } +bool isConcreteShape(ShapeContext& ctx) { + for (auto dim : ctx.shape) { + if (dim == ShapedType::kDynamic) return false; + } + return true; +} + +std::optional getConstTensor(OpBuilder& b, Operation* op, + ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return std::nullopt; + } + auto const_type = RankedTensorType::get(shape, b.getI64Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + auto const_op = + b.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +std::optional HandleBinaryOp(OpBuilder& b, Operation* op, + ShapeContext& inputCtx) { + if (!isBinaryOp(op)) return std::nullopt; + if (op->getOperand(1).isa()) { + return ShapeContext(op->getResult(0), inputCtx.shape); + } + if (auto const_op = + dyn_cast(op->getOperand(1).getDefiningOp())) { + auto elemTy = + op->getOperand(0).getType().cast().getElementType(); + b.setInsertionPoint(op); + auto dense_attr = const_op.getValue().dyn_cast(); + int64_t value = (*dense_attr.getValues().begin()).getSExtValue(); + auto scalar_const_op = getConstTensor(b, op, {value}, {}); + Value inputShape = + b.create(op->getLoc(), op->getOperand(0)); + auto rank = inputCtx.shape.size(); + SmallVector boradcast_dim; + boradcast_dim.push_back(static_cast(rank)); + + auto bcast_op = b.create( + op->getLoc(), RankedTensorType::get(inputCtx.shape, elemTy), + scalar_const_op.value(), inputShape, b.getI64TensorAttr({})); + const_op.getResult().replaceAllUsesWith(bcast_op.getResult()); + const_op.erase(); + } + return ShapeContext(op->getResult(0), inputCtx.shape); +} + +template +std::optional propagateHelper(OpBuilder& b, Operation* op, + ShapeContext& inputCtx) { + return std::nullopt; +} +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto dot_op = cast(op); + auto lhs_shape = + dot_op.getOperand(0).getType().cast().getShape(); + auto rhs_shape = + dot_op.getOperand(1).getType().cast().getShape(); + auto result_shape = + dot_op.getResult().getType().cast().getShape(); + SmallVector new_shape; + new_shape.push_back(lhs_shape[0]); + new_shape.push_back(rhs_shape[1]); + return ShapeContext(op->getResult(0), new_shape); +} + +LogicalResult parseInputDynamicDims( + func::FuncOp main, + std::vector>>& input_dynamic_dims) { + auto dict_attr = main->getAttrOfType("tf.entry_function"); + if (!dict_attr) { + return failure(); + } + if (!dict_attr.get(kDynamicDimsAttr)) { + return failure(); + } + StringRef param_str = + dict_attr.get(kDynamicDimsAttr).dyn_cast(); + + SmallVector parsed_dynamic_dims; + param_str.split(parsed_dynamic_dims, "|"); + for (auto kv : parsed_dynamic_dims) { + SmallVector pair; + kv.split(pair, ":"); + if (pair.size() != 2) { + return failure(); + } + int arg_index = std::stoi(pair[0].str()); + SmallVector dims; + pair[1].split(dims, ","); + std::vector dim_vec; + for (auto dim : dims) { + dim_vec.push_back(std::stoi(dim.str())); + } + input_dynamic_dims.push_back({arg_index, dim_vec}); + } + return success(); +} + +void applyShapeContext(ShapeContext& ctx) { + auto res_ty = ctx.value.getType().dyn_cast(); + if (!res_ty) return; + auto elemTy = res_ty.getElementType(); + ctx.value.setType(RankedTensorType::get(ctx.shape, elemTy)); +} + +std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, + ShapeContext& inputCtx) { + if (isUnaryOp(op)) { + return ShapeContext(op->getResult(0), inputCtx.shape); + } + if (auto ctx = HandleBinaryOp(rewriter, op, inputCtx)) { + return ctx; + } + using PropagationFunc = + std::optional (*)(OpBuilder&, Operation*, ShapeContext&); + const std::vector propagationFunctions = { + propagateHelper, + }; + // Iterate over the propagation functions and apply each one + for (const auto& propagate : propagationFunctions) { + if (auto ctx = propagate(rewriter, op, inputCtx)) { + return ctx; + } + } + return std::nullopt; +} + +void visitOperator(ModuleOp& m, OpBuilder& rewriter, Operation* op, + ShapeContext& ctx) { + if (isConcreteShape(ctx)) return; + // later to process return operators + if (isa(op)) return; + + auto resultShapeCtx = propagateOpShape(rewriter, op, ctx); + if (!resultShapeCtx) { + m.emitError("failed update shape context on op:" + + op->getName().stripDialect().str()); + return; + } + for (auto user : op->getResult(0).getUsers()) { + visitOperator(m, rewriter, user, resultShapeCtx.value()); + } + applyShapeContext(*resultShapeCtx); +} + +void DiscShapePropagatePass::runOnOperation() { + ModuleOp m = getOperation(); + auto main = m.lookupSymbol("main"); + MLIRContext* context = &getContext(); + mlir::OpBuilder rewriter(context); + OpBuilder b(main); + if (!main) { + m.emitError("entry func: main not found"); + signalPassFailure(); + return; + } + SmallVector new_arg_types, new_return_types; + for (auto arg : main.getArguments()) { + new_arg_types.push_back(arg.getType()); + } + // stage1: parse attribute input_dynamic_dims to a map + std::vector>> input_dynamic_dims; + if (failed(parseInputDynamicDims(main, input_dynamic_dims))) { + return; + } + // skip this pass if no dynamic dims attribute + if (input_dynamic_dims.size() == 0) return; + // stage2: visit all operators to propagate shape + for (auto pair : input_dynamic_dims) { + int argIdx = pair.first; + Value value = main.getArgument(argIdx); + auto ty = value.getType().cast(); + SmallVector newShape; + std::copy(ty.getShape().begin(), ty.getShape().end(), + std::back_inserter(newShape)); + for (auto dim : pair.second) { + newShape[dim] = ShapedType::kDynamic; + } + ShapeContext ctx(value, newShape); + auto newType = RankedTensorType::get(newShape, ty.getElementType()); + for (auto user : main.getArgument(argIdx).getUsers()) { + visitOperator(m, rewriter, user, ctx); + } + new_arg_types[argIdx] = newType; + applyShapeContext(ctx); + } + + // stage3: visit all return operators to update function signature + main.walk([&](Operation* op) { + if (isa(*op)) { + for (auto operand : op->getOperands()) { + new_return_types.push_back(operand.getType()); + } + } + }); + main.setType( + FunctionType::get(main.getContext(), new_arg_types, new_return_types)); +} + +} // namespace + +std::unique_ptr> createDiscShapePropagatePass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index a6dc7344aff..b09b0bef411 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -191,6 +191,8 @@ createDiscUnhandledAtomicRMWConverterPass(); std::unique_ptr> createDiscShapeSimplifierPass( const std::string& entry_func_name = "main", bool insert_tie_shape = false); +std::unique_ptr> createDiscShapePropagatePass(); + // Using approximation impl for some special math ops. std::unique_ptr> createDiscMathApproximationPass(); diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir new file mode 100644 index 00000000000..ae112a3a903 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir @@ -0,0 +1,19 @@ +// RUN: disc-opt -split-input-file --disc-shape-propagate %s | FileCheck %s + +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x101xi64>, %arg1: tensor<4x101xi64>) -> tensor<4x101xi1> attributes{tf.entry_function = {input_dynamic_dims = "0:1|1:1"}}{ + // CHECK: %0 = mhlo.compare LT, %arg0, %arg1 : (tensor<4x?xi64>, tensor<4x?xi64>) -> tensor<4x?xi1> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<4x101xi64>, tensor<4x101xi64>) -> tensor<4x101xi1> + // CHECK: return %0 : tensor<4x?xi1> + return %0 : tensor<4x101xi1> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x101xi64>) -> tensor<4x101xi1> attributes{tf.entry_function = {input_dynamic_dims = "0:1"}}{ + // CHECK: %1 = shape.shape_of %arg0 : tensor<4x?xi64> -> tensor<2xindex> + // CHECK: %2 = "mhlo.dynamic_broadcast_in_dim"(%0, %1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor<4x?xi64> + %0 = mhlo.constant dense<0> : tensor<4x101xi64> + %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = #mhlo} : (tensor<4x101xi64>, tensor<4x101xi64>) -> tensor<4x101xi1> + return %1 : tensor<4x101xi1> +} \ No newline at end of file