From e6fe5deb9cc5324e7cdb718cc332dd110b395914 Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Fri, 22 Nov 2024 13:05:36 -0800 Subject: [PATCH] Mark weights unbacked (#2583) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2583 This is to avoid recompilations caused by the shape changes of `_weights` in KJT. Reviewed By: TroyGarden Differential Revision: D66342695 fbshipit-source-id: 044db94404b285367bcbbe5d4b513bae5a463810 --- torchrec/pt2/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index e62a9a6a4..55accff68 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -75,12 +75,15 @@ def kjt_for_pt2_tracing( values = kjt.values().long() torch._dynamo.decorators.mark_unbacked(values, 0) + weights = kjt.weights_or_none() + if weights is not None: + torch._dynamo.decorators.mark_unbacked(weights, 0) return KeyedJaggedTensor( keys=kjt.keys(), values=values, lengths=lengths, - weights=kjt.weights_or_none(), + weights=weights, stride=stride if not is_vb else None, stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, inverse_indices=inverse_indices,