From 8b7048ebc034a8b40e48c39df3f33ba9f60feda9 Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Thu, 8 Aug 2024 19:58:33 +0200 Subject: [PATCH] Fix reshape optimizations in nested blocks (#2297) * Fixed reshape optimization in nested blocks * Address comments --- .../mil/passes/defs/optimize_repeat_ops.py | 1 + .../mil/mil/passes/tests/test_passes.py | 37 +++++++++++++++++++ coremltools/converters/mil/testing_utils.py | 13 +++++-- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_repeat_ops.py b/coremltools/converters/mil/mil/passes/defs/optimize_repeat_ops.py index 76be08a6f..dbb61d290 100644 --- a/coremltools/converters/mil/mil/passes/defs/optimize_repeat_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/optimize_repeat_ops.py @@ -269,6 +269,7 @@ def _match_pattern(reshape_op): @block_context_manager def _merge_consecutive_reshapes_block(self, block): + @block_context_manager def help_merge_consecutive_reshapes_block(block): fusion_happens = False for op in list(block.operations): diff --git a/coremltools/converters/mil/mil/passes/tests/test_passes.py b/coremltools/converters/mil/mil/passes/tests/test_passes.py index fed90fd55..3954a1566 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_passes.py @@ -1674,6 +1674,43 @@ def prog(x): backend=backend, ) + @pytest.mark.parametrize( + "backend", + backends, + ) + def test_merge_reshape_in_nested_block(self, backend): + INPUT_SHAPE = (6, 7) + OUTPUT_SHAPE = (7, 6) + + @mb.program(input_specs=[mb.TensorSpec(shape=INPUT_SHAPE)]) + def prog(x): + loop_var = np.int32(2) + def while_cond(loop_var, _x): + return mb.equal(x=loop_var, y=np.int32(0)) + + def while_body(loop_var, x): + # Do reshapes of the input + y1 = mb.reshape(x=x, shape=(3, 2, 7)) + y2 = mb.reshape(x=y1, shape=(7, 2, 3)) + y3 = mb.reshape(x=y2, shape=(14, 3)) + y4 = mb.reshape(x=y3, shape=OUTPUT_SHAPE) + return mb.add(x=loop_var, y=np.int(-1)), y4 + + while_results = mb.while_loop(_cond=while_cond, _body=while_body, loop_vars=(loop_var, x)) + return while_results[1] + + prev_prog, _, block = apply_pass_and_basic_check(prog, "common::merge_consecutive_reshapes") + assert get_op_types_in_program(prev_prog, recurse=True) == ["while_loop", "equal", "reshape", "reshape", "reshape", "reshape", "add"] + assert get_op_types_in_program(prog, recurse=True) == ["while_loop", "equal", "reshape", "add"] + + assert len(block.outputs) == 1 + assert_model_is_valid( + prog, + {"x": INPUT_SHAPE}, + expected_output_shapes={block.outputs[0].name: OUTPUT_SHAPE}, + backend=backend, + ) + class TestCastOptimizationReduendantCastRemoval: """ Test single cast op removal. diff --git a/coremltools/converters/mil/testing_utils.py b/coremltools/converters/mil/testing_utils.py index 4ab44c59e..ecfbe210a 100644 --- a/coremltools/converters/mil/testing_utils.py +++ b/coremltools/converters/mil/testing_utils.py @@ -271,7 +271,7 @@ def get_op_names_in_program(prog, func_name="main", skip_const_ops=True): return op_names_in_program -def get_op_types_in_block(block: Block, skip_const_ops: bool = True): +def get_op_types_in_block(block: Block, skip_const_ops: bool = True, recurse: bool = False): """ Return the operation types in block, in the same order as they are stored (topological) @@ -282,16 +282,23 @@ def get_op_types_in_block(block: Block, skip_const_ops: bool = True): if op.op_type == "const": continue op_types_in_block.append(op.op_type) + + if recurse: + for child_block in op.blocks: + child_ops = get_op_types_in_block(child_block, skip_const_ops, recurse) + op_types_in_block += child_ops + return op_types_in_block -def get_op_types_in_program(prog: Program, func_name: str = "main", skip_const_ops: bool = True): +def get_op_types_in_program(prog: Program, func_name: str = "main", skip_const_ops: bool = True, recurse: bool = False): """ Return the operation types in prog[func_name], in the same order as they are stored (topological) If ``skip_const_ops = True``, const ops are not returned. + If ``recurse = True``, the ops of all nested blocks are returned. """ - return get_op_types_in_block(prog[func_name], skip_const_ops) + return get_op_types_in_block(prog[func_name], skip_const_ops, recurse) def random_gen( shape,