diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4f1eaa44c..e616ae16e 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -191,13 +191,15 @@ def _arange(*args, **kwargs) -> torch.Tensor: return torch.arange(*args, **kwargs) -def _permute_variable_stride_values( - values: torch.Tensor, - length_per_key: torch.Tensor, +def _permute_tensor_by_segments( + tensor: torch.Tensor, + segment_sizes: torch.Tensor, recat: torch.Tensor, - weights: Optional[torch.Tensor], + weights: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ + Permutes a tensor by segments according to recat tensor. + For variable stride tensors we permute across length per key, which reduces the number of permute indices and lengthens each sequence. `keyed_jagged_index_select_dim1` more efficiently parallelizes work for each permute @@ -206,31 +208,30 @@ def _permute_variable_stride_values( NOTE: `keyed_jagged_index_select_dim1` is only supported for CUDA. """ - if values.device.type == "cuda": + if tensor.device.type == "cuda": output = torch.ops.fbgemm.keyed_jagged_index_select_dim1( - values, - length_per_key, - _to_offsets(length_per_key), + tensor, + segment_sizes, + _to_offsets(segment_sizes), recat, - len(length_per_key), + segment_sizes.numel(), weights, - # TODO: add selected_lengths_sum once landed to prevent D2H sync ) - permuted_values = output[0] + permuted_tensor = output[0] permuted_weights = None if weights is None else output[2] else: ( _, - permuted_values, + permuted_tensor, permuted_weights, ) = torch.ops.fbgemm.permute_1D_sparse_data( recat, - length_per_key, - values, + segment_sizes, + tensor, weights, None, ) - return permuted_values, permuted_weights + return permuted_tensor, permuted_weights class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): @@ -1703,14 +1704,13 @@ def permute( stride_per_key_tensor = _pin_and_move( torch.tensor(self.stride_per_key()), self.device() ) - (_, permuted_lengths, _,) = torch.ops.fbgemm.permute_1D_sparse_data( - indices_tensor, - stride_per_key_tensor, + permuted_lengths, _ = _permute_tensor_by_segments( self.lengths(), - None, + stride_per_key_tensor, + indices_tensor, None, ) - permuted_values, permuted_weights = _permute_variable_stride_values( + permuted_values, permuted_weights = _permute_tensor_by_segments( self.values(), length_per_key_tensor, indices_tensor, @@ -1987,14 +1987,13 @@ def dist_init( ) with record_function("## all2all_data:recat_values ##"): if recat is not None and recat.numel() > 0: - (_, lengths, _,) = torch.ops.fbgemm.permute_1D_sparse_data( - recat, - stride_per_rank_per_key, + lengths, _ = _permute_tensor_by_segments( lengths, - None, + stride_per_rank_per_key, + recat, None, ) - values, weights = _permute_variable_stride_values( + values, weights = _permute_tensor_by_segments( values, length_per_key, recat,