Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Jul 21, 2023
1 parent 79506e8 commit a8f9a30
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,11 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {

std::call_once(white, [&]() {
auto list = StrSplit(env::ReadStringFromEnvVar("TORCH_MHLO_OP_WHITE_LIST", ""), ';');
for (auto s : list) {
std::cout << "white insert: " << s << std::endl;
white_list.insert(s);
}
for (auto s : list) white_list.insert(s);
});
std::call_once(black, [&]() {
auto list = StrSplit(env::ReadStringFromEnvVar("TORCH_MHLO_OP_BLACK_LIST", ""), ';');
for (auto s : list) {
white_list.erase(s);
}
for (auto s : list) white_list.erase(s);
});

std::ostringstream ostr;
Expand Down
3 changes: 0 additions & 3 deletions pytorch_blade/torch_blade/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def _optimize_common(c_module):
torch._C._jit_pass_remove_dropout(c_module)
_fixup_for_dynamic_shape(cfg, c_module)
graph = c_module.forward.graph
print("after fixup for dynamic shape", graph)
_jit_pass_remove_nograd(graph)
_jit_pass_freeze_requires_grad(graph)
if hasattr(torch._C, "_jit_pass_fold_frozen_conv_bn"):
Expand Down Expand Up @@ -525,10 +524,8 @@ def _collect_all_inplace_nodes(block):
graph.appendNode(copy_op)
if list(graph.return_node().inputs())[0].node().kind() == "prim::TupleConstruct":
copy_op.moveBefore(list(graph.return_node().inputs())[0].node())

list(copy_op.inputs())[0].replaceAllUsesAfterNodeWith(copy_op, slice_scatter.output())
node.destroy()
print("after inplace mutation: \n", graph)

def _jit_pass_hack_cpu_device(graph):
cfg = Config.get_current_context_or_new()
Expand Down
2 changes: 1 addition & 1 deletion tao_compiler/mlir/disc/transforms/codegen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Value emitNumElementsComputation(OpBuilder& b, Location loc, Operation* op) {
// only const rank is supported for now
assert(op->getDialect()->getNamespace() == "lmhlo");
int num_operands = op->getNumOperands();
if (isa<lmhlo::DynamicUpdateSliceOp>(op)) {
if (isa<lmhlo::DynamicUpdateSliceOp>(op) && isInplaceOperator(op)) {
return emitNumElementsComputation(b, loc, op->getOperand(1));
}
Value result_memref = op->getOperand(num_operands - 1);
Expand Down

0 comments on commit a8f9a30

Please sign in to comment.