diff --git a/torchrec/modules/tests/test_regroup.py b/torchrec/modules/tests/test_regroup.py index 4f00b99c1..14a79605e 100644 --- a/torchrec/modules/tests/test_regroup.py +++ b/torchrec/modules/tests/test_regroup.py @@ -33,6 +33,17 @@ def setUp(self) -> None: self.keys = ["user", "object"] self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + def new_kts(self) -> None: + self.kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + def test_regroup_backward_skips_and_duplicates(self) -> None: groups = build_groups( kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True @@ -40,6 +51,34 @@ def test_regroup_backward_skips_and_duplicates(self) -> None: assert _all_keys_used_once(self.kts, groups) is False regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + + # first run + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + # second run + self.new_kts() tensor_groups = regroup_module(self.kts) pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()