diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 9a10d6939..9f264ef75 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -959,9 +959,14 @@ def assert_close(expected, actual) -> None: assert sorted(expected.keys()) == sorted(actual.keys()) for feature, jt_e in expected.items(): jt_got = actual[feature] - assert_close(jt_e.lengths(), jt_got.lengths()) - assert_close(jt_e.values(), jt_got.values()) - assert_close(jt_e.offsets(), jt_got.offsets()) + if isinstance(jt_e, torch.Tensor) and isinstance(jt_got, torch.Tensor): + if jt_got.device != jt_e.device: + jt_got = actual.to(jt_e.device) + assert_close(jt_e, jt_got) + else: + assert_close(jt_e.lengths(), jt_got.lengths()) + assert_close(jt_e.values(), jt_got.values()) + assert_close(jt_e.offsets(), jt_got.offsets()) else: if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor): if actual.device != expected.device: diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 8a79aadb4..fbeff23dc 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -146,6 +146,7 @@ def gen_model_and_input( feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, long_indices: bool = True, global_constant_batch: bool = False, + num_inputs: int = 1, ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -175,29 +176,31 @@ def gen_model_and_input( sparse_device=sparse_device, feature_processor_modules=feature_processor_modules, ) - inputs = [ - ( - cast(VariableBatchModelInputCallable, generate)( - average_batch_size=batch_size, - world_size=world_size, - num_float_features=num_float_features, - tables=tables, - weighted_tables=weighted_tables or [], - global_constant_batch=global_constant_batch, - ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( - world_size=world_size, - tables=tables, - dedup_tables=dedup_tables, - weighted_tables=weighted_tables or [], - num_float_features=num_float_features, - variable_batch_size=variable_batch_size, - batch_size=batch_size, - long_indices=long_indices, + inputs = [] + for _ in range(num_inputs): + inputs.append( + ( + cast(VariableBatchModelInputCallable, generate)( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables or [], + global_constant_batch=global_constant_batch, + ) + if generate == ModelInput.generate_variable_batch_input + else cast(ModelInputCallable, generate)( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + ) ) ) - ] return (model, inputs)