Skip to content

Commit 4e41029

Browse files
amyreesefacebook-github-bot
authored andcommitted
apply Black 2024 style in fbcode (4/16)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: aleivag Differential Revision: D54447727 fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
1 parent 4168f11 commit 4e41029

File tree

86 files changed

+934
-731
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+934
-731
lines changed

examples/inference/dlrm_predict.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
139139
EmbeddingBagConfig(
140140
name=f"t_{feature_name}",
141141
embedding_dim=self.model_config.embedding_dim,
142-
num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
143-
if self.model_config.num_embeddings is None
144-
else self.model_config.num_embeddings,
142+
num_embeddings=(
143+
self.model_config.num_embeddings_per_feature[feature_idx]
144+
if self.model_config.num_embeddings is None
145+
else self.model_config.num_embeddings
146+
),
145147
feature_names=[feature_name],
146148
)
147149
for feature_idx, feature_name in enumerate(

examples/inference/dlrm_predict_single_gpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
5050
EmbeddingBagConfig(
5151
name=f"t_{feature_name}",
5252
embedding_dim=self.model_config.embedding_dim,
53-
num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
54-
if self.model_config.num_embeddings is None
55-
else self.model_config.num_embeddings,
53+
num_embeddings=(
54+
self.model_config.num_embeddings_per_feature[feature_idx]
55+
if self.model_config.num_embeddings is None
56+
else self.model_config.num_embeddings
57+
),
5658
feature_names=[feature_name],
5759
)
5860
for feature_idx, feature_name in enumerate(

examples/nvt_dataloader/nvt_binary_dataloader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def __getitem__(self, idx: int):
9494
"""Numerical features are returned in the order they appear in the channel spec section
9595
For performance reasons, this is required to be the order they are saved in, as specified
9696
by the relevant chunk in source spec.
97-
Categorical features are returned in the order they appear in the channel spec section"""
97+
Categorical features are returned in the order they appear in the channel spec section
98+
"""
9899

99100
if idx >= self._num_entries:
100101
raise IndexError()

examples/nvt_dataloader/train_torchrec.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,11 @@ def main(argv: List[str]):
208208
EmbeddingBagConfig(
209209
name=f"t_{feature_name}",
210210
embedding_dim=args.embedding_dim,
211-
num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx]
212-
if num_embeddings_per_feature is not None
213-
else args.num_embeddings,
211+
num_embeddings=(
212+
none_throws(num_embeddings_per_feature)[feature_idx]
213+
if num_embeddings_per_feature is not None
214+
else args.num_embeddings
215+
),
214216
feature_names=[feature_name],
215217
)
216218
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
@@ -232,9 +234,9 @@ def main(argv: List[str]):
232234

233235
train_model = fuse_embedding_optimizer(
234236
train_model,
235-
optimizer_type=torchrec.optim.RowWiseAdagrad
236-
if args.adagrad
237-
else torch.optim.SGD,
237+
optimizer_type=(
238+
torchrec.optim.RowWiseAdagrad if args.adagrad else torch.optim.SGD
239+
),
238240
optimizer_kwargs={"learning_rate": args.learning_rate},
239241
device=torch.device("meta"),
240242
)
@@ -270,9 +272,11 @@ def main(argv: List[str]):
270272

271273
non_fused_optimizer = KeyedOptimizerWrapper(
272274
dict(in_backward_optimizer_filter(model.named_parameters())),
273-
lambda params: torch.optim.Adagrad(params, lr=args.learning_rate)
274-
if args.adagrad
275-
else torch.optim.SGD(params, lr=args.learning_rate),
275+
lambda params: (
276+
torch.optim.Adagrad(params, lr=args.learning_rate)
277+
if args.adagrad
278+
else torch.optim.SGD(params, lr=args.learning_rate)
279+
),
276280
)
277281

278282
opt = trec_optim.keyed.CombinedOptimizer(

examples/retrieval/modules/two_tower.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def __init__(
7171
embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[
7272
0
7373
].embedding_dim
74-
self._feature_names_query: List[
75-
str
76-
] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
77-
self._candidate_feature_names: List[
78-
str
79-
] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
74+
self._feature_names_query: List[str] = (
75+
embedding_bag_collection.embedding_bag_configs()[0].feature_names
76+
)
77+
self._candidate_feature_names: List[str] = (
78+
embedding_bag_collection.embedding_bag_configs()[1].feature_names
79+
)
8080
self.ebc = embedding_bag_collection
8181
self.query_proj = MLP(
8282
in_size=embedding_dim, layer_sizes=layer_sizes, device=device

tools/lint/black_linter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ def main() -> None:
176176

177177
logging.basicConfig(
178178
format="<%(threadName)s:%(levelname)s> %(message)s",
179-
level=logging.NOTSET
180-
if args.verbose
181-
else logging.DEBUG
182-
if len(args.filenames) < 1000
183-
else logging.INFO,
179+
level=(
180+
logging.NOTSET
181+
if args.verbose
182+
else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
183+
),
184184
stream=sys.stderr,
185185
)
186186

torchrec/datasets/criteo.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,13 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
234234
return dense, sparse, label
235235

236236
dense, sparse, labels = [], [], []
237-
for (row_dense, row_sparse, row_label) in CriteoIterDataPipe(
237+
for row_dense, row_sparse, row_label in CriteoIterDataPipe(
238238
[in_file],
239-
row_mapper=row_mapper
240-
if not (dataset_name == "criteo_kaggle" and "test" in in_file)
241-
else row_mapper_with_fake_label_constant,
239+
row_mapper=(
240+
row_mapper
241+
if not (dataset_name == "criteo_kaggle" and "test" in in_file)
242+
else row_mapper_with_fake_label_constant
243+
),
242244
):
243245
dense.append(row_dense)
244246
sparse.append(row_sparse)
@@ -261,7 +263,7 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
261263
labels_np = labels_np.reshape((-1, 1))
262264

263265
path_manager = PathManagerFactory().get(path_manager_key)
264-
for (fname, arr) in [
266+
for fname, arr in [
265267
(out_dense_file, dense_np),
266268
(out_sparse_file, sparse_np),
267269
(out_labels_file, labels_np),
@@ -665,7 +667,7 @@ def shuffle(
665667
curr_first_row = curr_last_row
666668

667669
# Directly copy over the last day's files since they will be used for validation and testing.
668-
for (part, input_dir) in [
670+
for part, input_dir in [
669671
("sparse", input_dir_sparse),
670672
("dense", input_dir_labels_and_dense),
671673
("labels", input_dir_labels_and_dense),

torchrec/datasets/tests/test_criteo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _validate_dataloader_sample(
7070
) -> None:
7171
unbatched_samples = [{} for _ in range(self._sample_len(sample))]
7272
for k, batched_values in sample.items():
73-
for (idx, value) in enumerate(batched_values):
73+
for idx, value in enumerate(batched_values):
7474
unbatched_samples[idx][k] = value
7575
for sample in unbatched_samples:
7676
self._validate_sample(sample, train=train)

torchrec/datasets/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def train_filter(
7777
decimal_places: int,
7878
idx: int,
7979
) -> bool:
80-
return (key_fn(idx) % 10**decimal_places) < round(
81-
train_perc * 10**decimal_places
82-
)
80+
return (key_fn(idx) % 10**decimal_places) < round(train_perc * 10**decimal_places)
8381

8482

8583
def val_filter(

torchrec/distributed/benchmark/benchmark_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def main() -> None:
196196
mb = int(float(num * dim) / 1024 / 1024)
197197
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb"
198198

199-
report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
199+
report: str = (
200+
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
201+
)
200202
report += f"Module: {module_name}\n"
201203
report += tables_info
202204
report += "\n"

torchrec/distributed/composable/tests/test_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def _test_sharding( # noqa C901
127127
kjt_input_per_rank[ctx.rank]
128128
)
129129

130-
unsharded_model_pred_jt_dict_this_rank: Dict[
131-
str, JaggedTensor
132-
] = unsharded_model_pred_jt_dict[ctx.rank]
130+
unsharded_model_pred_jt_dict_this_rank: Dict[str, JaggedTensor] = (
131+
unsharded_model_pred_jt_dict[ctx.rank]
132+
)
133133

134134
embedding_names = unsharded_model_pred_jt_dict_this_rank.keys()
135135
assert set(unsharded_model_pred_jt_dict_this_rank.keys()) == set(

torchrec/distributed/composable/tests/test_embeddingbag.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,11 @@ def test_sharding_ebc(
385385
},
386386
kjt_input_per_rank=kjt_input_per_rank,
387387
sharder=TestEmbeddingBagCollectionSharder(sharding_type=sharding_type),
388-
backend="nccl"
389-
if (torch.cuda.is_available() and torch.cuda.device_count() >= 2)
390-
else "gloo",
388+
backend=(
389+
"nccl"
390+
if (torch.cuda.is_available() and torch.cuda.device_count() >= 2)
391+
else "gloo"
392+
),
391393
constraints=constraints,
392394
is_data_parallel=(sharding_type == ShardingType.DATA_PARALLEL.value),
393395
use_apply_optimizer_in_backward=use_apply_optimizer_in_backward,

torchrec/distributed/composable/tests/test_fused_optim_nccl.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,60 +77,78 @@ def _test_sharded_fused_optimizer_state_dict(
7777
0
7878
].state_dict()["state"][""]["table_0.momentum1"].gather(
7979
dst=0,
80-
out=None if ctx.rank != 0
81-
# sharded column, each shard will have rowwise state
82-
else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device),
80+
out=(
81+
None
82+
if ctx.rank != 0
83+
# sharded column, each shard will have rowwise state
84+
else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device)
85+
),
8386
)
8487

8588
ebc.embedding_bags["table_1"].weight._in_backward_optimizers[
8689
0
8790
].state_dict()["state"][""]["table_1.momentum1"].gather(
8891
dst=0,
89-
out=None if ctx.rank != 0
90-
# sharded rowwise
91-
else torch.empty((tables[1].num_embeddings,), device=ctx.device),
92+
out=(
93+
None
94+
if ctx.rank != 0
95+
# sharded rowwise
96+
else torch.empty((tables[1].num_embeddings,), device=ctx.device)
97+
),
9298
)
9399

94100
ebc.embedding_bags["table_2"].weight._in_backward_optimizers[
95101
0
96102
].state_dict()["state"][""]["table_2.momentum1"].gather(
97103
dst=0,
98-
out=None if ctx.rank != 0
99-
# Column wise - with partial rowwise adam, first state is point wise
100-
else torch.empty(
101-
(tables[2].num_embeddings, tables[2].embedding_dim),
102-
device=ctx.device,
104+
out=(
105+
None
106+
if ctx.rank != 0
107+
# Column wise - with partial rowwise adam, first state is point wise
108+
else torch.empty(
109+
(tables[2].num_embeddings, tables[2].embedding_dim),
110+
device=ctx.device,
111+
)
103112
),
104113
)
105114

106115
ebc.embedding_bags["table_2"].weight._in_backward_optimizers[
107116
0
108117
].state_dict()["state"][""]["table_2.exp_avg_sq"].gather(
109118
dst=0,
110-
out=None if ctx.rank != 0
111-
# Column wise - with partial rowwise adam, first state is point wise
112-
else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device),
119+
out=(
120+
None
121+
if ctx.rank != 0
122+
# Column wise - with partial rowwise adam, first state is point wise
123+
else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device)
124+
),
113125
)
114126

115127
ebc.embedding_bags["table_3"].weight._in_backward_optimizers[
116128
0
117129
].state_dict()["state"][""]["table_3.momentum1"].gather(
118130
dst=0,
119-
out=None if ctx.rank != 0
120-
# Row wise - with partial rowwise adam, first state is point wise
121-
else torch.empty(
122-
(tables[3].num_embeddings, tables[3].embedding_dim),
123-
device=ctx.device,
131+
out=(
132+
None
133+
if ctx.rank != 0
134+
# Row wise - with partial rowwise adam, first state is point wise
135+
else torch.empty(
136+
(tables[3].num_embeddings, tables[3].embedding_dim),
137+
device=ctx.device,
138+
)
124139
),
125140
)
126141

127142
ebc.embedding_bags["table_3"].weight._in_backward_optimizers[
128143
0
129144
].state_dict()["state"][""]["table_3.exp_avg_sq"].gather(
130145
dst=0,
131-
out=None if ctx.rank != 0
132-
# Column wise - with partial rowwise adam, first state is point wise
133-
else torch.empty((tables[2].num_embeddings,), device=ctx.device),
146+
out=(
147+
None
148+
if ctx.rank != 0
149+
# Column wise - with partial rowwise adam, first state is point wise
150+
else torch.empty((tables[2].num_embeddings,), device=ctx.device)
151+
),
134152
)
135153

136154
# pyre-ignore

torchrec/distributed/dist_data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,9 +539,13 @@ def forward(self, kjt: KeyedJaggedTensor) -> KJTList:
539539
fx_marker("KJT_ONE_TO_ALL_FORWARD_BEGIN", kjt)
540540
kjts: List[KeyedJaggedTensor] = kjt.split(self._splits)
541541
dist_kjts = [
542-
kjts[rank]
543-
if self._device_type == "meta"
544-
else kjts[rank].to(torch.device(self._device_type, rank), non_blocking=True)
542+
(
543+
kjts[rank]
544+
if self._device_type == "meta"
545+
else kjts[rank].to(
546+
torch.device(self._device_type, rank), non_blocking=True
547+
)
548+
)
545549
for rank in range(self._world_size)
546550
]
547551
ret = KJTList(dist_kjts)

torchrec/distributed/embedding.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,11 @@ def __init__(
408408
if isinstance(sharding, DpSequenceEmbeddingSharding):
409409
self._lookups[index] = DistributedDataParallel(
410410
module=lookup,
411-
device_ids=[device]
412-
if self._device and self._device.type == "cuda"
413-
else None,
411+
device_ids=(
412+
[device]
413+
if self._device and self._device.type == "cuda"
414+
else None
415+
),
414416
process_group=env.process_group,
415417
gradient_as_bucket_view=True,
416418
broadcast_buffers=True,
@@ -510,9 +512,9 @@ def _initialize_torch_state(self) -> None: # noqa
510512
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
511513
continue
512514
self._model_parallel_name_to_local_shards[table_name] = []
513-
model_parallel_name_to_compute_kernel[
514-
table_name
515-
] = parameter_sharding.compute_kernel
515+
model_parallel_name_to_compute_kernel[table_name] = (
516+
parameter_sharding.compute_kernel
517+
)
516518

517519
self._name_to_table_size = {}
518520
for table in self._embedding_configs:
@@ -556,12 +558,12 @@ def _initialize_torch_state(self) -> None: # noqa
556558
EmptyFusedOptimizer()
557559
]
558560
# created ShardedTensors once in init, use in post_state_dict_hook
559-
self._model_parallel_name_to_sharded_tensor[
560-
table_name
561-
] = ShardedTensor._init_from_local_shards(
562-
local_shards,
563-
self._name_to_table_size[table_name],
564-
process_group=self._env.process_group,
561+
self._model_parallel_name_to_sharded_tensor[table_name] = (
562+
ShardedTensor._init_from_local_shards(
563+
local_shards,
564+
self._name_to_table_size[table_name],
565+
process_group=self._env.process_group,
566+
)
565567
)
566568

567569
def post_state_dict_hook(
@@ -792,9 +794,11 @@ def input_dist(
792794
ctx.sharding_contexts.append(
793795
SequenceShardingContext(
794796
features_before_input_dist=features,
795-
unbucketize_permute_tensor=input_dist.unbucketize_permute_tensor
796-
if isinstance(input_dist, RwSparseFeaturesDist)
797-
else None,
797+
unbucketize_permute_tensor=(
798+
input_dist.unbucketize_permute_tensor
799+
if isinstance(input_dist, RwSparseFeaturesDist)
800+
else None
801+
),
798802
)
799803
)
800804
return KJTListSplitsAwaitable(awaitables, ctx)

0 commit comments

Comments
 (0)