diff --git a/torchrec/modules/regroup.py b/torchrec/modules/regroup.py index 4fcf590d0..e73bb6452 100644 --- a/torchrec/modules/regroup.py +++ b/torchrec/modules/regroup.py @@ -9,20 +9,19 @@ #!/usr/bin/env python3 -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torchrec.sparse.jagged_tensor import ( - _all_keys_used_once, _desugar_keyed_tensors, - _remap_to_groups, + _kt_regroup_arguments, KeyedTensor, ) @torch.fx.wrap -def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor: - return torch.cat([kt.values() for kt in kts], dim=dim) +def _get_kts_values(kts: List[KeyedTensor]) -> List[torch.Tensor]: + return [kt.values() for kt in kts] @torch.fx.wrap @@ -36,11 +35,34 @@ def _permuted_values( @torch.fx.wrap def _build_dict( - keys: List[str], values: torch.Tensor, splits: List[int], dim: int + keys: List[str], + values: Union[torch.Tensor, List[torch.Tensor]], + splits: List[int], + dim: int, ) -> Dict[str, torch.Tensor]: - return { - key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim)) - } + if isinstance(values, torch.Tensor): + return dict(zip(keys, torch.split(values, splits, dim=dim))) + else: + return dict(zip(keys, values)) + + +@torch.fx.wrap +def module_init(module: "KTRegroupAsDict", keyed_tensors: List[KeyedTensor]) -> None: + assert len(keyed_tensors) > 0, "Empty list provided" + assert all( + kt.device() == keyed_tensors[0].device() for kt in keyed_tensors + ), "All inputs should be on the same device." + module.device = keyed_tensors[0].device() + assert all( + kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors + ), "All inputs should have the same key_dim" + module._dim = keyed_tensors[0].key_dim() + + if module._dim == 1: + module._init_fbgemm_regroup(keyed_tensors) + else: + module._init_regroup(keyed_tensors) + module._is_inited = True class KTRegroupAsDict(torch.nn.Module): @@ -76,27 +98,32 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None: # cached values populated on first forward call self.device: Optional[torch.device] = None - self._concat_dim: int = 1 + self._dim: int = 1 self._use_fbgemm_regroup: bool = False self._splits: List[int] = [] self._idx_key_pairs: List[Tuple[int, str]] = [] - self._permute_tensor: Optional[torch.Tensor] = None - self._inv_permute_tensor: Optional[torch.Tensor] = None - self._offsets_tensor: Optional[torch.Tensor] = None - self._inv_offsets_tensor: Optional[torch.Tensor] = None + self.register_buffer( + "_permutes", torch.empty(0, device=self.device), persistent=True + ) + self.register_buffer( + "_in_shapes", torch.empty(0, device=self.device), persistent=True + ) + self.register_buffer( + "_out_shapes", torch.empty(0, device=self.device), persistent=True + ) + self._out_lengths: Optional[List[int]] = None def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None: self._use_fbgemm_regroup = True keys, lengths, values = _desugar_keyed_tensors(kts) - permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( - keys, lengths, self._groups + self._permutes, self._in_shapes, self._out_shapes, self._out_lengths = ( + _kt_regroup_arguments( + values[0], + keys, + lengths, + self._groups, + ) ) - # no need to pin_memory() or to(..., non_blocking=True) since occurs only once - self._permute_tensor = permute.to(self.device) - self._inv_permute_tensor = inv_permute.to(self.device) - self._offsets_tensor = offsets.to(self.device) - self._inv_offsets_tensor = inv_offsets.to(self.device) - self._splits = splits def _init_regroup(self, kts: List[KeyedTensor]) -> None: lengths = [kt.length_per_key() for kt in kts] @@ -127,34 +154,19 @@ def _init_regroup(self, kts: List[KeyedTensor]) -> None: def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]: if not self._is_inited: - assert len(keyed_tensors) > 0, "Empty list provided" - assert all( - kt.device() == keyed_tensors[0].device() for kt in keyed_tensors - ), "All inputs should be on the same device." - self.device = keyed_tensors[0].device() - assert all( - kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors - ), "All inputs should have the same key_dim" - self._dim = keyed_tensors[0].key_dim() - - if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1: - self._init_fbgemm_regroup(keyed_tensors) - else: - self._init_regroup(keyed_tensors) - self._is_inited = True + module_init(self, keyed_tensors) if self._use_fbgemm_regroup: - values = _concat_values(keyed_tensors, self._dim) - permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( + values = _get_kts_values(keyed_tensors) + permuted_values = torch.ops.fbgemm.permute_multi_embedding( values, - self._offsets_tensor, - self._permute_tensor, - self._inv_offsets_tensor, - self._inv_permute_tensor, + self._permutes, + self._in_shapes, + self._out_shapes, + self._out_lengths, ) else: permuted_values = _permuted_values( keyed_tensors, self._idx_key_pairs, self._dim ) - return _build_dict(self._keys, permuted_values, self._splits, self._dim)