Skip to content

Commit

Permalink
Embedding Modules BC tests
Browse files Browse the repository at this point in the history
Summary: BC unit tests for embedding modules APIs

Reviewed By: iamzainhuda

Differential Revision: D62539188

fbshipit-source-id: de0b3e2a356a2016a79029fabda0f898991129ad
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Sep 12, 2024
1 parent 48d6eac commit ff6dc0a
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 0 deletions.
76 changes: 76 additions & 0 deletions torchrec/schema/api_tests/test_embedding_config_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import inspect
import unittest
from dataclasses import dataclass, field
from typing import Callable, List, Optional

import torch
from torchrec.modules.embedding_configs import (
DataType,
EmbeddingBagConfig,
EmbeddingConfig,
PoolingType,
)

from torchrec.schema.utils import is_signature_compatible


@dataclass
class StableEmbeddingBagConfig:
num_embeddings: int
embedding_dim: int
name: str = ""
data_type: DataType = DataType.FP32
feature_names: List[str] = field(default_factory=list)
weight_init_max: Optional[float] = None
weight_init_min: Optional[float] = None
num_embeddings_post_pruning: Optional[int] = None

init_fn: Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]] = None
# when the position_weighted feature is in this table config,
# enable this flag to support rw_sharding
need_pos: bool = False
pooling: PoolingType = PoolingType.SUM


@dataclass
class StableEmbeddingConfig:
num_embeddings: int
embedding_dim: int
name: str = ""
data_type: DataType = DataType.FP32
feature_names: List[str] = field(default_factory=list)
weight_init_max: Optional[float] = None
weight_init_min: Optional[float] = None
num_embeddings_post_pruning: Optional[int] = None

init_fn: Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]] = None
# when the position_weighted feature is in this table config,
# enable this flag to support rw_sharding
need_pos: bool = False


class TestEmbeddingConfig(unittest.TestCase):
def test_embedding_bag_config(self) -> None:
self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingBagConfig.__init__),
inspect.signature(EmbeddingBagConfig.__init__),
)
)

def test_embedding_config(self) -> None:
self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingConfig.__init__),
inspect.signature(EmbeddingConfig.__init__),
)
)
159 changes: 159 additions & 0 deletions torchrec/schema/api_tests/test_embedding_module_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import inspect
import unittest
from typing import Dict, List, Optional

import torch
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.modules.embedding_modules import (
EmbeddingBagCollection,
EmbeddingCollection,
)

from torchrec.schema.utils import is_signature_compatible
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


class StableEmbeddingBagCollectionInterface:
"""
Stable Interface for `EmbeddingBagCollection`.
"""

def __init__(
self,
tables: List[EmbeddingBagConfig],
is_weighted: bool = False,
device: Optional[torch.device] = None,
) -> None:
pass

def forward(
self,
features: KeyedJaggedTensor,
) -> KeyedTensor:
return KeyedTensor(
keys=[],
length_per_key=[],
values=torch.empty(0),
)

def embedding_bag_configs(
self,
) -> List[EmbeddingBagConfig]:
return []

def is_weighted(self) -> bool:
return False


class StableEmbeddingCollectionInterface:
"""
Stable Interface for `EmbeddingBagCollection`.
"""

def __init__(
self,
tables: List[EmbeddingConfig],
device: Optional[torch.device] = None,
need_indices: bool = False,
) -> None:
return

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, JaggedTensor]:
return {}

def embedding_configs(
self,
) -> List[EmbeddingConfig]:
return []

def need_indices(self) -> bool:
return False

def embedding_dim(self) -> int:
return 0

def embedding_names_by_table(self) -> List[List[str]]:
return []


class TestEmbeddingConfig(unittest.TestCase):
def test_embedding_bag_collection(self) -> None:
self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingBagCollectionInterface.__init__),
inspect.signature(EmbeddingBagCollection.__init__),
)
)

self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingBagCollectionInterface.forward),
inspect.signature(EmbeddingBagCollection.forward),
)
)

self.assertTrue(
is_signature_compatible(
inspect.signature(
StableEmbeddingBagCollectionInterface.embedding_bag_configs
),
inspect.signature(EmbeddingBagCollection.embedding_bag_configs),
)
)

self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingBagCollectionInterface.is_weighted),
inspect.signature(EmbeddingBagCollection.is_weighted),
)
)

def test_embedding_collection(self) -> None:
self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingCollectionInterface.__init__),
inspect.signature(EmbeddingCollection.__init__),
)
)

self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingCollectionInterface.forward),
inspect.signature(EmbeddingCollection.forward),
)
)

self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingCollectionInterface.embedding_configs),
inspect.signature(EmbeddingCollection.embedding_configs),
)
)

self.assertTrue(
is_signature_compatible(
inspect.signature(StableEmbeddingCollectionInterface.embedding_dim),
inspect.signature(EmbeddingCollection.embedding_dim),
)
)

self.assertTrue(
is_signature_compatible(
inspect.signature(
StableEmbeddingCollectionInterface.embedding_names_by_table
),
inspect.signature(EmbeddingCollection.embedding_names_by_table),
)
)

0 comments on commit ff6dc0a

Please sign in to comment.