Skip to content

Commit

Permalink
Add feature pools to torchrec OSS (pytorch#2126)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2126

We are open sourcing new TorchRec modules for fast, scalable and efficient indexing of tensors:  TensorPool and KeyedJaggedTensorPool for dense and sparse tensors respectively.

The proposed modules provide abstractions for reading and writing large tensor and KeyedJaggedTensor values, with support for sharding and flexible data emplacement (e.g. HBM, UVM, CPU, etc). They expose APIs to update and look up values based on arbitrary indices, and support sharding to distribute the tensors across multiple devices, abstracting away the collective communications for distributed lookup and updates.

# Motivation
When working with recommender systems, there is often a need to transform or augment the model’s feature inputs in various ways. For example, when training retrieval/candidate generation models, it is common to extend the training data with negative samples. In the context of video recommendation, negative samples might be the IDs of videos that the user did not click on (i.e. **hard negative samples**). Retrieval models are then trained to produce a list of positive samples as candidates for further ranking downstream.

In such cases, it may not be practical to store all the necessary features in the batched data. For example, during candidate generation, extracting features for a large corpus of candidate items may be prohibitively expensive. Instead, auxiliary features can be stored in memory and indexed to efficiently lookup features when needed to augment the given samples during training or inference.

These modules can also be used to implement a distributed cache for embeddings that supports index-based lookup and updates.

Note: this is joint work from various technical contributors: xing-liu strisunshinewentingwang murphymatt Michael-JY-He YLGH jiayisuse gnahzg yanxia hongweitian SeanXiaohengMao cz171

Reviewed By: joshuadeng, gnahzg

Differential Revision: D58355479
  • Loading branch information
sarckk authored and facebook-github-bot committed Jun 17, 2024
1 parent a9a5c06 commit e9b1bc2
Show file tree
Hide file tree
Showing 21 changed files with 6,158 additions and 3 deletions.
367 changes: 366 additions & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchrec.distributed.types import Awaitable, QuantizedCommCodecs, rank_device
from torchrec.fx.utils import fx_marker
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -1431,3 +1431,368 @@ def forward(
unbucketize_permute_tensor=unbucketize_permute_tensor,
embedding_dim=local_embs.shape[1],
)


class JaggedTensorAllToAll(Awaitable[JaggedTensor]):
"""
Redistributes `JaggedTensor` to a `ProcessGroup` along the batch dimension according
to the number of items to send and receive. The number of items to send
must be known ahead of time on each rank. This is currently used for sharded
KeyedJaggedTensorPool, after distributing the number of IDs to lookup or update on
each rank.
Implementation utilizes AlltoAll collective as part of torch.distributed.
Args:
jt (JaggedTensor): JaggedTensor to distribute.
num_items_to_send (int): Number of items to send.
num_items_to_receive (int): Number of items to receive from all other ranks.
This must be known ahead of time on each rank, usually via another AlltoAll.
pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication.
"""

def __init__(
self,
jt: JaggedTensor,
num_items_to_send: torch.Tensor,
num_items_to_receive: torch.Tensor,
pg: dist.ProcessGroup,
) -> None:
super().__init__()
self._workers: int = pg.size()

self._dist_lengths: torch.Tensor = torch.empty(
sum(num_items_to_receive),
device=jt.lengths().device,
dtype=jt.lengths().dtype,
)

dist.all_to_all_single(
self._dist_lengths,
jt.lengths(),
output_split_sizes=num_items_to_receive.tolist(),
input_split_sizes=num_items_to_send.tolist(),
group=pg,
async_op=False,
)

# below will calculate chunks sums e.g.
# num_batches_to_receive = [2,2]
# lengths = [2,3,1,1]
# output_splits = [5,2]
dist_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
num_items_to_receive
)
value_output_splits = torch.ops.fbgemm.segment_sum_csr(
1,
dist_id_offsets,
self._dist_lengths,
).tolist()

self._dist_values: torch.Tensor = torch.empty(
sum(value_output_splits),
dtype=jt.values().dtype,
device=jt.values().device,
)

# same as above, calculate chunk sums
id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(num_items_to_send)
value_input_splits = torch.ops.fbgemm.segment_sum_csr(
1,
id_offsets,
jt.lengths(),
).tolist()

self._dist_values_req: dist.Work = dist.all_to_all_single(
self._dist_values,
jt.values(),
output_split_sizes=value_output_splits,
input_split_sizes=value_input_splits,
group=pg,
async_op=True,
)

self._dist_weights: Optional[torch.Tensor] = None
self._dist_weights_req: Optional[dist.Work] = None
if jt.weights_or_none() is not None:
self._dist_weights = torch.empty(
sum(value_output_splits),
dtype=jt.weights().dtype,
device=jt.weights().device,
)

self._dist_weights_req = dist.all_to_all_single(
self._dist_weights,
jt.weights(),
output_split_sizes=value_output_splits,
input_split_sizes=value_input_splits,
group=pg,
async_op=True,
)

def _wait_impl(self) -> JaggedTensor:
self._dist_values_req.wait()
if self._dist_weights_req is not None:
self._dist_weights_req.wait()

return JaggedTensor(
values=self._dist_values,
lengths=self._dist_lengths,
weights=self._dist_weights,
)


class TensorAllToAllValuesAwaitable(Awaitable[torch.Tensor]):
def __init__(
self,
pg: dist.ProcessGroup,
input: torch.Tensor,
input_splits: torch.Tensor,
output_splits: torch.Tensor,
device: torch.device,
) -> None:
super().__init__()
self._workers: int = pg.size()
self._device: torch.device = device
self._input = input

self._dist_values: torch.Tensor
if self._workers == 1:
self._dist_values = input_splits
return
else:
if input.dim() > 1:
self._dist_values = torch.empty(
(sum(output_splits), input.shape[1]),
device=self._device,
dtype=input.dtype,
)
else:
self._dist_values = torch.empty(
sum(output_splits), device=self._device, dtype=input.dtype
)

with record_function("## all2all_data:ids ##"):
self._values_awaitable: dist.Work = dist.all_to_all_single(
output=self._dist_values,
input=input,
output_split_sizes=output_splits.tolist(),
input_split_sizes=input_splits.tolist(),
group=pg,
async_op=True,
)

def _wait_impl(self) -> torch.Tensor:
if self._workers > 1:
self._values_awaitable.wait()
return self._dist_values


class TensorAllToAllSplitsAwaitable(Awaitable[TensorAllToAllValuesAwaitable]):
def __init__(
self,
pg: dist.ProcessGroup,
input: torch.Tensor,
splits: torch.Tensor,
device: torch.device,
) -> None:
super().__init__()
self._workers: int = pg.size()
self._pg: dist.ProcessGroup = pg
self._device: torch.device = device
self._input = input
self._input_splits = splits

self._output_splits: torch.Tensor
if self._workers == 1:
self._output_splits = splits
return
else:
self._output_splits = torch.empty(
[self._workers],
device=device,
dtype=torch.int,
)

with record_function("## all2all_data:ids splits ##"):
self._num_ids_awaitable: dist.Work = dist.all_to_all_single(
output=self._output_splits,
input=splits,
group=pg,
async_op=True,
)

def _wait_impl(self) -> TensorAllToAllValuesAwaitable:
if self._workers > 1:
self._num_ids_awaitable.wait()

return TensorAllToAllValuesAwaitable(
pg=self._pg,
input=self._input,
input_splits=self._input_splits,
output_splits=self._output_splits,
device=self._device,
)


class TensorValuesAllToAll(nn.Module):
"""
Redistributes torch.Tensor to a `ProcessGroup` according to input and output splits.
Implementation utilizes AlltoAll collective as part of torch.distributed.
Args:
pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication.
Example::
tensor_vals_A2A = TensorValuesAllToAll(pg)
input_splits = torch.Tensor([1,2]) on rank0 and torch.Tensor([1,1]) on rank1
output_splits = torch.Tensor([1,1]) on rank0 and torch.Tensor([2,1]) on rank1
awaitable = tensor_vals_A2A(rank0_input, input_splits, output_splits)
where:
rank0_input is 3 x 3 torch.Tensor holding
[
[V1, V2, V3],
[V4, V5, V6],
[V7, V8, V9],
]
rank1_input is 2 x 3 torch.Tensor holding
[
[V10, V11, V12],
[V13, V14, V15],
]
rank0_output = awaitable.wait()
# where:
# rank0_output is torch.Tensor holding
[
[V1, V2, V3],
[V10, V11, V12],
]
# rank1_output is torch.Tensor holding
[
[V1, V2, V3],
[V4, V5, V6],
[V7, V8, V9],
]
"""

def __init__(
self,
pg: dist.ProcessGroup,
) -> None:
super().__init__()
self._pg: dist.ProcessGroup = pg

def forward(
self,
input: torch.Tensor,
input_splits: torch.Tensor,
output_splits: torch.Tensor,
) -> TensorAllToAllValuesAwaitable:
"""
Sends tensor to relevant `ProcessGroup` ranks.
Args:
input (torch.Tensor): `torch.Tensor` of values to distribute.
input_splits (torch.Tensor): tensor containing number of rows
to be sent to each rank. len(input_splits) must equal self._pg.size()
output_splits (torch.Tensor): tensor containing number of rows
to be received from each rank. len(output_splits) must equal self._pg.size()
Returns: `TensorAllToAllValuesAwaitable`
"""
with torch.no_grad():
return TensorAllToAllValuesAwaitable(
pg=self._pg,
input=input,
input_splits=input_splits,
output_splits=output_splits,
device=input.device,
)


class TensorAllToAll(nn.Module):
"""
Redistributes a 1D tensor to a `ProcessGroup` according to splits.
Implementation utilizes AlltoAll collective as part of torch.distributed.
The first collective call in `TensorAllToAllSplitsAwaitable` will transmit
splits to allocate correct space for the tensor values. The following collective
calls in `TensorAllToAllValuesAwaitable` will transmit the actual
tensor values asynchronously.
Args:
pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication.
Example::
tensor_A2A = TensorAllToAll(pg)
splits = torch.Tensor([1,1]) on rank0 and rank1
awaitable = tensor_A2A(rank0_input, splits)
where:
rank0_input is torch.Tensor holding
[
[V1, V2, V3],
[V4, V5, V6],
]
rank1_input is torch.Tensor holding
[
[V7, V8, V9],
[V10, V11, V12],
]
rank0_output = awaitable.wait().wait()
# where:
rank0_input is torch.Tensor holding
[
[V1, V2, V3],
[V7, V8, V9],
]
rank1_input is torch.Tensor holding
[
[V4, V5, V6],
[V10, V11, V12],
]
"""

def __init__(
self,
pg: dist.ProcessGroup,
) -> None:
super().__init__()
self._pg: dist.ProcessGroup = pg

def forward(
self,
input: torch.Tensor,
splits: torch.Tensor,
) -> TensorAllToAllSplitsAwaitable:
"""
Sends tensor to relevant `ProcessGroup` ranks.
The first wait will get the splits for the provided tensors and issue
tensors AlltoAll. The second wait will get the tensors.
Args:
input (torch.Tensor): `torch.Tensor` of values to distribute.
Returns:
Awaitable[TensorAllToAllValuesAwaitable]: awaitable of a `TensorAllToAllValuesAwaitable`.
"""
with torch.no_grad():
temp = TensorAllToAllSplitsAwaitable(
pg=self._pg,
input=input,
splits=splits,
device=input.device,
)
return temp
Loading

0 comments on commit e9b1bc2

Please sign in to comment.