From a8f9a30f6cfb8c5669dfcb9417ced01d1dc6f884 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Thu, 20 Jul 2023 17:16:15 +0800 Subject: [PATCH] update --- .../compiler/mlir/converters/torch_mlir_op_filter.cpp | 9 ++------- pytorch_blade/torch_blade/pass_manager.py | 3 --- tao_compiler/mlir/disc/transforms/codegen_utils.cc | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp index 20d267aca27..1111b3ba0fd 100644 --- a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp +++ b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp @@ -173,16 +173,11 @@ const std::unordered_set &GetTorchMlirWhiteList() { std::call_once(white, [&]() { auto list = StrSplit(env::ReadStringFromEnvVar("TORCH_MHLO_OP_WHITE_LIST", ""), ';'); - for (auto s : list) { - std::cout << "white insert: " << s << std::endl; - white_list.insert(s); - } + for (auto s : list) white_list.insert(s); }); std::call_once(black, [&]() { auto list = StrSplit(env::ReadStringFromEnvVar("TORCH_MHLO_OP_BLACK_LIST", ""), ';'); - for (auto s : list) { - white_list.erase(s); - } + for (auto s : list) white_list.erase(s); }); std::ostringstream ostr; diff --git a/pytorch_blade/torch_blade/pass_manager.py b/pytorch_blade/torch_blade/pass_manager.py index 3fa1e52cf3b..ce4636f61a6 100644 --- a/pytorch_blade/torch_blade/pass_manager.py +++ b/pytorch_blade/torch_blade/pass_manager.py @@ -363,7 +363,6 @@ def _optimize_common(c_module): torch._C._jit_pass_remove_dropout(c_module) _fixup_for_dynamic_shape(cfg, c_module) graph = c_module.forward.graph - print("after fixup for dynamic shape", graph) _jit_pass_remove_nograd(graph) _jit_pass_freeze_requires_grad(graph) if hasattr(torch._C, "_jit_pass_fold_frozen_conv_bn"): @@ -525,10 +524,8 @@ def _collect_all_inplace_nodes(block): graph.appendNode(copy_op) if list(graph.return_node().inputs())[0].node().kind() == "prim::TupleConstruct": copy_op.moveBefore(list(graph.return_node().inputs())[0].node()) - list(copy_op.inputs())[0].replaceAllUsesAfterNodeWith(copy_op, slice_scatter.output()) node.destroy() - print("after inplace mutation: \n", graph) def _jit_pass_hack_cpu_device(graph): cfg = Config.get_current_context_or_new() diff --git a/tao_compiler/mlir/disc/transforms/codegen_utils.cc b/tao_compiler/mlir/disc/transforms/codegen_utils.cc index b772e0d9f70..5d402ca050a 100755 --- a/tao_compiler/mlir/disc/transforms/codegen_utils.cc +++ b/tao_compiler/mlir/disc/transforms/codegen_utils.cc @@ -104,7 +104,7 @@ Value emitNumElementsComputation(OpBuilder& b, Location loc, Operation* op) { // only const rank is supported for now assert(op->getDialect()->getNamespace() == "lmhlo"); int num_operands = op->getNumOperands(); - if (isa(op)) { + if (isa(op) && isInplaceOperator(op)) { return emitNumElementsComputation(b, loc, op->getOperand(1)); } Value result_memref = op->getOperand(num_operands - 1);