Skip to content

Commit

Permalink
[Unity] Dynamic-shape param support in LazyTransformParams (#15713)
Browse files Browse the repository at this point in the history
This PR brings the support for dynamic-shape parameters to pass
LazyTransformParams.

Prior to this PR, the symbolic variables in the dynamic-shape parameters
are not properly popped out. This PR uses MatchCast to make sure the
symbolic variables are always popped out and thereby support the
dynamic-shape parameters.

This PR also fixes a previvously failed test.
  • Loading branch information
MasterJH5574 committed Sep 9, 2023
1 parent 6d1932b commit 761cae9
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 15 deletions.
13 changes: 8 additions & 5 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,15 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
# rewrite get item
tuple_get_item = super().visit_tuple_getitem_(op)
if tuple_get_item.tuple_value == self.input_tuple_param:
return relax.Call(
relax.ExternFunc("get_item"),
[relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
get_item_result = self.builder_.emit(
relax.Call(
relax.ExternFunc("get_item"),
[relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
)
)
return self.builder_.match_cast(get_item_result, op.struct_info)
else:
return tuple_get_item

Expand Down
120 changes: 110 additions & 10 deletions tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,23 @@ def main_transform_params() -> R.Tuple:
R.func_attr({"relax.force_pure": True})
cls = Expected
lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv)
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
lv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv_m: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv_m)
lv1: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
lv1, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1_m: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
(lv1_m,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
_2: R.Tuple = R.vm.kill_object(lv1_m)
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,))
gv: R.Tuple = R.tuple()
return gv
Expand Down Expand Up @@ -146,13 +154,17 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])):
slice_index = T.int64()

param = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv: R.Tensor((16, 16), dtype="float32") = R.match_cast(
param, R.Tensor((16, 16), dtype="float32")
)
param_m: R.Tensor((16, 16), dtype="float32") = gv
transformed = R.call_tir(
cls.slice_buffer,
(param,),
(param_m,),
tir_vars=[slice_index],
out_sinfo=R.Tensor((16,), dtype="float32"),
)
unused_1_ = R.vm.kill_object(param)
unused_1_ = R.vm.kill_object(param_m)
unused_2_ = R.call_packed(
"set_item", R.prim_value(0), transformed, sinfo_args=(R.Object,)
)
Expand All @@ -175,14 +187,100 @@ def slice_buffer(
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


# TODO(tvm-team): remove once regression get fixed
@pytest.mark.skip("temp disable, minor regression on read/write region in zero dim buffer")
def test_param_shape_symbolic():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
ic = T.int32()
w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, "ic", 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
)
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32")
):
ic = T.int64()
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
)
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((ic, 3, 3, 3), dtype="float32"),
) = (lv, lv2)
return gv

@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
ic = T.int32()
w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params() -> R.Tuple:
R.func_attr({"relax.force_pure": True})
ic = T.int64()
cls = Expected
gv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv)
gv2: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, ic, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, ic, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,))
gv4: R.Tuple = R.tuple()
return gv4

after = LazyTransformParams()(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_output_with_use_site():
@I.ir_module
class Module:
@T.prim_func
def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")):
with T.block("block"):
T.reads(x[()])
T.writes(y[()])
y[()] = x[()]

@R.function
Expand Down Expand Up @@ -212,8 +310,10 @@ def main_transform_params() -> R.Tuple:
R.func_attr({"relax.force_pure": True})
cls = Expected
x: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
y = R.call_tir(cls.copy, (x,), out_sinfo=R.Tensor((), dtype="float32"))
_: R.Tuple = R.vm.kill_object(x)
gv: R.Tensor((), dtype="float32") = R.match_cast(x, R.Tensor((), dtype="float32"))
x_m: R.Tensor((), dtype="float32") = gv
y = R.call_tir(cls.copy, (x_m,), out_sinfo=R.Tensor((), dtype="float32"))
_: R.Tuple = R.vm.kill_object(x_m)
z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), dtype="float32"))
_1: R.Object = R.call_packed("set_item", R.prim_value(0), y, sinfo_args=(R.Object,))
_2: R.Object = R.call_packed("set_item", R.prim_value(1), z, sinfo_args=(R.Object,))
Expand Down

0 comments on commit 761cae9

Please sign in to comment.