Skip to content

Commit

Permalink
use keyed jagged index select for lengths permute
Browse files Browse the repository at this point in the history
Summary:
we've seen better performance and gpu util from switching to keyed_jagged_index_select_dim1 for permuting values.

the traces also show that there is further performance gains to be had by switching lengths permute to use keyed_jagged_index_select_dim1 as well.

for example:
the permute 1d lengths takes the most time, while in the past the permute 1d for values took even longer but with keyed_jagged_index_select_dim1 can barely be seen
 {F1456534243}
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/aps-combo2_uhm_igfm_200x_baseline-ac3cbf3708/0/rank-1.Feb_08_22_34_41.5280.pt.trace.json.gz&bucket=aps_traces

Differential Revision: D53776313
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Feb 20, 2024
1 parent d1722e9 commit 7f7fa6b
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ def _arange(*args, **kwargs) -> torch.Tensor:
return torch.arange(*args, **kwargs)


def _permute_variable_stride_values(
values: torch.Tensor,
length_per_key: torch.Tensor,
def _permute_tensor_by_segments(
tensor: torch.Tensor,
segment_sizes: torch.Tensor,
recat: torch.Tensor,
weights: Optional[torch.Tensor],
weights: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Permutes a tensor by segments according to recat tensor.
For variable stride tensors we permute across length per key, which reduces the
number of permute indices and lengthens each sequence.
`keyed_jagged_index_select_dim1` more efficiently parallelizes work for each permute
Expand All @@ -206,31 +208,30 @@ def _permute_variable_stride_values(
NOTE:
`keyed_jagged_index_select_dim1` is only supported for CUDA.
"""
if values.device.type == "cuda":
if tensor.device.type == "cuda":
output = torch.ops.fbgemm.keyed_jagged_index_select_dim1(
values,
length_per_key,
_to_offsets(length_per_key),
tensor,
segment_sizes,
_to_offsets(segment_sizes),
recat,
len(length_per_key),
segment_sizes.numel(),
weights,
# TODO: add selected_lengths_sum once landed to prevent D2H sync
)
permuted_values = output[0]
permuted_tensor = output[0]
permuted_weights = None if weights is None else output[2]
else:
(
_,
permuted_values,
permuted_tensor,
permuted_weights,
) = torch.ops.fbgemm.permute_1D_sparse_data(
recat,
length_per_key,
values,
segment_sizes,
tensor,
weights,
None,
)
return permuted_values, permuted_weights
return permuted_tensor, permuted_weights


class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
Expand Down Expand Up @@ -1703,14 +1704,13 @@ def permute(
stride_per_key_tensor = _pin_and_move(
torch.tensor(self.stride_per_key()), self.device()
)
(_, permuted_lengths, _,) = torch.ops.fbgemm.permute_1D_sparse_data(
indices_tensor,
stride_per_key_tensor,
permuted_lengths, _ = _permute_tensor_by_segments(
self.lengths(),
None,
stride_per_key_tensor,
indices_tensor,
None,
)
permuted_values, permuted_weights = _permute_variable_stride_values(
permuted_values, permuted_weights = _permute_tensor_by_segments(
self.values(),
length_per_key_tensor,
indices_tensor,
Expand Down Expand Up @@ -1987,14 +1987,13 @@ def dist_init(
)
with record_function("## all2all_data:recat_values ##"):
if recat is not None and recat.numel() > 0:
(_, lengths, _,) = torch.ops.fbgemm.permute_1D_sparse_data(
recat,
stride_per_rank_per_key,
lengths, _ = _permute_tensor_by_segments(
lengths,
None,
stride_per_rank_per_key,
recat,
None,
)
values, weights = _permute_variable_stride_values(
values, weights = _permute_tensor_by_segments(
values,
length_per_key,
recat,
Expand Down

0 comments on commit 7f7fa6b

Please sign in to comment.