Skip to content

Commit

Permalink
additional changes
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Nov 11, 2024
1 parent 68e613d commit 29e0561
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def autotvm_get_tuning_tasks(
"""
target, target_host = Target.canon_target_and_host(target, target_host)

mod = apply_graph_transforms(mod, params, transform_args)
mod = apply_graph_transforms(mod, transform_args, params)

tasks = autotvm.task.extract_from_program(
mod["main"],
Expand Down Expand Up @@ -718,7 +718,7 @@ def autoscheduler_get_tuning_tasks(
"""
target, target_host = Target.canon_target_and_host(target, target_host)

mod = apply_graph_transforms(mod, params, transform_args)
mod = apply_graph_transforms(mod, transform_args, params)

# Extract the tasks
tasks, task_weights = auto_scheduler.extract_tasks(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def compile_model(
instruments=instruments,
):
transform_args = parse_graph_transform_args(locals())
mod = apply_graph_transforms(mod, params, transform_args)
mod = apply_graph_transforms(mod, transform_args, params)

for partition_function, opts in zip(partition_functions, partition_opts):
mod = partition_function(mod, params, mod_name=mod_name, **opts)
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/driver/tvmc/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def layout_helper(layout):
raise TVMCException("Error converting layouts: {}".format(str(err)))


def apply_graph_transforms(mod, params, args):
def apply_graph_transforms(mod, args, params=None):
"""Alter the layout of the input graph.
Parameters
Expand All @@ -172,6 +172,8 @@ def apply_graph_transforms(mod, params, args):
The relay module to convert.
args : dict
The transform arguments.
params: dict
Module params
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions tests/python/driver/tvmc/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def check(self, func):
"mixed_precision_calculation_type": "float16",
"mixed_precision_acc_type": "float16",
},
params,
)
ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"])
assert ret
Expand All @@ -240,6 +241,7 @@ def check(self, func):
"mixed_precision_calculation_type": "float16",
"mixed_precision_acc_type": "float32",
},
params,
)
ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"])
assert ret
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/opencl_texture/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _test_mobilenet_v1(remote, target, calc_dtype, executor_type, acc_dtype):
"mixed_precision_calculation_type": calc_dtype,
"mixed_precision_acc_type": acc_dtype,
},
params,
)

if executor_type == "ge":
Expand Down

0 comments on commit 29e0561

Please sign in to comment.