Skip to content

Commit

Permalink
use new op in KTRegroupAsDict module (pytorch#2210)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2210

# context
* the new op `permute_multi_embedding` outperforms the original op `permute_pooled_embs_auto_grad`
* this diff makes the move to switch to the new op
* benchmark results: D58907223

# benchmark
* [traces](https://drive.google.com/drive/folders/1v_kD9n1jOkGUmYyix3-dUYiBDE_C3Hiv?usp=drive_link)
* previous prod
 {F1747994738}
* new prod
 {F1747994032}
* metrics
|Operator|GPU runtime|GPU memory|notes|
|---|---|---|---|---|
|**[previous prod] permute_pooled_embs**|4.9 ms|1.5 K|GPU-boudned, does **NOT** allow duplicates, PT2 non-compatible `pin_and_move`|
|**[new prod] permute_multi_embedding**|2.0 ms|1.0 K|both CPU and GPU runtime/memory improved, **ALLOW** duplicates, PT2 friendly|

Reviewed By: dstaay-fb

Differential Revision: D53590566
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 7, 2024
1 parent 07dd9b9 commit 954d652
Showing 1 changed file with 56 additions and 44 deletions.
100 changes: 56 additions & 44 deletions torchrec/modules/regroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@

#!/usr/bin/env python3

from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from torchrec.sparse.jagged_tensor import (
_all_keys_used_once,
_desugar_keyed_tensors,
_remap_to_groups,
_kt_regroup_arguments,
KeyedTensor,
)


@torch.fx.wrap
def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor:
return torch.cat([kt.values() for kt in kts], dim=dim)
def _get_kts_values(kts: List[KeyedTensor]) -> List[torch.Tensor]:
return [kt.values() for kt in kts]


@torch.fx.wrap
Expand All @@ -36,11 +35,34 @@ def _permuted_values(

@torch.fx.wrap
def _build_dict(
keys: List[str], values: torch.Tensor, splits: List[int], dim: int
keys: List[str],
values: Union[torch.Tensor, List[torch.Tensor]],
splits: List[int],
dim: int,
) -> Dict[str, torch.Tensor]:
return {
key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim))
}
if isinstance(values, torch.Tensor):
return dict(zip(keys, torch.split(values, splits, dim=dim)))
else:
return dict(zip(keys, values))


@torch.fx.wrap
def module_init(module: "KTRegroupAsDict", keyed_tensors: List[KeyedTensor]) -> None:
assert len(keyed_tensors) > 0, "Empty list provided"
assert all(
kt.device() == keyed_tensors[0].device() for kt in keyed_tensors
), "All inputs should be on the same device."
module.device = keyed_tensors[0].device()
assert all(
kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors
), "All inputs should have the same key_dim"
module._dim = keyed_tensors[0].key_dim()

if module._dim == 1:
module._init_fbgemm_regroup(keyed_tensors)
else:
module._init_regroup(keyed_tensors)
module._is_inited = True


class KTRegroupAsDict(torch.nn.Module):
Expand Down Expand Up @@ -76,27 +98,32 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None:

# cached values populated on first forward call
self.device: Optional[torch.device] = None
self._concat_dim: int = 1
self._dim: int = 1
self._use_fbgemm_regroup: bool = False
self._splits: List[int] = []
self._idx_key_pairs: List[Tuple[int, str]] = []
self._permute_tensor: Optional[torch.Tensor] = None
self._inv_permute_tensor: Optional[torch.Tensor] = None
self._offsets_tensor: Optional[torch.Tensor] = None
self._inv_offsets_tensor: Optional[torch.Tensor] = None
self.register_buffer(
"_permutes", torch.empty(0, device=self.device), persistent=True
)
self.register_buffer(
"_in_shapes", torch.empty(0, device=self.device), persistent=True
)
self.register_buffer(
"_out_shapes", torch.empty(0, device=self.device), persistent=True
)
self._out_lengths: Optional[List[int]] = None

def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None:
self._use_fbgemm_regroup = True
keys, lengths, values = _desugar_keyed_tensors(kts)
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
keys, lengths, self._groups
self._permutes, self._in_shapes, self._out_shapes, self._out_lengths = (
_kt_regroup_arguments(
values[0],
keys,
lengths,
self._groups,
)
)
# no need to pin_memory() or to(..., non_blocking=True) since occurs only once
self._permute_tensor = permute.to(self.device)
self._inv_permute_tensor = inv_permute.to(self.device)
self._offsets_tensor = offsets.to(self.device)
self._inv_offsets_tensor = inv_offsets.to(self.device)
self._splits = splits

def _init_regroup(self, kts: List[KeyedTensor]) -> None:
lengths = [kt.length_per_key() for kt in kts]
Expand Down Expand Up @@ -127,34 +154,19 @@ def _init_regroup(self, kts: List[KeyedTensor]) -> None:

def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
if not self._is_inited:
assert len(keyed_tensors) > 0, "Empty list provided"
assert all(
kt.device() == keyed_tensors[0].device() for kt in keyed_tensors
), "All inputs should be on the same device."
self.device = keyed_tensors[0].device()
assert all(
kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors
), "All inputs should have the same key_dim"
self._dim = keyed_tensors[0].key_dim()

if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1:
self._init_fbgemm_regroup(keyed_tensors)
else:
self._init_regroup(keyed_tensors)
self._is_inited = True
module_init(self, keyed_tensors)

if self._use_fbgemm_regroup:
values = _concat_values(keyed_tensors, self._dim)
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
values = _get_kts_values(keyed_tensors)
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
values,
self._offsets_tensor,
self._permute_tensor,
self._inv_offsets_tensor,
self._inv_permute_tensor,
self._permutes,
self._in_shapes,
self._out_shapes,
self._out_lengths,
)
else:
permuted_values = _permuted_values(
keyed_tensors, self._idx_key_pairs, self._dim
)

return _build_dict(self._keys, permuted_values, self._splits, self._dim)

0 comments on commit 954d652

Please sign in to comment.