Skip to content

Commit

Permalink
Use non-variable stride per rank path for dynamo (#2018)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2018

Pick non-variable stride per rank path for dynamo for now.
In future that will be solved with first batch path memorizaiton.

Reviewed By: gnahzg

Differential Revision: D57562279

fbshipit-source-id: 73bb55580875f0b11e2daf477d82a21fdd90220e
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 20, 2024
1 parent cc14baa commit 709f554
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ def _get_recat_tensor_compute(
+ (feature_order.expand(LS, FO_S0) * LS)
).reshape(-1)

vb_condition = batch_size_per_rank is not None and any(
bs != batch_size_per_rank[0] for bs in batch_size_per_rank
)
# Use non variable stride per rank path for dynamo
# TODO(ivankobzarev): Implement memorization of the path from the first batch.
vb_condition = False
if not is_torchdynamo_compiling():
vb_condition = batch_size_per_rank is not None and any(
bs != batch_size_per_rank[0] for bs in batch_size_per_rank
)

if vb_condition:
batch_size_per_rank_tensor = torch._refs.tensor(
Expand Down

0 comments on commit 709f554

Please sign in to comment.