Skip to content

Commit 709f554

Browse files
Ivan Kobzarevfacebook-github-bot
authored andcommitted
Use non-variable stride per rank path for dynamo (#2018)
Summary: Pull Request resolved: #2018 Pick non-variable stride per rank path for dynamo for now. In future that will be solved with first batch path memorizaiton. Reviewed By: gnahzg Differential Revision: D57562279 fbshipit-source-id: 73bb55580875f0b11e2daf477d82a21fdd90220e
1 parent cc14baa commit 709f554

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torchrec/distributed/dist_data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,13 @@ def _get_recat_tensor_compute(
192192
+ (feature_order.expand(LS, FO_S0) * LS)
193193
).reshape(-1)
194194

195-
vb_condition = batch_size_per_rank is not None and any(
196-
bs != batch_size_per_rank[0] for bs in batch_size_per_rank
197-
)
195+
# Use non variable stride per rank path for dynamo
196+
# TODO(ivankobzarev): Implement memorization of the path from the first batch.
197+
vb_condition = False
198+
if not is_torchdynamo_compiling():
199+
vb_condition = batch_size_per_rank is not None and any(
200+
bs != batch_size_per_rank[0] for bs in batch_size_per_rank
201+
)
198202

199203
if vb_condition:
200204
batch_size_per_rank_tensor = torch._refs.tensor(

0 commit comments

Comments
 (0)