diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 2b3a24f52..2bcc6a16f 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -192,9 +192,13 @@ def _get_recat_tensor_compute( + (feature_order.expand(LS, FO_S0) * LS) ).reshape(-1) - vb_condition = batch_size_per_rank is not None and any( - bs != batch_size_per_rank[0] for bs in batch_size_per_rank - ) + # Use non variable stride per rank path for dynamo + # TODO(ivankobzarev): Implement memorization of the path from the first batch. + vb_condition = False + if not is_torchdynamo_compiling(): + vb_condition = batch_size_per_rank is not None and any( + bs != batch_size_per_rank[0] for bs in batch_size_per_rank + ) if vb_condition: batch_size_per_rank_tensor = torch._refs.tensor(