Skip to content

Commit

Permalink
fix lints all() usages to short circuit (#1683)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1683

tsia

Reviewed By: henrylhtsang

Differential Revision: D53443279

fbshipit-source-id: c9b8847c1ef49a0249a6ba4fddad592765ec2eed
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Feb 6, 2024
1 parent 02771aa commit fed5202
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 38 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
)

if all(
[opt_state is not None for opt_state in shard_params.optimizer_states]
opt_state is not None for opt_state in shard_params.optimizer_states
):
# pyre-ignore
def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
Expand Down
12 changes: 4 additions & 8 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,15 +468,11 @@ def shardable_parameters(self, module: M) -> Dict[str, nn.Parameter]:

if self._shardable_params:
assert all(
[
table_name in self._shardable_params
for table_name in shardable_params.keys()
]
table_name in self._shardable_params
for table_name in shardable_params.keys()
) or all(
[
table_name not in self._shardable_params
for table_name in shardable_params.keys()
]
table_name not in self._shardable_params
for table_name in shardable_params.keys()
), f"Cannot partially shard {type(module)}, please check sharder kwargs"
shardable_params = {
table_name: param
Expand Down
6 changes: 2 additions & 4 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,8 @@ def __init__(
)
else:
table_wise_sharded_only: bool = all(
[
sharding_type == ShardingType.TABLE_WISE.value
for sharding_type in self._sharding_type_to_sharding.keys()
]
sharding_type == ShardingType.TABLE_WISE.value
for sharding_type in self._sharding_type_to_sharding.keys()
)
assert (
table_wise_sharded_only
Expand Down
6 changes: 2 additions & 4 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,8 @@ def __init__(
)
else:
table_wise_sharded_only: bool = all(
[
sharding_type == ShardingType.TABLE_WISE.value
for sharding_type in self._sharding_type_to_sharding.keys()
]
sharding_type == ShardingType.TABLE_WISE.value
for sharding_type in self._sharding_type_to_sharding.keys()
)
assert (
table_wise_sharded_only
Expand Down
24 changes: 9 additions & 15 deletions torchrec/distributed/tests/test_quant_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,14 +583,12 @@ def test_shard_one_ebc_cuda(
)

self.assertTrue(
all([param.device == device for param in dmp.module.sparse.ebc.buffers()])
all(param.device == device for param in dmp.module.sparse.ebc.buffers())
)
self.assertTrue(
all(
[
param.device == torch.device("cpu")
for param in dmp.module.sparse.weighted_ebc.buffers()
]
param.device == torch.device("cpu")
for param in dmp.module.sparse.weighted_ebc.buffers()
)
)

Expand Down Expand Up @@ -664,14 +662,12 @@ def test_shard_one_ebc_meta(
)

self.assertTrue(
all([param.device == device for param in dmp.module.sparse.ebc.buffers()])
all(param.device == device for param in dmp.module.sparse.ebc.buffers())
)
self.assertTrue(
all(
[
param.device == torch.device("meta")
for param in dmp.module.sparse.weighted_ebc.buffers()
]
param.device == torch.device("meta")
for param in dmp.module.sparse.weighted_ebc.buffers()
)
)

Expand Down Expand Up @@ -743,14 +739,12 @@ def test_shard_all_ebcs(
)

self.assertTrue(
all([param.device == device for param in dmp.module.sparse.ebc.buffers()])
all(param.device == device for param in dmp.module.sparse.ebc.buffers())
)
self.assertTrue(
all(
[
param.device == device
for param in dmp.module.sparse.weighted_ebc.buffers()
]
param.device == device
for param in dmp.module.sparse.weighted_ebc.buffers()
)
)

Expand Down
6 changes: 2 additions & 4 deletions torchrec/distributed/tests/test_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,7 @@ def test_multi_dataloader_pipelining(self) -> None:
self.assertEqual(len(cpu_preds), len(gpu_preds))
self.assertTrue(
all(
[
cpu_pred.size() == gpu_pred.size()
for cpu_pred, gpu_pred in zip(cpu_preds, gpu_preds)
]
cpu_pred.size() == gpu_pred.size()
for cpu_pred, gpu_pred in zip(cpu_preds, gpu_preds)
)
)
2 changes: 1 addition & 1 deletion torchrec/distributed/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
KJTListSplitsAwaitable,
)
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule
from torchrec.distributed.types import Awaitable, NoWait
from torchrec.distributed.types import Awaitable
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable, Pipelineable

Expand Down
2 changes: 1 addition & 1 deletion torchrec/optim/rowwise_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def adagrad(
See :class:`~torch.optim.Adagrad` for details.
"""

if not all([isinstance(t, torch.Tensor) for t in state_steps]):
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
Expand Down

0 comments on commit fed5202

Please sign in to comment.