Skip to content

Commit

Permalink
Support prefetching for SSD TBE lookup (pytorch#2275)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2275

Currently, we cannot use prefetch pipeline with SSD-based TBE. This diff adds the requires changes in torchrec code to support this.

Differential Revision: D60838580
  • Loading branch information
sarckk authored and facebook-github-bot committed Aug 7, 2024
1 parent 07dd9b9 commit a39ac29
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.ssd.training import SSDTableBatchedEmbeddingBags
from torch import nn

from torch.autograd.function import FunctionCtx
Expand Down Expand Up @@ -182,7 +183,10 @@ def _create_lookup(
config: GroupedEmbeddingConfig,
) -> BaseEmbedding:
for table in config.embedding_tables:
if table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING:
if (
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE
):
self._need_prefetch = True
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
return BatchedDenseEmbedding(
Expand Down Expand Up @@ -254,11 +258,18 @@ def prefetch(
"If you don’t turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n"
)
if hasattr(emb_op.emb_module, "prefetch"):
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
forward_stream=forward_stream,
)
if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags):
# only takes indices and offsets
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
)
else:
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
forward_stream=forward_stream,
)

def forward(
self,
Expand Down Expand Up @@ -455,7 +466,10 @@ def prefetch(
) -> None:
def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
for table in config.embedding_tables:
if table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING:
if (
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE
):
return True
return False

Expand All @@ -476,16 +490,23 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
"If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n"
)
if hasattr(emb_op.emb_module, "prefetch"):
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
forward_stream=forward_stream,
batch_size_per_feature_per_rank=(
features.stride_per_key_per_rank()
if features.variable_stride_per_key()
else None
),
)
if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags):
# only takes indices and offsets
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
)
else:
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
forward_stream=forward_stream,
batch_size_per_feature_per_rank=(
features.stride_per_key_per_rank()
if features.variable_stride_per_key()
else None
),
)

def _merge_variable_batch_embeddings(
self, embeddings: List[torch.Tensor], splits: List[List[int]]
Expand Down

0 comments on commit a39ac29

Please sign in to comment.