Skip to content

Commit

Permalink
Fix reshape optimizations in nested blocks (#2297)
Browse files Browse the repository at this point in the history
* Fixed reshape optimization in nested blocks

* Address comments
  • Loading branch information
kasper0406 authored Aug 8, 2024
1 parent 2be4673 commit 8b7048e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def _match_pattern(reshape_op):

@block_context_manager
def _merge_consecutive_reshapes_block(self, block):
@block_context_manager
def help_merge_consecutive_reshapes_block(block):
fusion_happens = False
for op in list(block.operations):
Expand Down
37 changes: 37 additions & 0 deletions coremltools/converters/mil/mil/passes/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,43 @@ def prog(x):
backend=backend,
)

@pytest.mark.parametrize(
"backend",
backends,
)
def test_merge_reshape_in_nested_block(self, backend):
INPUT_SHAPE = (6, 7)
OUTPUT_SHAPE = (7, 6)

@mb.program(input_specs=[mb.TensorSpec(shape=INPUT_SHAPE)])
def prog(x):
loop_var = np.int32(2)
def while_cond(loop_var, _x):
return mb.equal(x=loop_var, y=np.int32(0))

def while_body(loop_var, x):
# Do reshapes of the input
y1 = mb.reshape(x=x, shape=(3, 2, 7))
y2 = mb.reshape(x=y1, shape=(7, 2, 3))
y3 = mb.reshape(x=y2, shape=(14, 3))
y4 = mb.reshape(x=y3, shape=OUTPUT_SHAPE)
return mb.add(x=loop_var, y=np.int(-1)), y4

while_results = mb.while_loop(_cond=while_cond, _body=while_body, loop_vars=(loop_var, x))
return while_results[1]

prev_prog, _, block = apply_pass_and_basic_check(prog, "common::merge_consecutive_reshapes")
assert get_op_types_in_program(prev_prog, recurse=True) == ["while_loop", "equal", "reshape", "reshape", "reshape", "reshape", "add"]
assert get_op_types_in_program(prog, recurse=True) == ["while_loop", "equal", "reshape", "add"]

assert len(block.outputs) == 1
assert_model_is_valid(
prog,
{"x": INPUT_SHAPE},
expected_output_shapes={block.outputs[0].name: OUTPUT_SHAPE},
backend=backend,
)

class TestCastOptimizationReduendantCastRemoval:
"""
Test single cast op removal.
Expand Down
13 changes: 10 additions & 3 deletions coremltools/converters/mil/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get_op_names_in_program(prog, func_name="main", skip_const_ops=True):
return op_names_in_program


def get_op_types_in_block(block: Block, skip_const_ops: bool = True):
def get_op_types_in_block(block: Block, skip_const_ops: bool = True, recurse: bool = False):
"""
Return the operation types in block,
in the same order as they are stored (topological)
Expand All @@ -282,16 +282,23 @@ def get_op_types_in_block(block: Block, skip_const_ops: bool = True):
if op.op_type == "const":
continue
op_types_in_block.append(op.op_type)

if recurse:
for child_block in op.blocks:
child_ops = get_op_types_in_block(child_block, skip_const_ops, recurse)
op_types_in_block += child_ops

return op_types_in_block


def get_op_types_in_program(prog: Program, func_name: str = "main", skip_const_ops: bool = True):
def get_op_types_in_program(prog: Program, func_name: str = "main", skip_const_ops: bool = True, recurse: bool = False):
"""
Return the operation types in prog[func_name],
in the same order as they are stored (topological)
If ``skip_const_ops = True``, const ops are not returned.
If ``recurse = True``, the ops of all nested blocks are returned.
"""
return get_op_types_in_block(prog[func_name], skip_const_ops)
return get_op_types_in_block(prog[func_name], skip_const_ops, recurse)

def random_gen(
shape,
Expand Down

0 comments on commit 8b7048e

Please sign in to comment.