From 4e41029116bca69b81371cd94ad1a9cb251b1cce Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Sat, 2 Mar 2024 17:31:19 -0800 Subject: [PATCH] apply Black 2024 style in fbcode (4/16) Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: aleivag Differential Revision: D54447727 fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f --- examples/inference/dlrm_predict.py | 8 +- examples/inference/dlrm_predict_single_gpu.py | 8 +- .../nvt_dataloader/nvt_binary_dataloader.py | 3 +- examples/nvt_dataloader/train_torchrec.py | 22 +-- examples/retrieval/modules/two_tower.py | 12 +- tools/lint/black_linter.py | 10 +- torchrec/datasets/criteo.py | 14 +- torchrec/datasets/tests/test_criteo.py | 2 +- torchrec/datasets/utils.py | 4 +- .../benchmark/benchmark_inference.py | 4 +- .../composable/tests/test_embedding.py | 6 +- .../composable/tests/test_embeddingbag.py | 8 +- .../composable/tests/test_fused_optim_nccl.py | 62 ++++++--- torchrec/distributed/dist_data.py | 10 +- torchrec/distributed/embedding.py | 34 ++--- torchrec/distributed/embedding_kernel.py | 12 +- torchrec/distributed/embedding_lookup.py | 16 ++- .../distributed/embedding_tower_sharding.py | 26 ++-- torchrec/distributed/embeddingbag.py | 32 ++--- torchrec/distributed/fbgemm_qcomm_codec.py | 22 +-- .../distributed/grouped_position_weighted.py | 1 + torchrec/distributed/mc_embedding_modules.py | 16 ++- torchrec/distributed/mc_modules.py | 76 ++++++----- torchrec/distributed/planner/planners.py | 28 ++-- torchrec/distributed/planner/proposers.py | 6 +- .../distributed/planner/shard_estimators.py | 52 +++++--- torchrec/distributed/planner/stats.py | 8 +- .../planner/tests/test_proposers.py | 8 +- torchrec/distributed/planner/types.py | 21 +-- torchrec/distributed/planner/utils.py | 4 +- torchrec/distributed/quant_embedding.py | 28 ++-- .../distributed/quant_embedding_kernel.py | 126 ++++++++++-------- torchrec/distributed/quant_embeddingbag.py | 12 +- torchrec/distributed/quant_state.py | 28 ++-- torchrec/distributed/sharding/cw_sharding.py | 6 +- torchrec/distributed/sharding/dp_sharding.py | 12 +- .../sharding/rw_sequence_sharding.py | 12 +- torchrec/distributed/sharding/rw_sharding.py | 56 ++++---- .../distributed/sharding/sequence_sharding.py | 12 +- .../sharding/tw_sequence_sharding.py | 12 +- torchrec/distributed/sharding/tw_sharding.py | 18 +-- .../distributed/sharding/twrw_sharding.py | 20 +-- torchrec/distributed/sharding_plan.py | 83 +++++++----- .../distributed/test_utils/infer_utils.py | 4 +- torchrec/distributed/test_utils/test_model.py | 38 +++--- .../test_utils/test_model_parallel.py | 8 +- .../distributed/test_utils/test_sharding.py | 56 ++++---- torchrec/distributed/tests/test_dist_data.py | 67 ++++++---- .../tests/test_embedding_sharding.py | 38 +++--- .../distributed/tests/test_fp_embeddingbag.py | 10 +- .../tests/test_fp_embeddingbag_utils.py | 42 +++--- torchrec/distributed/tests/test_fx_jit.py | 2 +- .../distributed/tests/test_infer_shardings.py | 6 +- .../distributed/tests/test_mc_embedding.py | 40 +++--- .../distributed/tests/test_mc_embeddingbag.py | 40 +++--- .../tests/test_qcomms_embedding_modules.py | 16 ++- .../tests/test_sequence_model_parallel.py | 8 +- ...est_sequence_model_parallel_single_rank.py | 8 +- torchrec/distributed/tests/test_utils.py | 24 ++-- torchrec/distributed/train_pipeline.py | 44 +++--- torchrec/distributed/types.py | 11 +- torchrec/distributed/utils.py | 1 - torchrec/fx/utils.py | 1 + torchrec/inference/modules.py | 10 +- torchrec/metrics/auc.py | 2 +- torchrec/metrics/auprc.py | 2 +- torchrec/metrics/metric_module.py | 6 +- torchrec/metrics/metrics_config.py | 6 +- torchrec/metrics/rauc.py | 2 +- torchrec/metrics/rec_metric.py | 10 +- torchrec/metrics/tests/test_recall_session.py | 6 +- torchrec/modules/feature_processor_.py | 1 - torchrec/modules/fp_embedding_modules.py | 8 +- torchrec/modules/fused_embedding_modules.py | 2 +- torchrec/modules/lazy_extension.py | 3 +- torchrec/modules/mc_embedding_modules.py | 6 +- torchrec/modules/mc_modules.py | 20 +-- torchrec/modules/tests/test_crossnet.py | 1 + .../tests/test_fused_embedding_modules.py | 2 + torchrec/modules/utils.py | 8 +- torchrec/optim/fused.py | 9 +- torchrec/optim/rowwise_adagrad.py | 2 +- torchrec/optim/tests/test_keyed.py | 12 +- torchrec/quant/embedding_modules.py | 14 +- torchrec/sparse/jagged_tensor.py | 106 +++++++++------ torchrec/types.py | 3 +- 86 files changed, 934 insertions(+), 731 deletions(-) diff --git a/examples/inference/dlrm_predict.py b/examples/inference/dlrm_predict.py index c36bb613d..e2bdcef9d 100644 --- a/examples/inference/dlrm_predict.py +++ b/examples/inference/dlrm_predict.py @@ -139,9 +139,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module: EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=self.model_config.embedding_dim, - num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx] - if self.model_config.num_embeddings is None - else self.model_config.num_embeddings, + num_embeddings=( + self.model_config.num_embeddings_per_feature[feature_idx] + if self.model_config.num_embeddings is None + else self.model_config.num_embeddings + ), feature_names=[feature_name], ) for feature_idx, feature_name in enumerate( diff --git a/examples/inference/dlrm_predict_single_gpu.py b/examples/inference/dlrm_predict_single_gpu.py index 753cb0dcc..ba5323247 100644 --- a/examples/inference/dlrm_predict_single_gpu.py +++ b/examples/inference/dlrm_predict_single_gpu.py @@ -50,9 +50,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module: EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=self.model_config.embedding_dim, - num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx] - if self.model_config.num_embeddings is None - else self.model_config.num_embeddings, + num_embeddings=( + self.model_config.num_embeddings_per_feature[feature_idx] + if self.model_config.num_embeddings is None + else self.model_config.num_embeddings + ), feature_names=[feature_name], ) for feature_idx, feature_name in enumerate( diff --git a/examples/nvt_dataloader/nvt_binary_dataloader.py b/examples/nvt_dataloader/nvt_binary_dataloader.py index 2286b5610..f88ff1feb 100644 --- a/examples/nvt_dataloader/nvt_binary_dataloader.py +++ b/examples/nvt_dataloader/nvt_binary_dataloader.py @@ -94,7 +94,8 @@ def __getitem__(self, idx: int): """Numerical features are returned in the order they appear in the channel spec section For performance reasons, this is required to be the order they are saved in, as specified by the relevant chunk in source spec. - Categorical features are returned in the order they appear in the channel spec section""" + Categorical features are returned in the order they appear in the channel spec section + """ if idx >= self._num_entries: raise IndexError() diff --git a/examples/nvt_dataloader/train_torchrec.py b/examples/nvt_dataloader/train_torchrec.py index 42563795d..abdf3a67d 100644 --- a/examples/nvt_dataloader/train_torchrec.py +++ b/examples/nvt_dataloader/train_torchrec.py @@ -208,9 +208,11 @@ def main(argv: List[str]): EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=args.embedding_dim, - num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx] - if num_embeddings_per_feature is not None - else args.num_embeddings, + num_embeddings=( + none_throws(num_embeddings_per_feature)[feature_idx] + if num_embeddings_per_feature is not None + else args.num_embeddings + ), feature_names=[feature_name], ) for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) @@ -232,9 +234,9 @@ def main(argv: List[str]): train_model = fuse_embedding_optimizer( train_model, - optimizer_type=torchrec.optim.RowWiseAdagrad - if args.adagrad - else torch.optim.SGD, + optimizer_type=( + torchrec.optim.RowWiseAdagrad if args.adagrad else torch.optim.SGD + ), optimizer_kwargs={"learning_rate": args.learning_rate}, device=torch.device("meta"), ) @@ -270,9 +272,11 @@ def main(argv: List[str]): non_fused_optimizer = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(model.named_parameters())), - lambda params: torch.optim.Adagrad(params, lr=args.learning_rate) - if args.adagrad - else torch.optim.SGD(params, lr=args.learning_rate), + lambda params: ( + torch.optim.Adagrad(params, lr=args.learning_rate) + if args.adagrad + else torch.optim.SGD(params, lr=args.learning_rate) + ), ) opt = trec_optim.keyed.CombinedOptimizer( diff --git a/examples/retrieval/modules/two_tower.py b/examples/retrieval/modules/two_tower.py index 224bcbfc7..447967780 100644 --- a/examples/retrieval/modules/two_tower.py +++ b/examples/retrieval/modules/two_tower.py @@ -71,12 +71,12 @@ def __init__( embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[ 0 ].embedding_dim - self._feature_names_query: List[ - str - ] = embedding_bag_collection.embedding_bag_configs()[0].feature_names - self._candidate_feature_names: List[ - str - ] = embedding_bag_collection.embedding_bag_configs()[1].feature_names + self._feature_names_query: List[str] = ( + embedding_bag_collection.embedding_bag_configs()[0].feature_names + ) + self._candidate_feature_names: List[str] = ( + embedding_bag_collection.embedding_bag_configs()[1].feature_names + ) self.ebc = embedding_bag_collection self.query_proj = MLP( in_size=embedding_dim, layer_sizes=layer_sizes, device=device diff --git a/tools/lint/black_linter.py b/tools/lint/black_linter.py index cfdc3d4e8..7c9a75f9c 100644 --- a/tools/lint/black_linter.py +++ b/tools/lint/black_linter.py @@ -176,11 +176,11 @@ def main() -> None: logging.basicConfig( format="<%(threadName)s:%(levelname)s> %(message)s", - level=logging.NOTSET - if args.verbose - else logging.DEBUG - if len(args.filenames) < 1000 - else logging.INFO, + level=( + logging.NOTSET + if args.verbose + else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO + ), stream=sys.stderr, ) diff --git a/torchrec/datasets/criteo.py b/torchrec/datasets/criteo.py index 14d012332..7318dda3f 100644 --- a/torchrec/datasets/criteo.py +++ b/torchrec/datasets/criteo.py @@ -234,11 +234,13 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]: return dense, sparse, label dense, sparse, labels = [], [], [] - for (row_dense, row_sparse, row_label) in CriteoIterDataPipe( + for row_dense, row_sparse, row_label in CriteoIterDataPipe( [in_file], - row_mapper=row_mapper - if not (dataset_name == "criteo_kaggle" and "test" in in_file) - else row_mapper_with_fake_label_constant, + row_mapper=( + row_mapper + if not (dataset_name == "criteo_kaggle" and "test" in in_file) + else row_mapper_with_fake_label_constant + ), ): dense.append(row_dense) sparse.append(row_sparse) @@ -261,7 +263,7 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]: labels_np = labels_np.reshape((-1, 1)) path_manager = PathManagerFactory().get(path_manager_key) - for (fname, arr) in [ + for fname, arr in [ (out_dense_file, dense_np), (out_sparse_file, sparse_np), (out_labels_file, labels_np), @@ -665,7 +667,7 @@ def shuffle( curr_first_row = curr_last_row # Directly copy over the last day's files since they will be used for validation and testing. - for (part, input_dir) in [ + for part, input_dir in [ ("sparse", input_dir_sparse), ("dense", input_dir_labels_and_dense), ("labels", input_dir_labels_and_dense), diff --git a/torchrec/datasets/tests/test_criteo.py b/torchrec/datasets/tests/test_criteo.py index e75dd032f..c93e25f45 100644 --- a/torchrec/datasets/tests/test_criteo.py +++ b/torchrec/datasets/tests/test_criteo.py @@ -70,7 +70,7 @@ def _validate_dataloader_sample( ) -> None: unbatched_samples = [{} for _ in range(self._sample_len(sample))] for k, batched_values in sample.items(): - for (idx, value) in enumerate(batched_values): + for idx, value in enumerate(batched_values): unbatched_samples[idx][k] = value for sample in unbatched_samples: self._validate_sample(sample, train=train) diff --git a/torchrec/datasets/utils.py b/torchrec/datasets/utils.py index faedc498b..a671caf68 100644 --- a/torchrec/datasets/utils.py +++ b/torchrec/datasets/utils.py @@ -77,9 +77,7 @@ def train_filter( decimal_places: int, idx: int, ) -> bool: - return (key_fn(idx) % 10**decimal_places) < round( - train_perc * 10**decimal_places - ) + return (key_fn(idx) % 10**decimal_places) < round(train_perc * 10**decimal_places) def val_filter( diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index 4f09df0fc..7c3a11661 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -196,7 +196,9 @@ def main() -> None: mb = int(float(num * dim) / 1024 / 1024) tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb" - report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" + report: str = ( + f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" + ) report += f"Module: {module_name}\n" report += tables_info report += "\n" diff --git a/torchrec/distributed/composable/tests/test_embedding.py b/torchrec/distributed/composable/tests/test_embedding.py index c7ca713c8..7991b49be 100644 --- a/torchrec/distributed/composable/tests/test_embedding.py +++ b/torchrec/distributed/composable/tests/test_embedding.py @@ -127,9 +127,9 @@ def _test_sharding( # noqa C901 kjt_input_per_rank[ctx.rank] ) - unsharded_model_pred_jt_dict_this_rank: Dict[ - str, JaggedTensor - ] = unsharded_model_pred_jt_dict[ctx.rank] + unsharded_model_pred_jt_dict_this_rank: Dict[str, JaggedTensor] = ( + unsharded_model_pred_jt_dict[ctx.rank] + ) embedding_names = unsharded_model_pred_jt_dict_this_rank.keys() assert set(unsharded_model_pred_jt_dict_this_rank.keys()) == set( diff --git a/torchrec/distributed/composable/tests/test_embeddingbag.py b/torchrec/distributed/composable/tests/test_embeddingbag.py index 7047662c4..673ff79e4 100644 --- a/torchrec/distributed/composable/tests/test_embeddingbag.py +++ b/torchrec/distributed/composable/tests/test_embeddingbag.py @@ -385,9 +385,11 @@ def test_sharding_ebc( }, kjt_input_per_rank=kjt_input_per_rank, sharder=TestEmbeddingBagCollectionSharder(sharding_type=sharding_type), - backend="nccl" - if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) - else "gloo", + backend=( + "nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo" + ), constraints=constraints, is_data_parallel=(sharding_type == ShardingType.DATA_PARALLEL.value), use_apply_optimizer_in_backward=use_apply_optimizer_in_backward, diff --git a/torchrec/distributed/composable/tests/test_fused_optim_nccl.py b/torchrec/distributed/composable/tests/test_fused_optim_nccl.py index 21f5e8e12..0fb3b62cd 100644 --- a/torchrec/distributed/composable/tests/test_fused_optim_nccl.py +++ b/torchrec/distributed/composable/tests/test_fused_optim_nccl.py @@ -77,29 +77,38 @@ def _test_sharded_fused_optimizer_state_dict( 0 ].state_dict()["state"][""]["table_0.momentum1"].gather( dst=0, - out=None if ctx.rank != 0 - # sharded column, each shard will have rowwise state - else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device), + out=( + None + if ctx.rank != 0 + # sharded column, each shard will have rowwise state + else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device) + ), ) ebc.embedding_bags["table_1"].weight._in_backward_optimizers[ 0 ].state_dict()["state"][""]["table_1.momentum1"].gather( dst=0, - out=None if ctx.rank != 0 - # sharded rowwise - else torch.empty((tables[1].num_embeddings,), device=ctx.device), + out=( + None + if ctx.rank != 0 + # sharded rowwise + else torch.empty((tables[1].num_embeddings,), device=ctx.device) + ), ) ebc.embedding_bags["table_2"].weight._in_backward_optimizers[ 0 ].state_dict()["state"][""]["table_2.momentum1"].gather( dst=0, - out=None if ctx.rank != 0 - # Column wise - with partial rowwise adam, first state is point wise - else torch.empty( - (tables[2].num_embeddings, tables[2].embedding_dim), - device=ctx.device, + out=( + None + if ctx.rank != 0 + # Column wise - with partial rowwise adam, first state is point wise + else torch.empty( + (tables[2].num_embeddings, tables[2].embedding_dim), + device=ctx.device, + ) ), ) @@ -107,20 +116,26 @@ def _test_sharded_fused_optimizer_state_dict( 0 ].state_dict()["state"][""]["table_2.exp_avg_sq"].gather( dst=0, - out=None if ctx.rank != 0 - # Column wise - with partial rowwise adam, first state is point wise - else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device), + out=( + None + if ctx.rank != 0 + # Column wise - with partial rowwise adam, first state is point wise + else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device) + ), ) ebc.embedding_bags["table_3"].weight._in_backward_optimizers[ 0 ].state_dict()["state"][""]["table_3.momentum1"].gather( dst=0, - out=None if ctx.rank != 0 - # Row wise - with partial rowwise adam, first state is point wise - else torch.empty( - (tables[3].num_embeddings, tables[3].embedding_dim), - device=ctx.device, + out=( + None + if ctx.rank != 0 + # Row wise - with partial rowwise adam, first state is point wise + else torch.empty( + (tables[3].num_embeddings, tables[3].embedding_dim), + device=ctx.device, + ) ), ) @@ -128,9 +143,12 @@ def _test_sharded_fused_optimizer_state_dict( 0 ].state_dict()["state"][""]["table_3.exp_avg_sq"].gather( dst=0, - out=None if ctx.rank != 0 - # Column wise - with partial rowwise adam, first state is point wise - else torch.empty((tables[2].num_embeddings,), device=ctx.device), + out=( + None + if ctx.rank != 0 + # Column wise - with partial rowwise adam, first state is point wise + else torch.empty((tables[2].num_embeddings,), device=ctx.device) + ), ) # pyre-ignore diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 0e845a80d..8f7a55adf 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -539,9 +539,13 @@ def forward(self, kjt: KeyedJaggedTensor) -> KJTList: fx_marker("KJT_ONE_TO_ALL_FORWARD_BEGIN", kjt) kjts: List[KeyedJaggedTensor] = kjt.split(self._splits) dist_kjts = [ - kjts[rank] - if self._device_type == "meta" - else kjts[rank].to(torch.device(self._device_type, rank), non_blocking=True) + ( + kjts[rank] + if self._device_type == "meta" + else kjts[rank].to( + torch.device(self._device_type, rank), non_blocking=True + ) + ) for rank in range(self._world_size) ] ret = KJTList(dist_kjts) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 137789e3e..d82a36cee 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -408,9 +408,11 @@ def __init__( if isinstance(sharding, DpSequenceEmbeddingSharding): self._lookups[index] = DistributedDataParallel( module=lookup, - device_ids=[device] - if self._device and self._device.type == "cuda" - else None, + device_ids=( + [device] + if self._device and self._device.type == "cuda" + else None + ), process_group=env.process_group, gradient_as_bucket_view=True, broadcast_buffers=True, @@ -510,9 +512,9 @@ def _initialize_torch_state(self) -> None: # noqa if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: continue self._model_parallel_name_to_local_shards[table_name] = [] - model_parallel_name_to_compute_kernel[ - table_name - ] = parameter_sharding.compute_kernel + model_parallel_name_to_compute_kernel[table_name] = ( + parameter_sharding.compute_kernel + ) self._name_to_table_size = {} for table in self._embedding_configs: @@ -556,12 +558,12 @@ def _initialize_torch_state(self) -> None: # noqa EmptyFusedOptimizer() ] # created ShardedTensors once in init, use in post_state_dict_hook - self._model_parallel_name_to_sharded_tensor[ - table_name - ] = ShardedTensor._init_from_local_shards( - local_shards, - self._name_to_table_size[table_name], - process_group=self._env.process_group, + self._model_parallel_name_to_sharded_tensor[table_name] = ( + ShardedTensor._init_from_local_shards( + local_shards, + self._name_to_table_size[table_name], + process_group=self._env.process_group, + ) ) def post_state_dict_hook( @@ -792,9 +794,11 @@ def input_dist( ctx.sharding_contexts.append( SequenceShardingContext( features_before_input_dist=features, - unbucketize_permute_tensor=input_dist.unbucketize_permute_tensor - if isinstance(input_dist, RwSparseFeaturesDist) - else None, + unbucketize_permute_tensor=( + input_dist.unbucketize_permute_tensor + if isinstance(input_dist, RwSparseFeaturesDist) + else None + ), ) ) return KJTListSplitsAwaitable(awaitables, ctx) diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index 985f3c4b0..a7302ebf5 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -124,12 +124,12 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: # Populate the remaining destinations that have a global metadata for key in key_to_local_shards: global_metadata = key_to_global_metadata[key] - destination[ - key - ] = ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=key_to_local_shards[key], - sharded_tensor_metadata=global_metadata, - process_group=pg, + destination[key] = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=key_to_local_shards[key], + sharded_tensor_metadata=global_metadata, + process_group=pg, + ) ) return destination diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 68d1176c7..5fcba78a2 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -563,9 +563,11 @@ def _create_lookup( "_dummy_embs_tensor", torch.empty( [0], - dtype=fused_params["output_dtype"].as_dtype() - if fused_params and "output_dtype" in fused_params - else torch.float16, + dtype=( + fused_params["output_dtype"].as_dtype() + if fused_params and "output_dtype" in fused_params + else torch.float16 + ), device=device, ), ) @@ -696,9 +698,11 @@ def _create_lookup( "_dummy_embs_tensor", torch.empty( [0], - dtype=fused_params["output_dtype"].as_dtype() - if fused_params and "output_dtype" in fused_params - else torch.float16, + dtype=( + fused_params["output_dtype"].as_dtype() + if fused_params and "output_dtype" in fused_params + else torch.float16 + ), device=device, ), ) diff --git a/torchrec/distributed/embedding_tower_sharding.py b/torchrec/distributed/embedding_tower_sharding.py index 0f6dbf83d..7c8a6f512 100644 --- a/torchrec/distributed/embedding_tower_sharding.py +++ b/torchrec/distributed/embedding_tower_sharding.py @@ -69,7 +69,7 @@ def _replace_sharding_with_intra_node( value.ranks = [rank % local_size for rank in value.ranks] if value.sharding_spec: # pyre-ignore [6, 16] - for (shard, rank) in zip(value.sharding_spec.shards, value.ranks): + for shard, rank in zip(value.sharding_spec.shards, value.ranks): shard.placement._rank = rank @@ -338,11 +338,13 @@ def _create_output_dist( # `List[Union[bool, float, int]]`. dim_sum_per_rank=dim_sum_per_rank, device=self._device, - codecs=self.qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if self.qcomm_codecs_registry - else None, + codecs=( + self.qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if self.qcomm_codecs_registry + else None + ), ) def output_dist( @@ -725,11 +727,13 @@ def _create_output_dist(self, output: torch.Tensor) -> None: # pyre-ignore dim_sum_per_rank=dim_sum_per_rank, device=self._device, - codecs=self.qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if self.qcomm_codecs_registry - else None, + codecs=( + self.qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if self.qcomm_codecs_registry + else None + ), ) def output_dist( diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 022328e38..83363e2ae 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -408,9 +408,9 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._embedding_bag_configs: List[ - EmbeddingBagConfig - ] = module.embedding_bag_configs() + self._embedding_bag_configs: List[EmbeddingBagConfig] = ( + module.embedding_bag_configs() + ) self._table_names: List[str] = [ config.name for config in self._embedding_bag_configs ] @@ -498,9 +498,11 @@ def __init__( if isinstance(sharding, DpPooledEmbeddingSharding): self._lookups[index] = DistributedDataParallel( module=lookup, - device_ids=[device] - if self._device and (self._device.type in {"cuda", "mtia"}) - else None, + device_ids=( + [device] + if self._device and (self._device.type in {"cuda", "mtia"}) + else None + ), process_group=env.process_group, gradient_as_bucket_view=True, broadcast_buffers=True, @@ -605,9 +607,9 @@ def _initialize_torch_state(self) -> None: # noqa if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: continue self._model_parallel_name_to_local_shards[table_name] = [] - model_parallel_name_to_compute_kernel[ - table_name - ] = parameter_sharding.compute_kernel + model_parallel_name_to_compute_kernel[table_name] = ( + parameter_sharding.compute_kernel + ) self._name_to_table_size = {} for table in self._embedding_bag_configs: @@ -651,12 +653,12 @@ def _initialize_torch_state(self) -> None: # noqa EmptyFusedOptimizer() ] # created ShardedTensors once in init, use in post_state_dict_hook - self._model_parallel_name_to_sharded_tensor[ - table_name - ] = ShardedTensor._init_from_local_shards( - local_shards, - self._name_to_table_size[table_name], - process_group=self._env.process_group, + self._model_parallel_name_to_sharded_tensor[table_name] = ( + ShardedTensor._init_from_local_shards( + local_shards, + self._name_to_table_size[table_name], + process_group=self._env.process_group, + ) ) def post_state_dict_hook( diff --git a/torchrec/distributed/fbgemm_qcomm_codec.py b/torchrec/distributed/fbgemm_qcomm_codec.py index df88c28df..be78370a8 100644 --- a/torchrec/distributed/fbgemm_qcomm_codec.py +++ b/torchrec/distributed/fbgemm_qcomm_codec.py @@ -98,9 +98,11 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode ), loss_scale=qcomms_config.forward_loss_scale, is_fwd=True, - row_dim=qcomms_config.fp8_quantize_dim - if qcomms_config.forward_precision == CommType.FP8 - else None, + row_dim=( + qcomms_config.fp8_quantize_dim + if qcomms_config.forward_precision == CommType.FP8 + else None + ), ), ) codecs.backward = cast( @@ -110,13 +112,15 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode qcomms_config.backward_precision ), loss_scale=qcomms_config.backward_loss_scale, - is_fwd=True - if qcomms_config.fp8_bwd_uses_143 - else False, # if fp8_bwd_uses_143 is True, bwd will use 1-4-3 + is_fwd=( + True if qcomms_config.fp8_bwd_uses_143 else False + ), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3 # if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2 - row_dim=qcomms_config.fp8_quantize_dim_bwd - if qcomms_config.backward_precision == CommType.FP8 - else None, + row_dim=( + qcomms_config.fp8_quantize_dim_bwd + if qcomms_config.backward_precision == CommType.FP8 + else None + ), ), ) return codecs diff --git a/torchrec/distributed/grouped_position_weighted.py b/torchrec/distributed/grouped_position_weighted.py index 3230aa37d..8e8939086 100644 --- a/torchrec/distributed/grouped_position_weighted.py +++ b/torchrec/distributed/grouped_position_weighted.py @@ -14,6 +14,7 @@ from torchrec.distributed.utils import append_prefix from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + # Will be deprecated soon, please use PositionWeightedProcessor, see the full # doc under modules/feature_processor.py class GroupedPositionWeightedModule(BaseGroupedFeatureProcessor): diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index 0d62241ab..ff047cb5e 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -107,13 +107,15 @@ def __init__( # TODO: This is a hack since _embedding_module doesn't need input # dist, so eliminating it so all fused a2a will ignore it. self._embedding_module._has_uninitialized_input_dist = False - self._managed_collision_collection: ShardedManagedCollisionCollection = mc_sharder.shard( - module._managed_collision_collection, - table_name_to_parameter_sharding, - env=env, - device=device, - # pyre-ignore - sharding_type_to_sharding=self._embedding_module._sharding_type_to_sharding, + self._managed_collision_collection: ShardedManagedCollisionCollection = ( + mc_sharder.shard( + module._managed_collision_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + # pyre-ignore + sharding_type_to_sharding=self._embedding_module._sharding_type_to_sharding, + ) ) self._return_remapped_features: bool = module._return_remapped_features diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index be959c2f2..d1d1eed42 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -147,9 +147,9 @@ def __init__( self._device = device self._env = env - self._table_name_to_parameter_sharding: Dict[ - str, ParameterSharding - ] = copy.deepcopy(table_name_to_parameter_sharding) + self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = ( + copy.deepcopy(table_name_to_parameter_sharding) + ) # TODO: create a MCSharding type instead of leveraging EmbeddingSharding self._sharding_type_to_sharding = sharding_type_to_sharding @@ -192,27 +192,27 @@ def _initialize_torch_state(self) -> None: if name not in shardable_buffers: continue - self._model_parallel_mc_buffer_name_to_sharded_tensor[ - name - ] = ShardedTensor._init_from_local_shards( - [ - Shard( - tensor=tensor, - metadata=ShardMetadata( - # pyre-ignore [6] - shard_offsets=[shard_offset], - # pyre-ignore [6] - shard_sizes=[shard_size], - placement=( - f"rank:{self._env.rank}/cuda:" - f"{get_local_rank(self._env.world_size, self._env.rank)}" + self._model_parallel_mc_buffer_name_to_sharded_tensor[name] = ( + ShardedTensor._init_from_local_shards( + [ + Shard( + tensor=tensor, + metadata=ShardMetadata( + # pyre-ignore [6] + shard_offsets=[shard_offset], + # pyre-ignore [6] + shard_sizes=[shard_size], + placement=( + f"rank:{self._env.rank}/cuda:" + f"{get_local_rank(self._env.world_size, self._env.rank)}" + ), ), - ), - ) - ], - # pyre-ignore [6] - torch.Size([global_size]), - process_group=self._env.process_group, + ) + ], + # pyre-ignore [6] + torch.Size([global_size]), + process_group=self._env.process_group, + ) ) def _post_state_dict_hook( @@ -266,9 +266,9 @@ def _create_managed_collision_modules( if sharding_type == ShardingType.ROW_WISE.value: assert isinstance(sharding, BaseRwEmbeddingSharding) - grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = sharding._grouped_embedding_configs + grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + sharding._grouped_embedding_configs + ) for group_config in grouped_embedding_configs: for table in group_config.embedding_tables: # pyre-ignore [16] @@ -282,14 +282,14 @@ def _create_managed_collision_modules( # 1) need to make TBE accept global indices for now force to local indices # 2) MCH is particularly nasty with a portion of each shard; ideally dont do this # 3) now create a feature_to_offset and pass into awaitable callbacks to act as raw id adder - self._managed_collision_modules[ - table.name - ] = mc_module.rebuild_with_output_id_range( - output_id_range=( - 0, # new_min_output_id, - new_range_size, # new_min_output_id + new_range_size, - ), - device=self._device, + self._managed_collision_modules[table.name] = ( + mc_module.rebuild_with_output_id_range( + output_id_range=( + 0, # new_min_output_id, + new_range_size, # new_min_output_id + new_range_size, + ), + device=self._device, + ) ) zch_size = self._managed_collision_modules[table.name]._zch_size @@ -424,9 +424,11 @@ def input_dist( ctx.sharding_contexts.append( SequenceShardingContext( features_before_input_dist=features, - unbucketize_permute_tensor=input_dist.unbucketize_permute_tensor - if isinstance(input_dist, RwSparseFeaturesDist) - else None, + unbucketize_permute_tensor=( + input_dist.unbucketize_permute_tensor + if isinstance(input_dist, RwSparseFeaturesDist) + else None + ), ) ) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index af174dd75..fff0f7226 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -76,19 +76,21 @@ def _to_sharding_plan( module_plan = plan.get(sharding_option.path, EmbeddingModuleShardingPlan()) module_plan[sharding_option.name] = ParameterSharding( - sharding_spec=None - if sharding_type == ShardingType.DATA_PARALLEL.value - else EnumerableShardingSpec( - [ - ShardMetadata( - shard_sizes=shard.size, - shard_offsets=shard.offset, - placement=placement( - compute_device, cast(int, shard.rank), local_size - ), - ) - for shard in shards - ] + sharding_spec=( + None + if sharding_type == ShardingType.DATA_PARALLEL.value + else EnumerableShardingSpec( + [ + ShardMetadata( + shard_sizes=shard.size, + shard_offsets=shard.offset, + placement=placement( + compute_device, cast(int, shard.rank), local_size + ), + ) + for shard in shards + ] + ) ), sharding_type=sharding_type, compute_kernel=sharding_option.compute_kernel, diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index bc355af3b..12e77fb26 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -149,9 +149,9 @@ def load( ) -> None: self._reset() all_fqns = set() - sharding_options_by_type_and_fqn: Dict[ - str, Dict[str, List[ShardingOption]] - ] = {} + sharding_options_by_type_and_fqn: Dict[str, Dict[str, List[ShardingOption]]] = ( + {} + ) for sharding_option in search_space: sharding_type = sharding_option.sharding_type diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 8fc5b3908..f4c732322 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -971,9 +971,11 @@ def calculate_shard_storages( ) ] ddr_sizes: List[int] = [ - input_size + output_size + ddr_specific_size - if compute_device in {"cpu", "mtia"} - else ddr_specific_size + ( + input_size + output_size + ddr_specific_size + if compute_device in {"cpu", "mtia"} + else ddr_specific_size + ) for input_size, output_size, ddr_specific_size in zip( input_sizes, output_sizes, @@ -1175,17 +1177,21 @@ def _calculate_rw_shard_io_sizes( ) input_sizes = [ - math.ceil(batch_inputs * world_size * input_data_type_size) - if prod(shard) != 0 - else 0 + ( + math.ceil(batch_inputs * world_size * input_data_type_size) + if prod(shard) != 0 + else 0 + ) for shard in shard_sizes ] output_sizes = [ - math.ceil( - batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ( + math.ceil( + batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ) + if prod(shard) != 0 + else 0 ) - if prod(shard) != 0 - else 0 for i, shard in enumerate(shard_sizes) ] @@ -1214,17 +1220,21 @@ def _calculate_twrw_shard_io_sizes( ) input_sizes = [ - math.ceil(batch_inputs * world_size * input_data_type_size) - if prod(shard) != 0 - else 0 + ( + math.ceil(batch_inputs * world_size * input_data_type_size) + if prod(shard) != 0 + else 0 + ) for shard in shard_sizes ] output_sizes = [ - math.ceil( - batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ( + math.ceil( + batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ) + if prod(shard) != 0 + else 0 ) - if prod(shard) != 0 - else 0 for i, shard in enumerate(shard_sizes) ] @@ -1239,9 +1249,11 @@ def _calculate_storage_specific_sizes( optimizer_class: Optional[Type[torch.optim.Optimizer]] = None, ) -> List[int]: tensor_sizes: List[int] = [ - math.ceil(storage * prod(size) / prod(shape)) - if sharding_type != ShardingType.DATA_PARALLEL.value - else storage + ( + math.ceil(storage * prod(size) / prod(shape)) + if sharding_type != ShardingType.DATA_PARALLEL.value + else storage + ) for size in shard_sizes ] optimizer_multipler: float = _get_optimizer_multipler(optimizer_class, shape) diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 1cd637149..6712e9e18 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -277,9 +277,11 @@ def log( ], ] feat_batch_sizes = [ - constraints[so.name].batch_sizes - if constraints and constraints.get(so.name) - else None + ( + constraints[so.name].batch_sizes + if constraints and constraints.get(so.name) + else None + ) for so in best_plan ] diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index 817d90c70..7573ec122 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -496,9 +496,11 @@ def test_scaleup(self) -> None: ( candidate.name, candidate.compute_kernel, - candidate.cache_params.load_factor - if candidate.cache_params - else None, + ( + candidate.cache_params.load_factor + if candidate.cache_params + else None + ), ) for candidate in proposal ], diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index f3fe32df5..02f786d79 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -478,14 +478,12 @@ def reserve( module: nn.Module, sharders: List[ModuleSharder[nn.Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None, - ) -> Topology: - ... + ) -> Topology: ... class PerfModel(abc.ABC): @abc.abstractmethod - def rate(self, plan: List[ShardingOption]) -> float: - ... + def rate(self, plan: List[ShardingOption]) -> float: ... class ShardEstimator(abc.ABC): @@ -498,8 +496,7 @@ def __init__( self, topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, - ) -> None: - ... + ) -> None: ... @abc.abstractmethod def estimate( @@ -524,8 +521,7 @@ def __init__( batch_size: int = BATCH_SIZE, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None, - ) -> None: - ... + ) -> None: ... @abc.abstractmethod def enumerate( @@ -557,8 +553,7 @@ def load( self, search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None, - ) -> None: - ... + ) -> None: ... @abc.abstractmethod def feedback( @@ -567,12 +562,10 @@ def feedback( plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None, - ) -> None: - ... + ) -> None: ... @abc.abstractmethod - def propose(self) -> Optional[List[ShardingOption]]: - ... + def propose(self) -> Optional[List[ShardingOption]]: ... class Partitioner(abc.ABC): diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index 2e33d31d1..5dd7851af 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -14,6 +14,7 @@ from torchrec.distributed.planner.types import Perf, ShardingOption, Storage from torchrec.distributed.types import ShardingType + # pyre-ignore[2] def sharder_name(t: Type[Any]) -> str: return t.__module__ + "." + t.__name__ @@ -135,7 +136,8 @@ class BinarySearchPredicate: def __init__(self, A: int, B: int, tolerance: int) -> None: """A = lower boundary (inclusive) B = upper boundary (inclusive) - tolerance = stop search early if remaining search range is less than tolerance""" + tolerance = stop search early if remaining search range is less than tolerance + """ self.left = A self.right = B self.tolerance = tolerance diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index d2fb0a6db..0dba016b3 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -306,9 +306,11 @@ def output_jt_dict( embedding_names_per_rank=embedding_names_per_rank, features_before_input_dist=features_before_input_dist, need_indices=need_indices, - rw_unbucketize_tensor=unbucketize_tensors[unbucketize_tensor_idx] - if unbucketize_tensor_idx != -1 - else None, + rw_unbucketize_tensor=( + unbucketize_tensors[unbucketize_tensor_idx] + if unbucketize_tensor_idx != -1 + else None + ), cw_features_to_permute_indices=features_to_permute_indices, key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates, ) @@ -405,9 +407,9 @@ def __init__( self._fused_params = fused_params - tbes: Dict[ - IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig - ] = get_tbes_to_register_from_iterable(self._lookups) + tbes: Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] = ( + get_tbes_to_register_from_iterable(self._lookups) + ) self._tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig @@ -529,9 +531,9 @@ def _generate_permute_coordinates_per_feature_per_sharding( ].tolist() for i, permute_idx in enumerate(permute_indices): permuted_coordinates[i] = coordinates[permute_idx] - self._key_to_feature_permuted_coordinates_per_sharding[idx][ - key - ] = torch.tensor(permuted_coordinates) + self._key_to_feature_permuted_coordinates_per_sharding[idx][key] = ( + torch.tensor(permuted_coordinates) + ) def _create_input_dist( self, @@ -617,9 +619,11 @@ def input_dist( InferSequenceShardingContext( features=input_dist_result, features_before_input_dist=features_by_sharding[i], - unbucketize_permute_tensor=input_dist.unbucketize_permute_tensor - if isinstance(input_dist, InferRwSparseFeaturesDist) - else None, + unbucketize_permute_tensor=( + input_dist.unbucketize_permute_tensor + if isinstance(input_dist, InferRwSparseFeaturesDist) + else None + ), ) ) return ListOfKJTList(ret) diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 5b4975a66..b147df15c 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -161,29 +161,36 @@ def __init__( ] if all(v is None for v in index_remapping): index_remapping = None - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - table.name, - local_rows, - local_cols - if self._quant_state_dict_split_scale_bias - else table.embedding_dim, - data_type_to_sparse_type(config.data_type), - location, - ) - for local_rows, local_cols, table, location in zip( - self._local_rows, self._local_cols, config.embedding_tables, managed - ) - ], - device=device, - # pyre-ignore - index_remapping=index_remapping, - pooling_mode=self._pooling, - feature_table_map=self._feature_table_map, - row_alignment=16, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - **(tbe_fused_params(fused_params) or {}), + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( + IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + table.name, + local_rows, + ( + local_cols + if self._quant_state_dict_split_scale_bias + else table.embedding_dim + ), + data_type_to_sparse_type(config.data_type), + location, + ) + for local_rows, local_cols, table, location in zip( + self._local_rows, + self._local_cols, + config.embedding_tables, + managed, + ) + ], + device=device, + # pyre-ignore + index_remapping=index_remapping, + pooling_mode=self._pooling, + feature_table_map=self._feature_table_map, + row_alignment=16, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + **(tbe_fused_params(fused_params) or {}), + ) ) if device is not None: self._emb_module.initialize_weights() @@ -230,9 +237,9 @@ def named_buffers( for config, (weight, weight_qscale, weight_qbias) in zip( self._config.embedding_tables, self.emb_module.split_embedding_weights_with_scale_bias( - split_scale_bias_mode=2 - if self._quant_state_dict_split_scale_bias - else 0 + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) ), ): yield append_prefix(prefix, f"{config.name}.weight"), weight @@ -248,9 +255,9 @@ def split_embedding_weights( return [ (weight, qscale, qbias) for weight, qscale, qbias in self.emb_module.split_embedding_weights_with_scale_bias( - split_scale_bias_mode=2 - if self._quant_state_dict_split_scale_bias - else 0 + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) ) ] @@ -309,27 +316,34 @@ def __init__( self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - table.name, - local_rows, - local_cols - if self._quant_state_dict_split_scale_bias - else table.embedding_dim, - data_type_to_sparse_type(config.data_type), - location, - ) - for local_rows, local_cols, table, location in zip( - self._local_rows, self._local_cols, config.embedding_tables, managed - ) - ], - device=device, - pooling_mode=PoolingMode.NONE, - feature_table_map=self._feature_table_map, - row_alignment=16, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - **(tbe_fused_params(fused_params) or {}), + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( + IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + table.name, + local_rows, + ( + local_cols + if self._quant_state_dict_split_scale_bias + else table.embedding_dim + ), + data_type_to_sparse_type(config.data_type), + location, + ) + for local_rows, local_cols, table, location in zip( + self._local_rows, + self._local_cols, + config.embedding_tables, + managed, + ) + ], + device=device, + pooling_mode=PoolingMode.NONE, + feature_table_map=self._feature_table_map, + row_alignment=16, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + **(tbe_fused_params(fused_params) or {}), + ) ) if device is not None: self._emb_module.initialize_weights() @@ -351,9 +365,9 @@ def split_embedding_weights( return [ (weight, qscale, qbias) for weight, qscale, qbias in self.emb_module.split_embedding_weights_with_scale_bias( - split_scale_bias_mode=2 - if self._quant_state_dict_split_scale_bias - else 0 + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) ) ] @@ -376,9 +390,9 @@ def named_buffers( for config, (weight, weight_qscale, weight_qbias) in zip( self._config.embedding_tables, self.emb_module.split_embedding_weights_with_scale_bias( - split_scale_bias_mode=2 - if self._quant_state_dict_split_scale_bias - else 0 + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) ), ): yield append_prefix(prefix, f"{config.name}.weight"), weight diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index dbb40c3ef..ac8b3c9c2 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -117,9 +117,9 @@ def __init__( device: Optional[torch.device] = None, ) -> None: super().__init__() - self._embedding_bag_configs: List[ - EmbeddingBagConfig - ] = module.embedding_bag_configs() + self._embedding_bag_configs: List[EmbeddingBagConfig] = ( + module.embedding_bag_configs() + ) self._sharding_type_to_sharding_infos: Dict[ str, List[EmbeddingShardingInfo] ] = create_sharding_infos_by_sharding( @@ -158,9 +158,9 @@ def __init__( self._has_uninitialized_output_dist: bool = True self._has_features_permute: bool = True - tbes: Dict[ - IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig - ] = get_tbes_to_register_from_iterable(self._lookups) + tbes: Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] = ( + get_tbes_to_register_from_iterable(self._lookups) + ) self._tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 277f5f4f3..31022d0ea 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -183,7 +183,11 @@ def _initialize_torch_state( # noqa: C901 ) # end of weight_qscale & weight_qbias section if table.pruning_indices_remapping is not None: - for (qparam, table_name_to_local_shards, _,) in [ + for ( + qparam, + table_name_to_local_shards, + _, + ) in [ ( table.pruning_indices_remapping, self._table_name_to_local_shards_pruning_index_remappings, @@ -250,11 +254,11 @@ def _initialize_torch_state( # noqa: C901 shards_metadata=[ls.metadata for ls in local_shards], size=torch.Size([global_rows, global_cols]), ) - table_name_to_sharded_tensor[ - table_name - ] = ShardedTensorBase._init_from_local_shards_and_global_metadata( - local_shards=local_shards, - sharded_tensor_metadata=global_metadata, + table_name_to_sharded_tensor[table_name] = ( + ShardedTensorBase._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=global_metadata, + ) ) for table_name_to_local_shards, table_name_to_sharded_tensor in [ @@ -279,9 +283,9 @@ def post_state_dict_hook( table_name, sharded_t, ) in module._table_name_to_sharded_tensor.items(): - destination[ - f"{prefix}{tables_weights_prefix}.{table_name}.weight" - ] = sharded_t + destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = ( + sharded_t + ) for sfx, dict_sharded_t, dict_t_list in [ ( @@ -439,9 +443,9 @@ def sharded_tbes_weights_spec( tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig ] = module.tbes_configs() - sharding_type_to_sharding_infos: Dict[ - str, List[EmbeddingShardingInfo] - ] = module.sharding_type_to_sharding_infos() + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = ( + module.sharding_type_to_sharding_infos() + ) table_shardings: Dict[str, str] = {} for ( diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index f58a6972e..b60a57b87 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -87,9 +87,9 @@ def _init_combined_embeddings(self) -> None: embedding_names: List[str] = super().embedding_names() embedding_dims: List[int] = super().embedding_dims() - embedding_shard_metadata: List[ - Optional[ShardMetadata] - ] = super().embedding_shard_metadata() + embedding_shard_metadata: List[Optional[ShardMetadata]] = ( + super().embedding_shard_metadata() + ) embedding_name_to_index_offset_tuples: Dict[str, List[Tuple[int, int]]] = {} for i, (name, metadata) in enumerate( diff --git a/torchrec/distributed/sharding/dp_sharding.py b/torchrec/distributed/sharding/dp_sharding.py index 773129526..2824f147a 100644 --- a/torchrec/distributed/sharding/dp_sharding.py +++ b/torchrec/distributed/sharding/dp_sharding.py @@ -52,13 +52,13 @@ def __init__( self._rank: int = self._env.rank self._world_size: int = self._env.world_size sharded_tables_per_rank = self._shard(sharding_infos) - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) - self._grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._grouped_embedding_configs_per_rank[env.rank] + self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + self._grouped_embedding_configs_per_rank[env.rank] + ) def _shard( self, diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 69b2a5077..559798dd7 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -60,11 +60,13 @@ def __init__( pg, [num_features] * pg.size(), device, - codecs=qcomm_codecs_registry.get( - CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if qcomm_codecs_registry - else None, + codecs=( + qcomm_codecs_registry.get( + CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if qcomm_codecs_registry + else None + ), ) def forward( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index e8cf19715..3b507a968 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -124,13 +124,13 @@ def __init__( self._device: torch.device = device sharded_tables_per_rank = self._shard(sharding_infos) self._need_pos = need_pos - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) - self._grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._grouped_embedding_configs_per_rank[self._rank] + self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + self._grouped_embedding_configs_per_rank[self._rank] + ) self._has_feature_processor: bool = False for group_config in self._grouped_embedding_configs: @@ -313,9 +313,11 @@ def forward( num_buckets=self._world_size, block_sizes=self._feature_block_sizes_tensor, output_permute=self._is_sequence, - bucketize_pos=self._has_feature_processor - if sparse_features.weights_or_none() is None - else self._need_pos, + bucketize_pos=( + self._has_feature_processor + if sparse_features.weights_or_none() is None + else self._need_pos + ), ) return self._dist(bucketized_features) @@ -515,16 +517,18 @@ def get_block_sizes_runtime_device( device=runtime_device, dtype=dtype, ), - [] - if embedding_shard_metadata is None - else [ - torch.tensor( - row_pos, - device=runtime_device, - dtype=dtype, - ) - for row_pos in embedding_shard_metadata - ], + ( + [] + if embedding_shard_metadata is None + else [ + torch.tensor( + row_pos, + device=runtime_device, + dtype=dtype, + ) + for row_pos in embedding_shard_metadata + ] + ), ) return tensor_cache[cache_key] @@ -563,9 +567,9 @@ def __init__( self._need_pos = need_pos self.unbucketize_permute_tensor: Optional[torch.Tensor] = None - self._embedding_shard_metadata: Optional[ - List[List[int]] - ] = embedding_shard_metadata + self._embedding_shard_metadata: Optional[List[List[int]]] = ( + embedding_shard_metadata + ) def forward( self, @@ -585,9 +589,11 @@ def forward( num_buckets=self._world_size, block_sizes=block_sizes, output_permute=self._is_sequence, - bucketize_pos=self._has_feature_processor - if sparse_features.weights_or_none() is None - else self._need_pos, + bucketize_pos=( + self._has_feature_processor + if sparse_features.weights_or_none() is None + else self._need_pos + ), block_bucketize_row_pos=_fx_wrap_block_bucketize_row_pos( block_bucketize_row_pos ), diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index bdfd1a21e..03f32e409 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -56,17 +56,17 @@ def __init__( batch_size_per_feature_pre_a2a, variable_batch_per_feature, ) - self.features_before_input_dist: Optional[ - KeyedJaggedTensor - ] = features_before_input_dist + self.features_before_input_dist: Optional[KeyedJaggedTensor] = ( + features_before_input_dist + ) self.input_splits: List[int] = input_splits if input_splits is not None else [] self.output_splits: List[int] = ( output_splits if output_splits is not None else [] ) self.sparse_features_recat: Optional[torch.Tensor] = sparse_features_recat - self.unbucketize_permute_tensor: Optional[ - torch.Tensor - ] = unbucketize_permute_tensor + self.unbucketize_permute_tensor: Optional[torch.Tensor] = ( + unbucketize_permute_tensor + ) self.lengths_after_input_dist: Optional[torch.Tensor] = lengths_after_input_dist def record_stream(self, stream: torch.cuda.streams.Stream) -> None: diff --git a/torchrec/distributed/sharding/tw_sequence_sharding.py b/torchrec/distributed/sharding/tw_sequence_sharding.py index e4498d604..c5c8730ba 100644 --- a/torchrec/distributed/sharding/tw_sequence_sharding.py +++ b/torchrec/distributed/sharding/tw_sequence_sharding.py @@ -61,11 +61,13 @@ def __init__( pg, features_per_rank, device, - codecs=qcomm_codecs_registry.get( - CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if qcomm_codecs_registry - else None, + codecs=( + qcomm_codecs_registry.get( + CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if qcomm_codecs_registry + else None + ), ) def forward( diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index eaf06d46d..6d600a061 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -75,17 +75,17 @@ def __init__( self._rank: int = self._env.rank sharded_tables_per_rank = self._shard(sharding_infos) - self._sharded_tables_per_rank: List[ - List[ShardedEmbeddingTable] - ] = sharded_tables_per_rank + self._sharded_tables_per_rank: List[List[ShardedEmbeddingTable]] = ( + sharded_tables_per_rank + ) - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) - self._grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._grouped_embedding_configs_per_rank[self._rank] + self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + self._grouped_embedding_configs_per_rank[self._rank] + ) def _shard( self, diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index e97c3742f..fc6776bee 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -83,12 +83,12 @@ def __init__( ) sharded_tables_per_rank = self._shard(sharding_infos) - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - self._grouped_embedding_configs_per_node: List[ - List[GroupedEmbeddingConfig] - ] = [] + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_node: List[List[GroupedEmbeddingConfig]] = ( + [] + ) self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) self._grouped_embedding_configs_per_node = [ self._grouped_embedding_configs_per_rank[rank] @@ -344,9 +344,11 @@ def forward( num_buckets=self._local_size, block_sizes=self._feature_block_sizes_tensor, output_permute=False, - bucketize_pos=self._has_feature_processor - if sparse_features.weights_or_none() is None - else self._need_pos, + bucketize_pos=( + self._has_feature_processor + if sparse_features.weights_or_none() is None + else self._need_pos + ), )[0].permute( self._sf_staggered_shuffle, self._sf_staggered_shuffle_tensor, diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index d0f96fbed..e41b00ef2 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -162,9 +162,11 @@ def _calculate_cw_shard_sizes_and_offsets( col_wise_shard_dim: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[int]]]: block_size: int = min( - _find_base_dim(col_wise_shard_dim, columns) - if col_wise_shard_dim - else _find_base_dim(MIN_CW_DIM, columns), + ( + _find_base_dim(col_wise_shard_dim, columns) + if col_wise_shard_dim + else _find_base_dim(MIN_CW_DIM, columns) + ), columns, ) @@ -192,7 +194,10 @@ def _get_parameter_size_offsets( world_size: int, col_wise_shard_dim: Optional[int] = None, ) -> List[Tuple[List[int], List[int]]]: - (shard_sizes, shard_offsets,) = calculate_shard_sizes_and_offsets( + ( + shard_sizes, + shard_offsets, + ) = calculate_shard_sizes_and_offsets( tensor=none_throws(param), world_size=world_size, local_world_size=local_size, @@ -252,31 +257,37 @@ def _get_parameter_sharding( compute_kernel: Optional[str] = None, ) -> ParameterSharding: return ParameterSharding( - sharding_spec=None - if sharding_type == ShardingType.DATA_PARALLEL.value - else EnumerableShardingSpec( - [ - ShardMetadata( - shard_sizes=size, - shard_offsets=offset, - placement=placement( - device_type, - none_throws(rank), - none_throws(local_size), + sharding_spec=( + None + if sharding_type == ShardingType.DATA_PARALLEL.value + else EnumerableShardingSpec( + [ + ShardMetadata( + shard_sizes=size, + shard_offsets=offset, + placement=( + placement( + device_type, + none_throws(rank), + none_throws(local_size), + ) + if not device_placement + else device_placement + ), ) - if not device_placement - else device_placement, - ) - for (size, offset, rank), device_placement in zip( - size_offset_ranks, - placements if placements else [None] * len(size_offset_ranks), - ) - ] + for (size, offset, rank), device_placement in zip( + size_offset_ranks, + placements if placements else [None] * len(size_offset_ranks), + ) + ] + ) ), sharding_type=sharding_type, - compute_kernel=compute_kernel - if compute_kernel - else _get_compute_kernel(sharder, param, sharding_type, device_type), + compute_kernel=( + compute_kernel + if compute_kernel + else _get_compute_kernel(sharder, param, sharding_type, device_type) + ), ranks=[rank for (_, _, rank) in size_offset_ranks], ) @@ -459,15 +470,17 @@ def placement_helper(device_type: str, index: int = 0) -> str: local_size, device_type, sharder, - placements=[ - placement_helper(sizes_placement[1], i) - for i in range(len(sizes_placement[0])) - ] - if sizes_placement - else None, - compute_kernel=EmbeddingComputeKernel.QUANT.value - if sizes_placement - else None, + placements=( + [ + placement_helper(sizes_placement[1], i) + for i in range(len(sizes_placement[0])) + ] + if sizes_placement + else None + ), + compute_kernel=( + EmbeddingComputeKernel.QUANT.value if sizes_placement else None + ), ) return _parameter_sharding_generator diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 5e5ecaab5..a0c98631b 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -779,7 +779,9 @@ def assert_weight_spec( tbe_idx: int = 0 if "rank:1/cuda:1" == placement: tbe_idx = 1 - sharded_weight_fqn: str = f"{ebc_fqn}.tbes.{tbe_idx}.{tbe_table_idxs[tbe_idx]}.{table_name}.weight" + sharded_weight_fqn: str = ( + f"{ebc_fqn}.tbes.{tbe_idx}.{tbe_table_idxs[tbe_idx]}.{table_name}.weight" + ) tbe_table_idxs[tbe_idx] += 1 assert sharded_weight_fqn in weights_spec wspec = weights_spec[sharded_weight_fqn] diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index f6c59bcae..abe37a4dd 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -103,11 +103,11 @@ def generate( if variable_batch_size: lengths = torch.zeros(batch_size * world_size).int() for r in range(world_size): - lengths[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] = lengths_[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] + lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = ( + lengths_[ + r * batch_size : r * batch_size + batch_size_by_rank[r] + ] + ) else: lengths = lengths_ num_indices = cast(int, torch.sum(lengths).item()) @@ -132,11 +132,11 @@ def generate( if variable_batch_size: lengths = torch.zeros(batch_size * world_size).int() for r in range(world_size): - lengths[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] = lengths_[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] + lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = ( + lengths_[ + r * batch_size : r * batch_size + batch_size_by_rank[r] + ] + ) else: lengths = lengths_ num_indices = cast(int, torch.sum(lengths).item()) @@ -382,11 +382,11 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": idlist_features=self.idlist_features.to( device=device, non_blocking=non_blocking ), - idscore_features=self.idscore_features.to( - device=device, non_blocking=non_blocking - ) - if self.idscore_features is not None - else None, + idscore_features=( + self.idscore_features.to(device=device, non_blocking=non_blocking) + if self.idscore_features is not None + else None + ), label=self.label.to(device=device, non_blocking=non_blocking), ) @@ -656,9 +656,11 @@ def __init__( [ PositionWeightedProcessor( max_feature_lengths=max_feature_lengths, - device=device - if device != torch.device("meta") - else torch.device("cpu"), + device=( + device + if device != torch.device("meta") + else torch.device("cpu") + ), ) for max_feature_lengths in max_feature_lengths_list ] diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index bef0ff9af..e96d315f7 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -69,9 +69,11 @@ def setUp(self, backend: str = "nccl") -> None: self.shared_features = [f"feature_{i}" for i in range(shared_features)] self.embedding_groups = { "group_0": [ - f"{feature}@{table.name}" - if feature in self.shared_features - else feature + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) for table in self.tables for feature in table.feature_names ] diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index ee142f20d..3355233c0 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -108,8 +108,7 @@ def __call__( ] = None, variable_batch_size: bool = False, long_indices: bool = True, - ) -> Tuple["ModelInput", List["ModelInput"]]: - ... + ) -> Tuple["ModelInput", List["ModelInput"]]: ... class VariableBatchModelInputCallable(Protocol): @@ -121,8 +120,7 @@ def __call__( tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]], pooling_avg: int = 10, global_constant_batch: bool = False, - ) -> Tuple["ModelInput", List["ModelInput"]]: - ... + ) -> Tuple["ModelInput", List["ModelInput"]]: ... def gen_model_and_input( @@ -174,23 +172,25 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [ - cast(VariableBatchModelInputCallable, generate)( - average_batch_size=batch_size, - world_size=world_size, - num_float_features=num_float_features, - tables=tables, - global_constant_batch=global_constant_batch, - ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( - world_size=world_size, - tables=tables, - dedup_tables=dedup_tables, - weighted_tables=weighted_tables or [], - num_float_features=num_float_features, - variable_batch_size=variable_batch_size, - batch_size=batch_size, - long_indices=long_indices, + ( + cast(VariableBatchModelInputCallable, generate)( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=tables, + global_constant_batch=global_constant_batch, + ) + if generate == ModelInput.generate_variable_batch_input + else cast(ModelInputCallable, generate)( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + ) ) ] return (model, inputs) @@ -261,12 +261,14 @@ def sharding_single_rank_test( (global_model, inputs) = gen_model_and_input( model_class=model_class, tables=tables, - generate=cast( - VariableBatchModelInputCallable, - ModelInput.generate_variable_batch_input, - ) - if variable_batch_per_feature - else ModelInput.generate, + generate=( + cast( + VariableBatchModelInputCallable, + ModelInput.generate_variable_batch_input, + ) + if variable_batch_per_feature + else ModelInput.generate + ), weighted_tables=weighted_tables, embedding_groups=embedding_groups, world_size=world_size, diff --git a/torchrec/distributed/tests/test_dist_data.py b/torchrec/distributed/tests/test_dist_data.py index b1c9421d8..9907df308 100644 --- a/torchrec/distributed/tests/test_dist_data.py +++ b/torchrec/distributed/tests/test_dist_data.py @@ -37,6 +37,7 @@ T = TypeVar("T", int, float, List[int]) + # Lightly adapted from Stack Overflow #10823877 def _flatten(iterable: Iterable[T]) -> Generator[T, None, None]: iterator, sentinel, stack = iter(iterable), object(), [] @@ -96,9 +97,11 @@ def _generate_sparse_features_batch( keys=keys, lengths=_to_tensor([lengths[key][i] for key in keys], torch.int), values=_to_tensor([values[key][i] for key in keys], torch.int), - weights=_to_tensor([weights[key][i] for key in keys], torch.float) - if weights - else None, + weights=( + _to_tensor([weights[key][i] for key in keys], torch.float) + if weights + else None + ), ) ) key_index = [] @@ -117,12 +120,14 @@ def _generate_sparse_features_batch( [values[key][j] for key, j in key_index], torch.int, ), - weights=_to_tensor( - [weights[key][j] for key, j in key_index], - torch.float, - ) - if weights - else None, + weights=( + _to_tensor( + [weights[key][j] for key, j in key_index], + torch.float, + ) + if weights + else None + ), ) ) return in_jagged, out_jagged @@ -168,9 +173,11 @@ def _generate_variable_batch_sparse_features_batch( stride_per_key_per_rank=batch_size_per_rank_per_feature[i], lengths=_to_tensor([lengths[key][i] for key in keys], torch.int), values=_to_tensor([values[key][i] for key in keys], torch.int), - weights=_to_tensor([weights[key][i] for key in keys], torch.float) - if weights - else None, + weights=( + _to_tensor([weights[key][i] for key in keys], torch.float) + if weights + else None + ), ) ) key_index = [] @@ -196,12 +203,14 @@ def _generate_variable_batch_sparse_features_batch( [values[key][j] for key, j in key_index], torch.int, ), - weights=_to_tensor( - [weights[key][j] for key, j in key_index], - torch.float, - ) - if weights - else None, + weights=( + _to_tensor( + [weights[key][j] for key, j in key_index], + torch.float, + ) + if weights + else None + ), ) ) return in_jagged, out_jagged @@ -269,12 +278,16 @@ def _validate( expected_output.values().cpu(), ) torch.testing.assert_close( - actual_output.weights().cpu() - if actual_output.weights_or_none() is not None - else [], - expected_output.weights().cpu() - if expected_output.weights_or_none() is not None - else [], + ( + actual_output.weights().cpu() + if actual_output.weights_or_none() is not None + else [] + ), + ( + expected_output.weights().cpu() + if expected_output.weights_or_none() is not None + else [] + ), ) torch.testing.assert_close( actual_output.lengths().cpu(), @@ -896,9 +909,9 @@ def _generate_sequence_embedding_batch( ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: world_size = len(splits) - tensor_by_feature: Dict[ - str, List[torch.Tensor] - ] = {} # Model parallel, key as feature + tensor_by_feature: Dict[str, List[torch.Tensor]] = ( + {} + ) # Model parallel, key as feature tensor_by_rank: Dict[str, List[torch.Tensor]] = {} # Data parallel, key as rank emb_by_rank_feature = {} diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py index a95995e6d..b7cf6e9a2 100644 --- a/torchrec/distributed/tests/test_embedding_sharding.py +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -343,22 +343,30 @@ def test_should_not_group_together( tables = [ ShardedEmbeddingTable( name=f"table_{i}", - data_type=data_types[i] - if distinct_key == "data_type" - else data_types[0], - pooling=pooling_types[i] - if distinct_key == "pooling_type" - else pooling_types[0], - has_feature_processor=has_feature_processors[i] - if distinct_key == "has_feature_processor" - else has_feature_processors[0], + data_type=( + data_types[i] if distinct_key == "data_type" else data_types[0] + ), + pooling=( + pooling_types[i] + if distinct_key == "pooling_type" + else pooling_types[0] + ), + has_feature_processor=( + has_feature_processors[i] + if distinct_key == "has_feature_processor" + else has_feature_processors[0] + ), fused_params=fused_params_group, # can't hash dicts - compute_kernel=compute_kernels[i] - if distinct_key == "compute_kernel" - else compute_kernels[0], - embedding_dim=embedding_dims[i] - if distinct_key == "embedding_dim" - else embedding_dims[0], + compute_kernel=( + compute_kernels[i] + if distinct_key == "compute_kernel" + else compute_kernels[0] + ), + embedding_dim=( + embedding_dims[i] + if distinct_key == "embedding_dim" + else embedding_dims[0] + ), num_embeddings=10000, ) for i in range(2) diff --git a/torchrec/distributed/tests/test_fp_embeddingbag.py b/torchrec/distributed/tests/test_fp_embeddingbag.py index 9fed37aa3..abacbdd2f 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag.py @@ -165,7 +165,7 @@ def _test_sharding( # noqa C901 unsharded_named_parameters = dict(sparse_arch.named_parameters()) sharded_named_parameters = dict(sharded_sparse_arch.named_parameters()) - for (fqn, param) in unsharded_named_parameters.items(): + for fqn, param in unsharded_named_parameters.items(): if "_feature_processors" not in fqn: continue @@ -242,9 +242,11 @@ def test_sharding_ebc( tables=embedding_bag_config, kjt_input_per_rank=kjt_input_per_rank, sharder=FeatureProcessedEmbeddingBagCollectionSharder(), - backend="nccl" - if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) - else "gloo", + backend=( + "nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo" + ), set_gradient_division=set_gradient_division, use_dmp=use_dmp, use_fp_collection=use_fp_collection, diff --git a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py index 79e3930da..50eb32ecc 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py @@ -43,23 +43,25 @@ def __init__( device=device, is_weighted=True, ), - cast( - Dict[str, FeatureProcessor], - { - "feature_0": PositionWeightedModule(max_feature_length=10), - "feature_1": PositionWeightedModule(max_feature_length=10), - "feature_2": PositionWeightedModule(max_feature_length=12), - "feature_3": PositionWeightedModule(max_feature_length=12), - }, - ) - if not use_fp_collection - else PositionWeightedModuleCollection( - max_feature_lengths={ - "feature_0": 10, - "feature_1": 10, - "feature_2": 12, - "feature_3": 12, - } + ( + cast( + Dict[str, FeatureProcessor], + { + "feature_0": PositionWeightedModule(max_feature_length=10), + "feature_1": PositionWeightedModule(max_feature_length=10), + "feature_2": PositionWeightedModule(max_feature_length=12), + "feature_3": PositionWeightedModule(max_feature_length=12), + }, + ) + if not use_fp_collection + else PositionWeightedModuleCollection( + max_feature_lengths={ + "feature_0": 10, + "feature_1": 10, + "feature_2": 12, + "feature_3": 12, + } + ) ), ).to(device) ) @@ -135,9 +137,9 @@ def compute_kernels( return [self._kernel_type] -def get_configs_and_kjt_inputs() -> Tuple[ - List[EmbeddingBagConfig], List[KeyedJaggedTensor] -]: +def get_configs_and_kjt_inputs() -> ( + Tuple[List[EmbeddingBagConfig], List[KeyedJaggedTensor]] +): embedding_bag_config = [ EmbeddingBagConfig( name="table_0", diff --git a/torchrec/distributed/tests/test_fx_jit.py b/torchrec/distributed/tests/test_fx_jit.py index f0ad08bad..2f2bae941 100644 --- a/torchrec/distributed/tests/test_fx_jit.py +++ b/torchrec/distributed/tests/test_fx_jit.py @@ -231,7 +231,7 @@ def DMP_QEBC( world_size: int, sharding_type: str, quant_state_dict_split_scale_bias: bool, - unwrap_dmp: bool + unwrap_dmp: bool, # pyre-ignore ) -> Tuple[torch.nn.Module, torch.nn.Module, List[Tuple]]: model_info = self._set_up_qebc(sharding_type, quant_state_dict_split_scale_bias) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 8a2957dda..4cfe82ecd 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -108,9 +108,9 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: return KeyedJaggedTensor( keys=features.keys(), values=features.values(), - weights=torch.cat(scores_list) - if scores_list - else features.weights_or_none(), + weights=( + torch.cat(scores_list) if scores_list else features.weights_or_none() + ), lengths=features.lengths(), stride=features.stride(), ) diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py index 165e0d7e5..53e2d6758 100644 --- a/torchrec/distributed/tests/test_mc_embedding.py +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -57,9 +57,11 @@ def mch_hash_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: mc_modules = {} mc_modules["table_0"] = MCHManagedCollisionModule( - zch_size=tables[0].num_embeddings - mch_size - if mch_size - else tables[0].num_embeddings, + zch_size=( + tables[0].num_embeddings - mch_size + if mch_size + else tables[0].num_embeddings + ), mch_size=mch_size, mch_hash_func=mch_hash_func if mch_size else None, input_hash_size=4000, @@ -69,9 +71,11 @@ def mch_hash_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: ) mc_modules["table_1"] = MCHManagedCollisionModule( - zch_size=tables[1].num_embeddings - mch_size - if mch_size - else tables[1].num_embeddings, + zch_size=( + tables[1].num_embeddings - mch_size + if mch_size + else tables[1].num_embeddings + ), mch_size=mch_size, mch_hash_func=mch_hash_func if mch_size else None, device=device, @@ -80,17 +84,19 @@ def mch_hash_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: eviction_policy=DistanceLFU_EvictionPolicy(), ) - self._mc_ec: ManagedCollisionEmbeddingCollection = ManagedCollisionEmbeddingCollection( - EmbeddingCollection( - tables=tables, - device=device, - ), - ManagedCollisionCollection( - managed_collision_modules=mc_modules, - # pyre-ignore - embedding_configs=tables, - ), - return_remapped_features=self._return_remapped, + self._mc_ec: ManagedCollisionEmbeddingCollection = ( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + # pyre-ignore + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) ) def forward( diff --git a/torchrec/distributed/tests/test_mc_embeddingbag.py b/torchrec/distributed/tests/test_mc_embeddingbag.py index fccd9b631..ba28bb270 100644 --- a/torchrec/distributed/tests/test_mc_embeddingbag.py +++ b/torchrec/distributed/tests/test_mc_embeddingbag.py @@ -57,9 +57,11 @@ def mch_hash_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: mc_modules = {} mc_modules["table_0"] = MCHManagedCollisionModule( - zch_size=tables[0].num_embeddings - mch_size - if mch_size - else tables[0].num_embeddings, + zch_size=( + tables[0].num_embeddings - mch_size + if mch_size + else tables[0].num_embeddings + ), mch_size=mch_size, mch_hash_func=mch_hash_func if mch_size else None, input_hash_size=4000, @@ -69,9 +71,11 @@ def mch_hash_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: ) mc_modules["table_1"] = MCHManagedCollisionModule( - zch_size=tables[1].num_embeddings - mch_size - if mch_size - else tables[1].num_embeddings, + zch_size=( + tables[1].num_embeddings - mch_size + if mch_size + else tables[1].num_embeddings + ), mch_size=mch_size, mch_hash_func=mch_hash_func if mch_size else None, device=device, @@ -80,17 +84,19 @@ def mch_hash_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: eviction_policy=DistanceLFU_EvictionPolicy(), ) - self._mc_ebc: ManagedCollisionEmbeddingBagCollection = ManagedCollisionEmbeddingBagCollection( - EmbeddingBagCollection( - tables=tables, - device=device, - ), - ManagedCollisionCollection( - managed_collision_modules=mc_modules, - # pyre-ignore - embedding_configs=tables, - ), - return_remapped_features=self._return_remapped, + self._mc_ebc: ManagedCollisionEmbeddingBagCollection = ( + ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + # pyre-ignore + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) ) def forward( diff --git a/torchrec/distributed/tests/test_qcomms_embedding_modules.py b/torchrec/distributed/tests/test_qcomms_embedding_modules.py index 067fac906..4a41e88b3 100644 --- a/torchrec/distributed/tests/test_qcomms_embedding_modules.py +++ b/torchrec/distributed/tests/test_qcomms_embedding_modules.py @@ -225,9 +225,11 @@ def test_parameter_sharding_ebc( ] sharder = EmbeddingBagCollectionSharder( - qcomm_codecs_registry=get_qcomm_codecs_registry(qcomms_config) - if qcomms_config is not None - else None + qcomm_codecs_registry=( + get_qcomm_codecs_registry(qcomms_config) + if qcomms_config is not None + else None + ) ) ebc = EmbeddingBagCollection(tables=embedding_bag_config) @@ -262,9 +264,11 @@ def test_parameter_sharding_ebc( ), }, kjt_input_per_rank=kjt_input_per_rank, - backend="nccl" - if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) - else "gloo", + backend=( + "nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo" + ), sharder=sharder, parameter_sharding_plan=parameter_sharding_plan, ) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index cf04ff192..05e14130b 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -322,9 +322,11 @@ def setUp(self) -> None: self.embedding_groups = { "group_0": [ - f"{feature}@{table.name}" - if feature in self.shared_features - else feature + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) for table in self.tables for feature in table.feature_names ] diff --git a/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py index a737974f4..b23b2ffed 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py @@ -64,9 +64,11 @@ def setUp(self, backend: str = "nccl") -> None: self.embedding_groups = { "group_0": [ - f"{feature}@{table.name}" - if feature in self.shared_features - else feature + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) for table in self.tables for feature in table.feature_names ] diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 171fe4a64..a1d47d72e 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -240,9 +240,11 @@ def block_bucketize_ref( values=torch.tensor( translated_indices, dtype=keyed_jagged_tensor.values().dtype ).cuda(), - weights=torch.tensor(translated_weights).float().cuda() - if weights_list - else None, + weights=( + torch.tensor(translated_weights).float().cuda() + if weights_list + else None + ), ) else: return KeyedJaggedTensor( @@ -308,9 +310,11 @@ def test_kjt_bucketize_before_all2all( # for each feature, calculate the minimum block size needed to # distribute all rows to the available trainers block_sizes_list = [ - math.ceil((max(feature_indices_list) + 1) / world_size) - if feature_indices_list - else 1 + ( + math.ceil((max(feature_indices_list) + 1) / world_size) + if feature_indices_list + else 1 + ) for feature_indices_list in indices_lists ] @@ -392,9 +396,11 @@ def test_kjt_bucketize_before_all2all_cpu( # for each feature, calculate the minimum block size needed to # distribute all rows to the available trainers block_sizes_list = [ - math.ceil((max(feature_indices_list) + 1) / world_size) - if feature_indices_list - else 1 + ( + math.ceil((max(feature_indices_list) + 1) / world_size) + if feature_indices_list + else 1 + ) for feature_indices_list in indices_lists ] block_bucketize_row_pos = [] if variable_bucket_pos else None diff --git a/torchrec/distributed/train_pipeline.py b/torchrec/distributed/train_pipeline.py index 5bf53ac04..2ebbd04d7 100644 --- a/torchrec/distributed/train_pipeline.py +++ b/torchrec/distributed/train_pipeline.py @@ -359,9 +359,11 @@ def __init__( len(request.awaitables) for request in requests ] self._lengths: List[int] = [ - len(awaitable.splits_tensors) - if isinstance(awaitable, KJTSplitsAllToAllMeta) - else 0 + ( + len(awaitable.splits_tensors) + if isinstance(awaitable, KJTSplitsAllToAllMeta) + else 0 + ) for awaitable in self._awaitables ] splits_tensors = [ @@ -451,9 +453,9 @@ class TrainPipelineContext: input_dist_tensors_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) module_contexts_next_batch: Dict[str, Multistreamable] = field(default_factory=dict) - fused_splits_awaitables: List[ - Tuple[List[str], FusedKJTListSplitsAwaitable] - ] = field(default_factory=list) + fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = ( + field(default_factory=list) + ) @dataclass @@ -1002,12 +1004,12 @@ def __init__( self._apply_jit = apply_jit # use two data streams to support two concurrent batches if device.type == "cuda": - self._memcpy_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream(priority=-1) - self._data_dist_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream(priority=-1) + self._memcpy_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.Stream(priority=-1) + ) + self._data_dist_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.Stream(priority=-1) + ) else: self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None @@ -1264,12 +1266,12 @@ def __init__( ) self._context = PrefetchTrainPipelineContext() if self._device.type == "cuda": - self._prefetch_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream() - self._default_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.current_stream() + self._prefetch_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.Stream() + ) + self._default_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.current_stream() + ) else: self._prefetch_stream: Optional[torch.cuda.streams.Stream] = None self._default_stream: Optional[torch.cuda.streams.Stream] = None @@ -1404,6 +1406,6 @@ def _prefetch(self, batch: Optional[In]) -> None: dist_input=data, forward_stream=self._default_stream ) self._context.module_input_post_prefetch[forward._name] = data - self._context.module_contexts_post_prefetch[ - forward._name - ] = self._context.module_contexts[forward._name] + self._context.module_contexts_post_prefetch[forward._name] = ( + self._context.module_contexts[forward._name] + ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index e504791f9..1512f4754 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -170,13 +170,11 @@ class QuantizedCommCodec(Generic[QuantizationContext]): def encode( self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... def decode( self, input_grad: torch.Tensor, ctx: Optional[QuantizationContext] = None - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @property def quantized_dtype(self) -> torch.dtype: @@ -295,7 +293,6 @@ def _wait_impl(self) -> W: class _LazyAwaitableMeta( GenericMeta, abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta ): - """ The _LazyAwaitableMeta class that inherits both ABCMeta and ProxyableClassMeta This is because ABCMeta/ProxyableClassMeta are both non-trival metaclasses @@ -430,6 +427,7 @@ def impl(*args, **kwargs): # install reflective magic methods for orig_method_name in torch.fx.graph.reflectable_magic_methods: as_magic = f"__r{orig_method_name}__" + # pyre-ignore [2, 3] def scope(method): # pyre-ignore [2, 3, 53] @@ -812,8 +810,7 @@ def shard( @property @abc.abstractmethod - def module_type(self) -> Type[M]: - ... + def module_type(self) -> Type[M]: ... @property def qcomm_codecs_registry(self) -> Optional[Dict[str, QuantizedCommCodecs]]: diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 57db6d315..1008d21dd 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -344,7 +344,6 @@ def merge_fused_params( fused_params: Optional[Dict[str, Any]] = None, param_fused_params: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - """ Configure the fused_params including cache_precision if the value is not preset. diff --git a/torchrec/fx/utils.py b/torchrec/fx/utils.py index 31a900611..bfb6607c5 100644 --- a/torchrec/fx/utils.py +++ b/torchrec/fx/utils.py @@ -15,6 +15,7 @@ # Not importing DistributedModelParallel here to avoid circular dependencies as DMP depends on torchrec.fx.tracer # def dmp_fx_trace_forward(dmp: DistributedModelParallel) + # pyre-ignore def fake_range(): # pyre-fixme[16]: Module `_C` has no attribute `_jit_tree_views`. diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 8e10df849..0977ec728 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -34,10 +34,12 @@ def quantize_feature( ) -> Tuple[torch.Tensor, ...]: return tuple( [ - input.half() - if isinstance(input, torch.Tensor) - and input.dtype in [torch.float32, torch.float64] - else input + ( + input.half() + if isinstance(input, torch.Tensor) + and input.dtype in [torch.float32, torch.float64] + else input + ) for input in inputs ] ) diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 0637eda52..5dcc5801b 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -125,7 +125,7 @@ def compute_auc_per_group( # get unique group indices group_indices = torch.unique(grouping_keys) - for (predictions_i, labels_i, weights_i) in zip(preds_t, labels_t, weights_t): + for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t): # Loop over each group auc_groups_sum = torch.tensor([0], dtype=torch.float32) for group_idx in group_indices: diff --git a/torchrec/metrics/auprc.py b/torchrec/metrics/auprc.py index d556ce533..dc4d1e438 100644 --- a/torchrec/metrics/auprc.py +++ b/torchrec/metrics/auprc.py @@ -120,7 +120,7 @@ def compute_auprc_per_group( # get unique group indices group_indices = torch.unique(grouping_keys) - for (predictions_i, labels_i, weights_i) in zip(predictions, labels, weights): + for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): # Loop over each group auprc_groups_sum = torch.tensor([0], dtype=torch.float32) for group_idx in group_indices: diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index edd6093ff..ef5fb9771 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -423,9 +423,9 @@ def _generate_state_metrics( ) -> Dict[str, StateMetric]: state_metrics: Dict[str, StateMetric] = {} for metric_enum in metrics_config.state_metrics: - metric_namespace: Optional[ - MetricNamespace - ] = STATE_METRICS_NAMESPACE_MAPPING.get(metric_enum, None) + metric_namespace: Optional[MetricNamespace] = ( + STATE_METRICS_NAMESPACE_MAPPING.get(metric_enum, None) + ) if metric_namespace is None: raise ValueError(f"Unknown StateMetrics {metric_enum}") full_namespace = compose_metric_namespace( diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index db39dc03a..6c7f5320a 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -54,9 +54,9 @@ class RecTaskInfo: label_name: str = "label" prediction_name: str = "prediction" weight_name: str = "weight" - session_metric_def: Optional[ - SessionMetricDef - ] = None # used for session level metrics + session_metric_def: Optional[SessionMetricDef] = ( + None # used for session level metrics + ) class RecComputeMode(Enum): diff --git a/torchrec/metrics/rauc.py b/torchrec/metrics/rauc.py index 9c38daf95..29ce77f98 100644 --- a/torchrec/metrics/rauc.py +++ b/torchrec/metrics/rauc.py @@ -171,7 +171,7 @@ def compute_rauc_per_group( # get unique group indices group_indices = torch.unique(grouping_keys) - for (predictions_i, labels_i, weights_i) in zip(preds_t, labels_t, weights_t): + for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t): # Loop over each group rauc_groups_sum = torch.tensor([0], dtype=torch.float32) for group_idx in group_indices: diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index 422a6bd6a..359fdbcb3 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -120,6 +120,7 @@ class RecMetricComputation(Metric, abc.ABC): process_group (Optional[ProcessGroup]): the process group used for the communication. Will use the default process group if not specified. """ + _batch_window_buffers: Optional[Dict[str, WindowBuffer]] def __init__( @@ -300,6 +301,7 @@ class RecMetric(nn.Module, abc.ABC): tasks=DefaultTaskInfo, ) """ + _computation_class: Type[RecMetricComputation] _namespace: MetricNamespaceBase _metrics_computations: nn.ModuleList @@ -417,10 +419,10 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType: for task, metric_value, has_valid_update in zip( self._tasks, metric_report.value, - self._metrics_computations[0].has_valid_update - if self._should_validate_update - else itertools.repeat( - 1 + ( + self._metrics_computations[0].has_valid_update + if self._should_validate_update + else itertools.repeat(1) ), # has_valid_update > 0 means the update is valid ): # The attribute has_valid_update is a tensor whose length equals to the diff --git a/torchrec/metrics/tests/test_recall_session.py b/torchrec/metrics/tests/test_recall_session.py index af0c09401..0034a0d09 100644 --- a/torchrec/metrics/tests/test_recall_session.py +++ b/torchrec/metrics/tests/test_recall_session.py @@ -48,9 +48,9 @@ def generate_model_output_test2() -> Dict[str, torch._tensor.Tensor]: } -def generate_model_output_with_no_positive_examples() -> Dict[ - str, torch._tensor.Tensor -]: +def generate_model_output_with_no_positive_examples() -> ( + Dict[str, torch._tensor.Tensor] +): return { "predictions": torch.tensor( [[1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8]] diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index 609bf8ab5..380e3e21a 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -77,7 +77,6 @@ def forward( self, features: JaggedTensor, ) -> JaggedTensor: - """ Args: features (JaggedTensor]): feature representation diff --git a/torchrec/modules/fp_embedding_modules.py b/torchrec/modules/fp_embedding_modules.py index 9dd9d0d0c..2bf654bdd 100644 --- a/torchrec/modules/fp_embedding_modules.py +++ b/torchrec/modules/fp_embedding_modules.py @@ -39,9 +39,11 @@ def apply_feature_processors_to_kjt( return KeyedJaggedTensor( keys=features.keys(), values=features.values(), - weights=torch.cat(processed_weights) - if processed_weights - else features.weights_or_none(), + weights=( + torch.cat(processed_weights) + if processed_weights + else features.weights_or_none() + ), lengths=features.lengths(), offsets=features._offsets, stride=features._stride, diff --git a/torchrec/modules/fused_embedding_modules.py b/torchrec/modules/fused_embedding_modules.py index 5ecf3965d..bf232e107 100644 --- a/torchrec/modules/fused_embedding_modules.py +++ b/torchrec/modules/fused_embedding_modules.py @@ -197,7 +197,7 @@ def _init_parameters(self) -> None: assert len(self._num_embeddings) == len( self._emb_module.split_embedding_weights() ) - for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip( + for rows, emb_dim, weight_init_min, weight_init_max, param in zip( self._rows, self._cols, self._weight_init_mins, diff --git a/torchrec/modules/lazy_extension.py b/torchrec/modules/lazy_extension.py index 3dc7c2481..b87bc50fa 100644 --- a/torchrec/modules/lazy_extension.py +++ b/torchrec/modules/lazy_extension.py @@ -89,8 +89,7 @@ def init_weights(m): class _LazyExtensionProtocol(_LazyProtocol): # pyre-ignore[2,3] - def _call_impl(self, *input, **kwargs): - ... + def _call_impl(self, *input, **kwargs): ... class LazyModuleExtensionMixin(LazyModuleMixin): diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 8b66e1e45..b1616253c 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -49,9 +49,9 @@ def __init__( super().__init__() self._managed_collision_collection = managed_collision_collection self._return_remapped_features = return_remapped_features - self._embedding_module: Union[ - EmbeddingBagCollection, EmbeddingCollection - ] = embedding_module + self._embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection] = ( + embedding_module + ) if isinstance(embedding_module, EmbeddingBagCollection): assert ( diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index c002deb83..b41940721 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -562,11 +562,11 @@ def update_metadata_and_generate_eviction_scores( ] # update metadata for matching ids - mch_last_access_iter[ - coalesced_history_mch_matching_indices - ] = coalesced_history_sorted_uniq_ids_last_access_iter[ - coalesced_history_mch_matching_elements_mask - ] + mch_last_access_iter[coalesced_history_mch_matching_indices] = ( + coalesced_history_sorted_uniq_ids_last_access_iter[ + coalesced_history_mch_matching_elements_mask + ] + ) # incoming non-matching ids new_sorted_uniq_ids_last_access = ( @@ -708,11 +708,11 @@ def update_metadata_and_generate_eviction_scores( ] += coalesced_history_sorted_unique_ids_counts[ coalesced_history_mch_matching_elements_mask ] - mch_last_access_iter[ - coalesced_history_mch_matching_indices - ] = coalesced_history_sorted_uniq_ids_last_access_iter[ - coalesced_history_mch_matching_elements_mask - ] + mch_last_access_iter[coalesced_history_mch_matching_indices] = ( + coalesced_history_sorted_uniq_ids_last_access_iter[ + coalesced_history_mch_matching_elements_mask + ] + ) # incoming non-matching ids new_sorted_uniq_ids_counts = coalesced_history_sorted_unique_ids_counts[ diff --git a/torchrec/modules/tests/test_crossnet.py b/torchrec/modules/tests/test_crossnet.py index 9686bbe3f..1aefc649c 100644 --- a/torchrec/modules/tests/test_crossnet.py +++ b/torchrec/modules/tests/test_crossnet.py @@ -16,6 +16,7 @@ VectorCrossNet, ) + # unit test for Full Rank CrossNet: CrossNet class TestCrossNet(unittest.TestCase): def test_cross_net_numercial_forward(self) -> None: diff --git a/torchrec/modules/tests/test_fused_embedding_modules.py b/torchrec/modules/tests/test_fused_embedding_modules.py index 1ef53a881..4ae80632a 100644 --- a/torchrec/modules/tests/test_fused_embedding_modules.py +++ b/torchrec/modules/tests/test_fused_embedding_modules.py @@ -572,6 +572,7 @@ def test_optimizer_fusion( ).to(device) opt = optimizer_type(ebc.parameters(), **optimizer_kwargs) + # pyre-ignore def run_one_training_step() -> None: fused_pooled_embeddings = fused_ebc(features) @@ -963,6 +964,7 @@ def test_optimizer_fusion( ).to(device) opt = optimizer_type(ec.parameters(), **optimizer_kwargs) + # pyre-ignore def run_one_training_step() -> None: fused_embeddings = fused_ec(features) diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 070bf9f44..a335e03a5 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -156,9 +156,11 @@ def construct_jagged_tensors( ) ret[key] = JaggedTensor( lengths=lengths_tuple[indices[0]], - values=embeddings_list[indices[0]] - if len(indices) == 1 - else torch.cat([embeddings_list[i] for i in indices], dim=1), + values=( + embeddings_list[indices[0]] + if len(indices) == 1 + else torch.cat([embeddings_list[i] for i in indices], dim=1) + ), # pyre-ignore weights=values_list[indices[0]] if need_indices else None, ) diff --git a/torchrec/optim/fused.py b/torchrec/optim/fused.py index f0df895f9..29693fd1d 100644 --- a/torchrec/optim/fused.py +++ b/torchrec/optim/fused.py @@ -20,12 +20,10 @@ class FusedOptimizer(KeyedOptimizer, abc.ABC): @abc.abstractmethod # pyre-ignore [2] - def step(self, closure: Any = None) -> None: - ... + def step(self, closure: Any = None) -> None: ... @abc.abstractmethod - def zero_grad(self, set_to_none: bool = False) -> None: - ... + def zero_grad(self, set_to_none: bool = False) -> None: ... def __repr__(self) -> str: return optim.Optimizer.__repr__(self) @@ -54,5 +52,4 @@ class FusedOptimizerModule(abc.ABC): @property @abc.abstractmethod - def fused_optimizer(self) -> KeyedOptimizer: - ... + def fused_optimizer(self) -> KeyedOptimizer: ... diff --git a/torchrec/optim/rowwise_adagrad.py b/torchrec/optim/rowwise_adagrad.py index adfa3eccf..b987afc72 100644 --- a/torchrec/optim/rowwise_adagrad.py +++ b/torchrec/optim/rowwise_adagrad.py @@ -201,7 +201,7 @@ def _single_tensor_adagrad( maximize: bool, ) -> None: - for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps): + for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): if grad.is_sparse: raise RuntimeError("RowWise adagrad cannot be used with sparse gradients") # update step diff --git a/torchrec/optim/tests/test_keyed.py b/torchrec/optim/tests/test_keyed.py index 99ced8afe..74ec0c044 100644 --- a/torchrec/optim/tests/test_keyed.py +++ b/torchrec/optim/tests/test_keyed.py @@ -161,13 +161,13 @@ def test_load_state_dict(self) -> None: fill_value=10.0, ) # pyre-ignore [6] - expected_state_dict["state"]["param_1"]["nested_dictionary"][ - "tensor" - ] = torch.tensor([70.0, 80.0]) + expected_state_dict["state"]["param_1"]["nested_dictionary"]["tensor"] = ( + torch.tensor([70.0, 80.0]) + ) # pyre-ignore [6] - expected_state_dict["state"]["param_1"]["optimizer_module"][ - "tensor" - ] = torch.tensor([90.0, 100.0]) + expected_state_dict["state"]["param_1"]["optimizer_module"]["tensor"] = ( + torch.tensor([90.0, 100.0]) + ) # pyre-ignore [6] expected_state_dict["param_groups"][0]["param_group_val_0"] = 8.0 # pyre-ignore [6] diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 60ca3cef7..b463627f2 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -359,9 +359,9 @@ def __init__( row_alignment=row_alignment, feature_table_map=feature_table_map, # pyre-ignore - index_remapping=index_remappings - if index_remappings_non_none_count > 0 - else None, + index_remapping=( + index_remappings if index_remappings_non_none_count > 0 else None + ), ) if weight_lists is None: emb_module.initialize_weights() @@ -761,9 +761,11 @@ def __init__( # noqa C901 table.num_embeddings, table.embedding_dim, data_type_to_sparse_type(data_type), - EmbeddingLocation.HOST - if device.type == "cpu" - else EmbeddingLocation.DEVICE, + ( + EmbeddingLocation.HOST + if device.type == "cpu" + else EmbeddingLocation.DEVICE + ), ) ) if table_name_to_quantized_weights: diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 23a16bf8d..2bff0f3d8 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -124,7 +124,7 @@ def _regroup_keyed_tensors( key_dim = keyed_tensors[0].key_dim() key_to_idx: dict[str, int] = {} - for (i, keyed_tensor) in enumerate(keyed_tensors): + for i, keyed_tensor in enumerate(keyed_tensors): for key in keyed_tensor.keys(): key_to_idx[key] = i @@ -570,15 +570,21 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor" offsets = self._offsets return JaggedTensor( values=self._values.to(device, non_blocking=non_blocking), - weights=weights.to(device, non_blocking=non_blocking) - if weights is not None - else None, - lengths=lengths.to(device, non_blocking=non_blocking) - if lengths is not None - else None, - offsets=offsets.to(device, non_blocking=non_blocking) - if offsets is not None - else None, + weights=( + weights.to(device, non_blocking=non_blocking) + if weights is not None + else None + ), + lengths=( + lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None + ), + offsets=( + offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None + ), ) @torch.jit.unused @@ -936,9 +942,11 @@ def _maybe_compute_kjt_to_jt_dict( lengths.view(-1, stride) if lengths.numel() != 0 else lengths, dim=0 ) split_offsets = torch.unbind( - _batched_lengths_to_offsets(lengths.view(-1, stride)) - if lengths.numel() != 0 - else lengths, + ( + _batched_lengths_to_offsets(lengths.view(-1, stride)) + if lengths.numel() != 0 + else lengths + ), dim=0, ) @@ -1252,9 +1260,9 @@ def __init__( self._offset_per_key: Optional[List[int]] = offset_per_key self._index_per_key: Optional[Dict[str, int]] = index_per_key self._jt_dict: Optional[Dict[str, JaggedTensor]] = jt_dict - self._inverse_indices: Optional[ - Tuple[List[str], torch.Tensor] - ] = inverse_indices + self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = ( + inverse_indices + ) self._lengths_offset_per_key: List[int] = [] @staticmethod @@ -1353,9 +1361,9 @@ def concat( weights=torch.cat(weight_list, dim=0) if is_weighted else None, lengths=torch.cat(length_list, dim=0), stride=stride, - stride_per_key_per_rank=stride_per_key_per_rank - if variable_stride_per_key - else None, + stride_per_key_per_rank=( + stride_per_key_per_rank if variable_stride_per_key else None + ), length_per_key=length_per_key if has_length_per_key else None, ) @@ -1388,9 +1396,11 @@ def empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": return KeyedJaggedTensor( keys=[], values=torch.empty(0, device=kjt.device(), dtype=kjt.values().dtype), - weights=None - if kjt.weights_or_none() is None - else torch.empty(0, device=kjt.device(), dtype=kjt.weights().dtype), + weights=( + None + if kjt.weights_or_none() is None + else torch.empty(0, device=kjt.device(), dtype=kjt.weights().dtype) + ), lengths=torch.empty(0, device=kjt.device(), dtype=kjt.lengths().dtype), stride=stride, stride_per_key_per_rank=stride_per_key_per_rank, @@ -1625,12 +1635,14 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: device=self.device(), dtype=self._values.dtype, ), - weights=None - if self.weights_or_none() is None - else torch.tensor( - empty_int_list, - device=self.device(), - dtype=self.weights().dtype, + weights=( + None + if self.weights_or_none() is None + else torch.tensor( + empty_int_list, + device=self.device(), + dtype=self.weights().dtype, + ) ), lengths=torch.tensor( empty_int_list, device=self.device(), dtype=torch.int @@ -1663,9 +1675,11 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: KeyedJaggedTensor( keys=keys, values=self._values[start_offset:end_offset], - weights=None - if self.weights_or_none() is None - else self.weights()[start_offset:end_offset], + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), lengths=self.lengths()[ self.lengths_offset_per_key()[ start @@ -1758,9 +1772,9 @@ def permute( offset_per_key=None, index_per_key=None, jt_dict=None, - inverse_indices=self.inverse_indices_or_none() - if include_inverse_indices - else None, + inverse_indices=( + self.inverse_indices_or_none() if include_inverse_indices else None + ), ) return kjt @@ -1796,9 +1810,11 @@ def __getitem__(self, key: str) -> JaggedTensor: ) return JaggedTensor( values=self._values[start_offset:end_offset], - weights=None - if self.weights_or_none() is None - else self.weights()[start_offset:end_offset], + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), lengths=self.lengths()[ self.lengths_offset_per_key()[index] : self.lengths_offset_per_key()[ index + 1 @@ -1871,12 +1887,16 @@ def to( keys=self._keys, values=self._values.to(device, non_blocking=non_blocking), weights=weights, - lengths=lengths.to(device, non_blocking=non_blocking) - if lengths is not None - else None, - offsets=offsets.to(device, non_blocking=non_blocking) - if offsets is not None - else None, + lengths=( + lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None + ), + offsets=( + offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None + ), stride=stride, stride_per_key_per_rank=stride_per_key_per_rank, length_per_key=length_per_key, diff --git a/torchrec/types.py b/torchrec/types.py index 391cf2ca5..f00832aff 100644 --- a/torchrec/types.py +++ b/torchrec/types.py @@ -14,8 +14,7 @@ class CopyMixIn: @abstractmethod - def copy(self, device: torch.device) -> nn.Module: - ... + def copy(self, device: torch.device) -> nn.Module: ... class ModuleCopyMixin(CopyMixIn):