Skip to content

Commit

Permalink
2024-10-19 nightly release (dbca437)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 19, 2024
1 parent 4baf8ee commit 74d4d68
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def compute(
values=features.values(),
lengths=features.lengths(),
# TODO: improve this temp solution by passing real weights
weights=torch.tensor(kjt.length_per_key()),
weights=torch.tensor(features.length_per_key()),
)
}
mcm = self._managed_collision_modules[table]
Expand Down
30 changes: 27 additions & 3 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Deque,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -911,17 +912,40 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams
)
for stream, emb_tensors, detached_emb_tensors in zip(
assert len(context.embedding_features) == len(context.embedding_tensors)
for stream, emb_tensors, embedding_features, detached_emb_tensors in zip(
streams,
context.embedding_tensors,
context.embedding_features,
context.detached_embedding_tensors,
):
with self._stream_context(stream):
grads = [tensor.grad for tensor in detached_emb_tensors]
if stream:
stream.wait_stream(default_stream)
# pyre-ignore
torch.autograd.backward(emb_tensors, grads)
# Some embeddings may never get used in the final loss computation,
# so the grads will be `None`. If we don't exclude these, it will fail
# with error: "grad can be implicitly created only for scalar outputs"
# Alternatively, if the tensor has only 1 element, pytorch can still
# figure out how to do autograd
embs_to_backprop, grads_to_use, invalid_features = [], [], []
assert len(embedding_features) == len(emb_tensors)
for features, tensor, grad in zip(
embedding_features, emb_tensors, grads
):
if tensor.numel() == 1 or grad is not None:
embs_to_backprop.append(tensor)
grads_to_use.append(grad)
else:
if isinstance(features, Iterable):
invalid_features.extend(features)
else:
invalid_features.append(features)
if invalid_features and context.index == 0:
logger.warning(
f"SemiSync, the following features have no gradients: {invalid_features}"
)
torch.autograd.backward(embs_to_backprop, grads_to_use)

def copy_batch_to_gpu(
self,
Expand Down
10 changes: 10 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class PrefetchTrainPipelineContext(TrainPipelineContext):
class EmbeddingTrainPipelineContext(TrainPipelineContext):
embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict)
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list)
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)


Expand Down Expand Up @@ -408,6 +409,8 @@ def __call__(self, *input, **kwargs) -> Awaitable:
# pyre-ignore [16]
self._context.embedding_tensors.append(tensors)
# pyre-ignore [16]
self._context.embedding_features.append(list(embeddings.keys()))
# pyre-ignore [16]
self._context.detached_embedding_tensors.append(detached_tensors)
else:
assert isinstance(embeddings, KeyedTensor)
Expand All @@ -418,6 +421,13 @@ def __call__(self, *input, **kwargs) -> Awaitable:
tensors.append(tensor)
detached_tensors.append(detached_tensor)
self._context.embedding_tensors.append(tensors)
# KeyedTensor is returned by EmbeddingBagCollections and its variants
# KeyedTensor holds dense data from multiple features and .values()
# returns a single concatenated dense tensor. To ensure that
# context.embedding_tensors[i] has the same length as
# context.embedding_features[i], we pass in a list with a single item:
# a list containing all the embedding feature names.
self._context.embedding_features.append([list(embeddings.keys())])
self._context.detached_embedding_tensors.append(detached_tensors)

return LazyNoWait(embeddings)
Expand Down

0 comments on commit 74d4d68

Please sign in to comment.