Skip to content

Commit

Permalink
EC with ZCH (#1580)
Browse files Browse the repository at this point in the history
Summary:

Allow ZCH module to also use EC

Reviewed By: dstaay-fb

Differential Revision: D51872310
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Jan 5, 2024
1 parent 1465cfd commit 594c347
Show file tree
Hide file tree
Showing 2 changed files with 442 additions and 248 deletions.
121 changes: 97 additions & 24 deletions torchrec/modules/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,34 @@
# LICENSE file in the root directory of this source tree.


from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn as nn

from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_modules import (
EmbeddingBagCollection,
EmbeddingCollection,
)
from torchrec.modules.mc_modules import ManagedCollisionCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


def evict(
evictions: Dict[str, Optional[torch.Tensor]], ebc: EmbeddingBagCollection
evictions: Dict[str, Optional[torch.Tensor]],
ebc: Union[EmbeddingBagCollection, EmbeddingCollection],
) -> None:
# TODO: write function
return


class ManagedCollisionEmbeddingBagCollection(nn.Module):
class BaseManagedCollisionEmbeddingCollection(nn.Module):
"""
ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules.
The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection.
For details of input and output types, see EmbeddingBagCollection
BaseManagedCollisionEmbeddingCollection represents a EC/EBC module and a set of managed collision modules.
The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection.
Args:
embedding_bag_collection: EmbeddingBagCollection to lookup embeddings
embedding_module: EmbeddingCollection to lookup embeddings
managed_collision_modules: Dict of managed collision modules
return_remapped_features (bool): whether to return remapped input features
in addition to embeddings
Expand All @@ -40,33 +42,104 @@ class ManagedCollisionEmbeddingBagCollection(nn.Module):

def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection],
managed_collision_collection: ManagedCollisionCollection,
return_remapped_features: bool = False,
) -> None:
super().__init__()
self._embedding_bag_collection = embedding_bag_collection
self._managed_collision_collection = managed_collision_collection
self._return_remapped_features = return_remapped_features

assert (
self._embedding_bag_collection.embedding_bag_configs()
== self._managed_collision_collection.embedding_configs()
), "Embedding Collection and Managed Collision Collection must contain the Embedding Configs"
self._embedding_module: Union[
EmbeddingBagCollection, EmbeddingCollection
] = embedding_module

if isinstance(embedding_module, EmbeddingBagCollection):
assert (
self._embedding_module.embedding_bag_configs()
== self._managed_collision_collection.embedding_configs()
), "Embedding Bag Collection and Managed Collision Collection must contain the Embedding Configs"

else:
assert (
self._embedding_module.embedding_configs()
== self._managed_collision_collection.embedding_configs()
), "Embedding Collection and Managed Collision Collection must contain the Embedding Configs"

def forward(
self,
features: KeyedJaggedTensor,
) -> Tuple[KeyedTensor, Optional[KeyedJaggedTensor]]:
) -> Tuple[
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
]:

features = self._managed_collision_collection(features)

pooled_embeddings = self._embedding_bag_collection(features)
embedding_res = self._embedding_module(features)

evict(
self._managed_collision_collection.evict(), self._embedding_bag_collection
)
evict(self._managed_collision_collection.evict(), self._embedding_module)

if not self._return_remapped_features:
return pooled_embeddings, None
return pooled_embeddings, features
return embedding_res, None
return embedding_res, features


class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollection):
"""
ManagedCollisionEmbeddingCollection represents a EmbeddingCollection module and a set of managed collision modules.
The inputs into the MC-EC will first be modified by the managed collision module before being passed into the embedding collection.
For details of input and output types, see EmbeddingCollection
Args:
embedding_module: EmbeddingCollection to lookup embeddings
managed_collision_modules: Dict of managed collision modules
return_remapped_features (bool): whether to return remapped input features
in addition to embeddings
"""

def __init__(
self,
embedding_collection: EmbeddingCollection,
managed_collision_collection: ManagedCollisionCollection,
return_remapped_features: bool = False,
) -> None:
super().__init__(
embedding_collection, managed_collision_collection, return_remapped_features
)

# For consistency with embedding bag collection
self._embedding_collection: EmbeddingCollection = embedding_collection


class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollection):
"""
ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules.
The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection.
For details of input and output types, see EmbeddingBagCollection
Args:
embedding_module: EmbeddingBagCollection to lookup embeddings
managed_collision_modules: Dict of managed collision modules
return_remapped_features (bool): whether to return remapped input features
in addition to embeddings
"""

def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
managed_collision_collection: ManagedCollisionCollection,
return_remapped_features: bool = False,
) -> None:
super().__init__(
embedding_bag_collection,
managed_collision_collection,
return_remapped_features,
)

# For backwards compat, as references existed in tests
self._embedding_bag_collection: EmbeddingBagCollection = (
embedding_bag_collection
)
Loading

0 comments on commit 594c347

Please sign in to comment.