diff --git a/torchrec/models/dlrm.py b/torchrec/models/dlrm.py index aef8b0d2f..98df57216 100644 --- a/torchrec/models/dlrm.py +++ b/torchrec/models/dlrm.py @@ -183,8 +183,10 @@ class InteractionArch(nn.Module): def __init__(self, num_sparse_features: int) -> None: super().__init__() self.F: int = num_sparse_features - self.triu_indices: torch.Tensor = torch.triu_indices( - self.F + 1, self.F + 1, offset=1 + self.register_buffer( + "triu_indices", + torch.triu_indices(self.F + 1, self.F + 1, offset=1), + persistent=False, ) def forward(