diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index dfdd21618..97639a425 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -127,7 +127,7 @@ def _test_kjt_input_module( device = "cuda" # pyre-ignore aot_inductor_module = AOTIRunnerUtil.load(device, so_path) - aot_actual_output = aot_inductor_module(inputs) + aot_actual_output = aot_inductor_module(*inputs) assert_close(eager_output, aot_actual_output) def test_kjt_split(self) -> None: @@ -204,7 +204,7 @@ def kjt_to_inputs(kjt): ) # pyre-ignore aot_inductor_module = AOTIRunnerUtil.load(device, so_path) - aot_inductor_module(example_inputs) + aot_inductor_module(*example_inputs) aot_actual_outputs = [ aot_inductor_module(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:]