From 709f5547402b4590ab9569313bd14be963efe6ed Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Mon, 20 May 2024 10:20:40 -0700 Subject: [PATCH] Use non-variable stride per rank path for dynamo (#2018) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2018 Pick non-variable stride per rank path for dynamo for now. In future that will be solved with first batch path memorizaiton. Reviewed By: gnahzg Differential Revision: D57562279 fbshipit-source-id: 73bb55580875f0b11e2daf477d82a21fdd90220e --- torchrec/distributed/dist_data.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 2b3a24f52..2bcc6a16f 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -192,9 +192,13 @@ def _get_recat_tensor_compute( + (feature_order.expand(LS, FO_S0) * LS) ).reshape(-1) - vb_condition = batch_size_per_rank is not None and any( - bs != batch_size_per_rank[0] for bs in batch_size_per_rank - ) + # Use non variable stride per rank path for dynamo + # TODO(ivankobzarev): Implement memorization of the path from the first batch. + vb_condition = False + if not is_torchdynamo_compiling(): + vb_condition = batch_size_per_rank is not None and any( + bs != batch_size_per_rank[0] for bs in batch_size_per_rank + ) if vb_condition: batch_size_per_rank_tensor = torch._refs.tensor(