Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (4/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447727

fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent 4168f11 commit 4e41029
Show file tree
Hide file tree
Showing 86 changed files with 934 additions and 731 deletions.
8 changes: 5 additions & 3 deletions examples/inference/dlrm_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings,
num_embeddings=(
self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
Expand Down
8 changes: 5 additions & 3 deletions examples/inference/dlrm_predict_single_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings,
num_embeddings=(
self.model_config.num_embeddings_per_feature[feature_idx]
if self.model_config.num_embeddings is None
else self.model_config.num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
Expand Down
3 changes: 2 additions & 1 deletion examples/nvt_dataloader/nvt_binary_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def __getitem__(self, idx: int):
"""Numerical features are returned in the order they appear in the channel spec section
For performance reasons, this is required to be the order they are saved in, as specified
by the relevant chunk in source spec.
Categorical features are returned in the order they appear in the channel spec section"""
Categorical features are returned in the order they appear in the channel spec section
"""

if idx >= self._num_entries:
raise IndexError()
Expand Down
22 changes: 13 additions & 9 deletions examples/nvt_dataloader/train_torchrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,11 @@ def main(argv: List[str]):
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=args.embedding_dim,
num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx]
if num_embeddings_per_feature is not None
else args.num_embeddings,
num_embeddings=(
none_throws(num_embeddings_per_feature)[feature_idx]
if num_embeddings_per_feature is not None
else args.num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
Expand All @@ -232,9 +234,9 @@ def main(argv: List[str]):

train_model = fuse_embedding_optimizer(
train_model,
optimizer_type=torchrec.optim.RowWiseAdagrad
if args.adagrad
else torch.optim.SGD,
optimizer_type=(
torchrec.optim.RowWiseAdagrad if args.adagrad else torch.optim.SGD
),
optimizer_kwargs={"learning_rate": args.learning_rate},
device=torch.device("meta"),
)
Expand Down Expand Up @@ -270,9 +272,11 @@ def main(argv: List[str]):

non_fused_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adagrad(params, lr=args.learning_rate)
if args.adagrad
else torch.optim.SGD(params, lr=args.learning_rate),
lambda params: (
torch.optim.Adagrad(params, lr=args.learning_rate)
if args.adagrad
else torch.optim.SGD(params, lr=args.learning_rate)
),
)

opt = trec_optim.keyed.CombinedOptimizer(
Expand Down
12 changes: 6 additions & 6 deletions examples/retrieval/modules/two_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def __init__(
embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[
0
].embedding_dim
self._feature_names_query: List[
str
] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
self._candidate_feature_names: List[
str
] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
self._feature_names_query: List[str] = (
embedding_bag_collection.embedding_bag_configs()[0].feature_names
)
self._candidate_feature_names: List[str] = (
embedding_bag_collection.embedding_bag_configs()[1].feature_names
)
self.ebc = embedding_bag_collection
self.query_proj = MLP(
in_size=embedding_dim, layer_sizes=layer_sizes, device=device
Expand Down
10 changes: 5 additions & 5 deletions tools/lint/black_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def main() -> None:

logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.NOTSET
if args.verbose
else logging.DEBUG
if len(args.filenames) < 1000
else logging.INFO,
level=(
logging.NOTSET
if args.verbose
else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
),
stream=sys.stderr,
)

Expand Down
14 changes: 8 additions & 6 deletions torchrec/datasets/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,13 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
return dense, sparse, label

dense, sparse, labels = [], [], []
for (row_dense, row_sparse, row_label) in CriteoIterDataPipe(
for row_dense, row_sparse, row_label in CriteoIterDataPipe(
[in_file],
row_mapper=row_mapper
if not (dataset_name == "criteo_kaggle" and "test" in in_file)
else row_mapper_with_fake_label_constant,
row_mapper=(
row_mapper
if not (dataset_name == "criteo_kaggle" and "test" in in_file)
else row_mapper_with_fake_label_constant
),
):
dense.append(row_dense)
sparse.append(row_sparse)
Expand All @@ -261,7 +263,7 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
labels_np = labels_np.reshape((-1, 1))

path_manager = PathManagerFactory().get(path_manager_key)
for (fname, arr) in [
for fname, arr in [
(out_dense_file, dense_np),
(out_sparse_file, sparse_np),
(out_labels_file, labels_np),
Expand Down Expand Up @@ -665,7 +667,7 @@ def shuffle(
curr_first_row = curr_last_row

# Directly copy over the last day's files since they will be used for validation and testing.
for (part, input_dir) in [
for part, input_dir in [
("sparse", input_dir_sparse),
("dense", input_dir_labels_and_dense),
("labels", input_dir_labels_and_dense),
Expand Down
2 changes: 1 addition & 1 deletion torchrec/datasets/tests/test_criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _validate_dataloader_sample(
) -> None:
unbatched_samples = [{} for _ in range(self._sample_len(sample))]
for k, batched_values in sample.items():
for (idx, value) in enumerate(batched_values):
for idx, value in enumerate(batched_values):
unbatched_samples[idx][k] = value
for sample in unbatched_samples:
self._validate_sample(sample, train=train)
Expand Down
4 changes: 1 addition & 3 deletions torchrec/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def train_filter(
decimal_places: int,
idx: int,
) -> bool:
return (key_fn(idx) % 10**decimal_places) < round(
train_perc * 10**decimal_places
)
return (key_fn(idx) % 10**decimal_places) < round(train_perc * 10**decimal_places)


def val_filter(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/benchmark/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def main() -> None:
mb = int(float(num * dim) / 1024 / 1024)
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb"

report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
report: str = (
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
)
report += f"Module: {module_name}\n"
report += tables_info
report += "\n"
Expand Down
6 changes: 3 additions & 3 deletions torchrec/distributed/composable/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def _test_sharding( # noqa C901
kjt_input_per_rank[ctx.rank]
)

unsharded_model_pred_jt_dict_this_rank: Dict[
str, JaggedTensor
] = unsharded_model_pred_jt_dict[ctx.rank]
unsharded_model_pred_jt_dict_this_rank: Dict[str, JaggedTensor] = (
unsharded_model_pred_jt_dict[ctx.rank]
)

embedding_names = unsharded_model_pred_jt_dict_this_rank.keys()
assert set(unsharded_model_pred_jt_dict_this_rank.keys()) == set(
Expand Down
8 changes: 5 additions & 3 deletions torchrec/distributed/composable/tests/test_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,11 @@ def test_sharding_ebc(
},
kjt_input_per_rank=kjt_input_per_rank,
sharder=TestEmbeddingBagCollectionSharder(sharding_type=sharding_type),
backend="nccl"
if (torch.cuda.is_available() and torch.cuda.device_count() >= 2)
else "gloo",
backend=(
"nccl"
if (torch.cuda.is_available() and torch.cuda.device_count() >= 2)
else "gloo"
),
constraints=constraints,
is_data_parallel=(sharding_type == ShardingType.DATA_PARALLEL.value),
use_apply_optimizer_in_backward=use_apply_optimizer_in_backward,
Expand Down
62 changes: 40 additions & 22 deletions torchrec/distributed/composable/tests/test_fused_optim_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,60 +77,78 @@ def _test_sharded_fused_optimizer_state_dict(
0
].state_dict()["state"][""]["table_0.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# sharded column, each shard will have rowwise state
else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# sharded column, each shard will have rowwise state
else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device)
),
)

ebc.embedding_bags["table_1"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_1.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# sharded rowwise
else torch.empty((tables[1].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# sharded rowwise
else torch.empty((tables[1].num_embeddings,), device=ctx.device)
),
)

ebc.embedding_bags["table_2"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_2.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[2].num_embeddings, tables[2].embedding_dim),
device=ctx.device,
out=(
None
if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[2].num_embeddings, tables[2].embedding_dim),
device=ctx.device,
)
),
)

ebc.embedding_bags["table_2"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_2.exp_avg_sq"].gather(
dst=0,
out=None if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device)
),
)

ebc.embedding_bags["table_3"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_3.momentum1"].gather(
dst=0,
out=None if ctx.rank != 0
# Row wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[3].num_embeddings, tables[3].embedding_dim),
device=ctx.device,
out=(
None
if ctx.rank != 0
# Row wise - with partial rowwise adam, first state is point wise
else torch.empty(
(tables[3].num_embeddings, tables[3].embedding_dim),
device=ctx.device,
)
),
)

ebc.embedding_bags["table_3"].weight._in_backward_optimizers[
0
].state_dict()["state"][""]["table_3.exp_avg_sq"].gather(
dst=0,
out=None if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((tables[2].num_embeddings,), device=ctx.device),
out=(
None
if ctx.rank != 0
# Column wise - with partial rowwise adam, first state is point wise
else torch.empty((tables[2].num_embeddings,), device=ctx.device)
),
)

# pyre-ignore
Expand Down
10 changes: 7 additions & 3 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,13 @@ def forward(self, kjt: KeyedJaggedTensor) -> KJTList:
fx_marker("KJT_ONE_TO_ALL_FORWARD_BEGIN", kjt)
kjts: List[KeyedJaggedTensor] = kjt.split(self._splits)
dist_kjts = [
kjts[rank]
if self._device_type == "meta"
else kjts[rank].to(torch.device(self._device_type, rank), non_blocking=True)
(
kjts[rank]
if self._device_type == "meta"
else kjts[rank].to(
torch.device(self._device_type, rank), non_blocking=True
)
)
for rank in range(self._world_size)
]
ret = KJTList(dist_kjts)
Expand Down
34 changes: 19 additions & 15 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,11 @@ def __init__(
if isinstance(sharding, DpSequenceEmbeddingSharding):
self._lookups[index] = DistributedDataParallel(
module=lookup,
device_ids=[device]
if self._device and self._device.type == "cuda"
else None,
device_ids=(
[device]
if self._device and self._device.type == "cuda"
else None
),
process_group=env.process_group,
gradient_as_bucket_view=True,
broadcast_buffers=True,
Expand Down Expand Up @@ -510,9 +512,9 @@ def _initialize_torch_state(self) -> None: # noqa
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
continue
self._model_parallel_name_to_local_shards[table_name] = []
model_parallel_name_to_compute_kernel[
table_name
] = parameter_sharding.compute_kernel
model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
)

self._name_to_table_size = {}
for table in self._embedding_configs:
Expand Down Expand Up @@ -556,12 +558,12 @@ def _initialize_torch_state(self) -> None: # noqa
EmptyFusedOptimizer()
]
# created ShardedTensors once in init, use in post_state_dict_hook
self._model_parallel_name_to_sharded_tensor[
table_name
] = ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
)
)

def post_state_dict_hook(
Expand Down Expand Up @@ -792,9 +794,11 @@ def input_dist(
ctx.sharding_contexts.append(
SequenceShardingContext(
features_before_input_dist=features,
unbucketize_permute_tensor=input_dist.unbucketize_permute_tensor
if isinstance(input_dist, RwSparseFeaturesDist)
else None,
unbucketize_permute_tensor=(
input_dist.unbucketize_permute_tensor
if isinstance(input_dist, RwSparseFeaturesDist)
else None
),
)
)
return KJTListSplitsAwaitable(awaitables, ctx)
Expand Down
Loading

0 comments on commit 4e41029

Please sign in to comment.