Skip to content

Commit

Permalink
Fix a bug in the torch._export.aot_load API (#1652)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1652

X-link: pytorch/pytorch#118039

X-link: pytorch/benchmark#2128

tree_flatten_spec should use args instead of *args

clone of pytorch/pytorch#117948 but with some fbcode specific changes

Reviewed By: angelayi

Differential Revision: D52982401

fbshipit-source-id: 503326c8b4d7316315153a4ab3873f6c8bcac9fb
  • Loading branch information
desertfire authored and facebook-github-bot committed Jan 23, 2024
1 parent f9eaff8 commit be17fd7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _test_kjt_input_module(
device = "cuda"
# pyre-ignore
aot_inductor_module = AOTIRunnerUtil.load(device, so_path)
aot_actual_output = aot_inductor_module(inputs)
aot_actual_output = aot_inductor_module(*inputs)
assert_close(eager_output, aot_actual_output)

def test_kjt_split(self) -> None:
Expand Down Expand Up @@ -204,7 +204,7 @@ def kjt_to_inputs(kjt):
)
# pyre-ignore
aot_inductor_module = AOTIRunnerUtil.load(device, so_path)
aot_inductor_module(example_inputs)
aot_inductor_module(*example_inputs)

aot_actual_outputs = [
aot_inductor_module(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:]
Expand Down

0 comments on commit be17fd7

Please sign in to comment.