From 45f9a334020845920c6dedae5c4081d20b9cc5df Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 5 Aug 2024 20:58:09 -0700 Subject: [PATCH] Use KTRegroupAsDict to Replace KeyedTensor.regroup_as_dict in EBC_sparse_arch (#2272) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2272 # context * Currently in APF a class method `KeyedTensor.regroup_as_dict` is used for permuting and regrouping the pooled embeddings * This function is not very efficent in training because every time it needs to calculate the necessary metadata arguments for the fbgemm operator. Moreover, it also needs to do a host-to-device data transfer to move the metadata tensor to GPU: [codepointer](https://fburl.com/code/fmhkg6sn) ``` # // metadata calculation for permute, offsets, etc. all are tensors permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( keys, lengths, groups ) values = torch.concat(values, dim=1) device = values.device permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( values, _pin_and_move(offsets, device), # // needs to use pinned_memory and _pin_and_move(permute, device), # // initiate several H2D transfers _pin_and_move(inv_offsets, device), _pin_and_move(inv_permute, device), ) ``` * However, these metadata (Tensors) won't change during training, so we can actually cache the results for the first batch and re-use them. * The recommended usage is `KTRegroupAsDict`, as a module for permuting and regrouping work. This module can store these metadata tensors as instance variables in the first run of the forward pass, then re-used them directly afterwards. NOTE: This `KTRegroupAsDict` module is also IR-compatible with custom-op approach in D57578012 # numbers with IG FM model * [baseline trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2Faps-ig_ctr_aps_vanilla_baseline-a6e954af60%2F1%2Frank-0.Aug_04_00_01_16.3621.pt.trace.json.gz&bucket=aps_traces), [experimental trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2Faps-ig_ctr_aps_vanilla_both_diffs-3503324cb2%2F4%2Frank-0.Aug_03_20_16_01.3608.pt.trace.json.gz&bucket=aps_traces) * the GPU runtime saving is about 1ms, which is consistent with the benchmark results. the overall duration (per batch) is 200ms. so the improvement is about 0.5% {F1792260523} * cpu runtime saving is 1.3ms {F1792260686} # additional notes 1. we don't expect any results change, these two approaches should produce exactly the same results. 2. regroup_as_dict works on pooled embeddings, which are the results from embedding table lookup. so for each feature, the length should be constant as it's defined in the config. Reviewed By: really121 Differential Revision: D43405610 --- torchrec/modules/tests/test_regroup.py | 39 ++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/torchrec/modules/tests/test_regroup.py b/torchrec/modules/tests/test_regroup.py index 4f00b99c1..14a79605e 100644 --- a/torchrec/modules/tests/test_regroup.py +++ b/torchrec/modules/tests/test_regroup.py @@ -33,6 +33,17 @@ def setUp(self) -> None: self.keys = ["user", "object"] self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + def new_kts(self) -> None: + self.kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + def test_regroup_backward_skips_and_duplicates(self) -> None: groups = build_groups( kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True @@ -40,6 +51,34 @@ def test_regroup_backward_skips_and_duplicates(self) -> None: assert _all_keys_used_once(self.kts, groups) is False regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + + # first run + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + # second run + self.new_kts() tensor_groups = regroup_module(self.kts) pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()