From 761cae9c7195be901beb09450db0b8a447f7ccba Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 9 Sep 2023 13:46:27 -0400 Subject: [PATCH] [Unity] Dynamic-shape param support in LazyTransformParams (#15713) This PR brings the support for dynamic-shape parameters to pass LazyTransformParams. Prior to this PR, the symbolic variables in the dynamic-shape parameters are not properly popped out. This PR uses MatchCast to make sure the symbolic variables are always popped out and thereby support the dynamic-shape parameters. This PR also fixes a previvously failed test. --- .../relax/transform/lazy_transform_params.py | 13 +- .../test_transform_lazy_transform_params.py | 120 ++++++++++++++++-- 2 files changed, 118 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index 69f724067c..90e56c8dbb 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -164,12 +164,15 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr: # rewrite get item tuple_get_item = super().visit_tuple_getitem_(op) if tuple_get_item.tuple_value == self.input_tuple_param: - return relax.Call( - relax.ExternFunc("get_item"), - [relax.PrimValue(tuple_get_item.index)], - None, - [relax.ObjectStructInfo()], + get_item_result = self.builder_.emit( + relax.Call( + relax.ExternFunc("get_item"), + [relax.PrimValue(tuple_get_item.index)], + None, + [relax.ObjectStructInfo()], + ) ) + return self.builder_.match_cast(get_item_result, op.struct_info) else: return tuple_get_item diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 478580ff8d..94f2181daf 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -79,15 +79,23 @@ def main_transform_params() -> R.Tuple: R.func_attr({"relax.force_pure": True}) cls = Expected lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) - _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)) - _1: R.Tuple = R.vm.kill_object(lv) + gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast( + lv, R.Tensor((16, 16, 3, 3), dtype="float32") + ) + lv_m: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 + _: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m, sinfo_args=(R.Object,)) + _1: R.Tuple = R.vm.kill_object(lv_m) lv1: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast( + lv1, R.Tensor((3, 16, 3, 3), dtype="float32") + ) + lv1_m: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3 lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, - (lv1,), + (lv1_m,), out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), ) - _2: R.Tuple = R.vm.kill_object(lv1) + _2: R.Tuple = R.vm.kill_object(lv1_m) _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,)) gv: R.Tuple = R.tuple() return gv @@ -146,13 +154,17 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])): slice_index = T.int64() param = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + gv: R.Tensor((16, 16), dtype="float32") = R.match_cast( + param, R.Tensor((16, 16), dtype="float32") + ) + param_m: R.Tensor((16, 16), dtype="float32") = gv transformed = R.call_tir( cls.slice_buffer, - (param,), + (param_m,), tir_vars=[slice_index], out_sinfo=R.Tensor((16,), dtype="float32"), ) - unused_1_ = R.vm.kill_object(param) + unused_1_ = R.vm.kill_object(param_m) unused_2_ = R.call_packed( "set_item", R.prim_value(0), transformed, sinfo_args=(R.Object,) ) @@ -175,14 +187,100 @@ def slice_buffer( tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True) -# TODO(tvm-team): remove once regression get fixed -@pytest.mark.skip("temp disable, minor regression on read/write region in zero dim buffer") +def test_param_shape_symbolic(): + @I.ir_module + class Before: + @T.prim_func + def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle): + ic = T.int32() + w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32") + out = T.match_buffer(var_out, (16, ic, 3, 3), "float32") + for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3): + with T.block("layout_transform"): + o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(w1[i, o, h, w]) + T.writes(out[o, i, h, w]) + out[o, i, h, w] = w1[i, o, h, w] + + @R.function + def main_transform_params( + params: R.Tuple( + R.Tensor((3, "ic", 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + ) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") + ): + ic = T.int64() + # we expect ToNonDataflow and RemovePurityTracking to be invoked first + R.func_attr({"relax.force_pure": True}) + cls = Before + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = params[0] + lv2 = R.call_tir( + cls.transform_layout_IOHW_to_OIHW, + (lv1,), + out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"), + ) + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((ic, 3, 3, 3), dtype="float32"), + ) = (lv, lv2) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle): + ic = T.int32() + w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32") + out = T.match_buffer(var_out, (16, ic, 3, 3), "float32") + for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3): + with T.block("layout_transform"): + o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(w1[i, o, h, w]) + T.writes(out[o, i, h, w]) + out[o, i, h, w] = w1[i, o, h, w] + + @R.function + def main_transform_params() -> R.Tuple: + R.func_attr({"relax.force_pure": True}) + ic = T.int64() + cls = Expected + gv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) + gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast( + gv, R.Tensor((16, 16, 3, 3), dtype="float32") + ) + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 + _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)) + _1: R.Tuple = R.vm.kill_object(lv) + gv2: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + gv3: R.Tensor((3, ic, 3, 3), dtype="float32") = R.match_cast( + gv2, R.Tensor((3, ic, 3, 3), dtype="float32") + ) + lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = gv3 + lv2 = R.call_tir( + cls.transform_layout_IOHW_to_OIHW, + (lv1,), + out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"), + ) + _2: R.Tuple = R.vm.kill_object(lv1) + _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,)) + gv4: R.Tuple = R.tuple() + return gv4 + + after = LazyTransformParams()(Before) + tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True) + + def test_output_with_use_site(): @I.ir_module class Module: @T.prim_func def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")): with T.block("block"): + T.reads(x[()]) + T.writes(y[()]) y[()] = x[()] @R.function @@ -212,8 +310,10 @@ def main_transform_params() -> R.Tuple: R.func_attr({"relax.force_pure": True}) cls = Expected x: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) - y = R.call_tir(cls.copy, (x,), out_sinfo=R.Tensor((), dtype="float32")) - _: R.Tuple = R.vm.kill_object(x) + gv: R.Tensor((), dtype="float32") = R.match_cast(x, R.Tensor((), dtype="float32")) + x_m: R.Tensor((), dtype="float32") = gv + y = R.call_tir(cls.copy, (x_m,), out_sinfo=R.Tensor((), dtype="float32")) + _: R.Tuple = R.vm.kill_object(x_m) z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), dtype="float32")) _1: R.Object = R.call_packed("set_item", R.prim_value(0), y, sinfo_args=(R.Object,)) _2: R.Object = R.call_packed("set_item", R.prim_value(1), z, sinfo_args=(R.Object,))