Skip to content

Commit

Permalink
Add Unit test for PEA FX Tracable (#2368)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2368

In converting VLE to TBE to use TorchRec Eager Mode Transform, we realized that the underlying PEA module is not FX tracable.

This adds a unit test to check if the sharded, quant PEA module is FX tracable -

Main changes:
1. An almost identical copy of the existing FX trace unit test for EBC, following D42192307. This test will trace over the Quantized, sharded PEA module
2. A new method to generate a model & inputs with a Sparse Arch that is a Quantized, Sharded PEA. This will be used for the Fx Trace unit test.

Reviewed By: PaulZhang12

Differential Revision: D61943896

fbshipit-source-id: f563fc4dea2518a1aaf858811fa50916e4c785a4
  • Loading branch information
aporialiao authored and facebook-github-bot committed Sep 9, 2024
1 parent 906cc26 commit 4abc147
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
11 changes: 8 additions & 3 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,9 +959,14 @@ def assert_close(expected, actual) -> None:
assert sorted(expected.keys()) == sorted(actual.keys())
for feature, jt_e in expected.items():
jt_got = actual[feature]
assert_close(jt_e.lengths(), jt_got.lengths())
assert_close(jt_e.values(), jt_got.values())
assert_close(jt_e.offsets(), jt_got.offsets())
if isinstance(jt_e, torch.Tensor) and isinstance(jt_got, torch.Tensor):
if jt_got.device != jt_e.device:
jt_got = actual.to(jt_e.device)
assert_close(jt_e, jt_got)
else:
assert_close(jt_e.lengths(), jt_got.lengths())
assert_close(jt_e.values(), jt_got.values())
assert_close(jt_e.offsets(), jt_got.offsets())
else:
if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor):
if actual.device != expected.device:
Expand Down
45 changes: 24 additions & 21 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def gen_model_and_input(
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
long_indices: bool = True,
global_constant_batch: bool = False,
num_inputs: int = 1,
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -175,29 +176,31 @@ def gen_model_and_input(
sparse_device=sparse_device,
feature_processor_modules=feature_processor_modules,
)
inputs = [
(
cast(VariableBatchModelInputCallable, generate)(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=tables,
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
)
if generate == ModelInput.generate_variable_batch_input
else cast(ModelInputCallable, generate)(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
inputs = []
for _ in range(num_inputs):
inputs.append(
(
cast(VariableBatchModelInputCallable, generate)(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=tables,
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
)
if generate == ModelInput.generate_variable_batch_input
else cast(ModelInputCallable, generate)(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
)
)
)
]
return (model, inputs)


Expand Down

0 comments on commit 4abc147

Please sign in to comment.