Skip to content

Commit

Permalink
Mark weights unbacked (pytorch#2583)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2583

This is to avoid recompilations caused by the shape changes of `_weights` in KJT.

Reviewed By: TroyGarden

Differential Revision: D66342695

fbshipit-source-id: 044db94404b285367bcbbe5d4b513bae5a463810
  • Loading branch information
Microve authored and facebook-github-bot committed Nov 22, 2024
1 parent 2962be0 commit e6fe5de
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchrec/pt2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ def kjt_for_pt2_tracing(

values = kjt.values().long()
torch._dynamo.decorators.mark_unbacked(values, 0)
weights = kjt.weights_or_none()
if weights is not None:
torch._dynamo.decorators.mark_unbacked(weights, 0)

return KeyedJaggedTensor(
keys=kjt.keys(),
values=values,
lengths=lengths,
weights=kjt.weights_or_none(),
weights=weights,
stride=stride if not is_vb else None,
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
inverse_indices=inverse_indices,
Expand Down

0 comments on commit e6fe5de

Please sign in to comment.