From fed5202be56d7196f7ace818a02b5b904044caee Mon Sep 17 00:00:00 2001 From: Joshua Deng Date: Mon, 5 Feb 2024 19:53:34 -0800 Subject: [PATCH] fix lints `all()` usages to short circuit (#1683) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1683 tsia Reviewed By: henrylhtsang Differential Revision: D53443279 fbshipit-source-id: c9b8847c1ef49a0249a6ba4fddad592765ec2eed --- .../distributed/batched_embedding_kernel.py | 2 +- torchrec/distributed/embedding_types.py | 12 ++++------ torchrec/distributed/quant_embedding.py | 6 ++--- torchrec/distributed/quant_embeddingbag.py | 6 ++--- .../tests/test_quant_model_parallel.py | 24 +++++++------------ .../distributed/tests/test_train_pipeline.py | 6 ++--- torchrec/distributed/train_pipeline.py | 2 +- torchrec/optim/rowwise_adagrad.py | 2 +- 8 files changed, 22 insertions(+), 38 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index d7d1d0d01..3fa3dd499 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -283,7 +283,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( ) if all( - [opt_state is not None for opt_state in shard_params.optimizer_states] + opt_state is not None for opt_state in shard_params.optimizer_states ): # pyre-ignore def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor: diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 34e3cdcaf..a4c2c533a 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -468,15 +468,11 @@ def shardable_parameters(self, module: M) -> Dict[str, nn.Parameter]: if self._shardable_params: assert all( - [ - table_name in self._shardable_params - for table_name in shardable_params.keys() - ] + table_name in self._shardable_params + for table_name in shardable_params.keys() ) or all( - [ - table_name not in self._shardable_params - for table_name in shardable_params.keys() - ] + table_name not in self._shardable_params + for table_name in shardable_params.keys() ), f"Cannot partially shard {type(module)}, please check sharder kwargs" shardable_params = { table_name: param diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 311336594..bbe3c7ec5 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -429,10 +429,8 @@ def __init__( ) else: table_wise_sharded_only: bool = all( - [ - sharding_type == ShardingType.TABLE_WISE.value - for sharding_type in self._sharding_type_to_sharding.keys() - ] + sharding_type == ShardingType.TABLE_WISE.value + for sharding_type in self._sharding_type_to_sharding.keys() ) assert ( table_wise_sharded_only diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index de1b58993..dbb40c3ef 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -182,10 +182,8 @@ def __init__( ) else: table_wise_sharded_only: bool = all( - [ - sharding_type == ShardingType.TABLE_WISE.value - for sharding_type in self._sharding_type_to_sharding.keys() - ] + sharding_type == ShardingType.TABLE_WISE.value + for sharding_type in self._sharding_type_to_sharding.keys() ) assert ( table_wise_sharded_only diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index 643bcef08..98da1a4fc 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -583,14 +583,12 @@ def test_shard_one_ebc_cuda( ) self.assertTrue( - all([param.device == device for param in dmp.module.sparse.ebc.buffers()]) + all(param.device == device for param in dmp.module.sparse.ebc.buffers()) ) self.assertTrue( all( - [ - param.device == torch.device("cpu") - for param in dmp.module.sparse.weighted_ebc.buffers() - ] + param.device == torch.device("cpu") + for param in dmp.module.sparse.weighted_ebc.buffers() ) ) @@ -664,14 +662,12 @@ def test_shard_one_ebc_meta( ) self.assertTrue( - all([param.device == device for param in dmp.module.sparse.ebc.buffers()]) + all(param.device == device for param in dmp.module.sparse.ebc.buffers()) ) self.assertTrue( all( - [ - param.device == torch.device("meta") - for param in dmp.module.sparse.weighted_ebc.buffers() - ] + param.device == torch.device("meta") + for param in dmp.module.sparse.weighted_ebc.buffers() ) ) @@ -743,14 +739,12 @@ def test_shard_all_ebcs( ) self.assertTrue( - all([param.device == device for param in dmp.module.sparse.ebc.buffers()]) + all(param.device == device for param in dmp.module.sparse.ebc.buffers()) ) self.assertTrue( all( - [ - param.device == device - for param in dmp.module.sparse.weighted_ebc.buffers() - ] + param.device == device + for param in dmp.module.sparse.weighted_ebc.buffers() ) ) diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index 6a4e357c0..496b10b4e 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -565,9 +565,7 @@ def test_multi_dataloader_pipelining(self) -> None: self.assertEqual(len(cpu_preds), len(gpu_preds)) self.assertTrue( all( - [ - cpu_pred.size() == gpu_pred.size() - for cpu_pred, gpu_pred in zip(cpu_preds, gpu_preds) - ] + cpu_pred.size() == gpu_pred.size() + for cpu_pred, gpu_pred in zip(cpu_preds, gpu_preds) ) ) diff --git a/torchrec/distributed/train_pipeline.py b/torchrec/distributed/train_pipeline.py index 50240de7b..e749af2ce 100644 --- a/torchrec/distributed/train_pipeline.py +++ b/torchrec/distributed/train_pipeline.py @@ -42,7 +42,7 @@ KJTListSplitsAwaitable, ) from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule -from torchrec.distributed.types import Awaitable, NoWait +from torchrec.distributed.types import Awaitable from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable, Pipelineable diff --git a/torchrec/optim/rowwise_adagrad.py b/torchrec/optim/rowwise_adagrad.py index 7c3aa14fa..adfa3eccf 100644 --- a/torchrec/optim/rowwise_adagrad.py +++ b/torchrec/optim/rowwise_adagrad.py @@ -170,7 +170,7 @@ def adagrad( See :class:`~torch.optim.Adagrad` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( "API has changed, `state_steps` argument must contain a list of singleton tensors" )