Skip to content

Commit

Permalink
Check mch-ec for sequence embeddings too
Browse files Browse the repository at this point in the history
Summary: If sharder is mch-ec, pooling is sequence.

Differential Revision: D52646853
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 10, 2024
1 parent 813671c commit 8bf159f
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 2 deletions.
156 changes: 156 additions & 0 deletions torchrec/distributed/planner/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import cast
from unittest.mock import MagicMock

import torch
Expand All @@ -19,6 +20,21 @@
DataType,
ShardingType,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.modules.embedding_modules import (
EmbeddingBagCollection,
EmbeddingCollection,
)
from torchrec.modules.mc_embedding_modules import (
ManagedCollisionCollection,
ManagedCollisionEmbeddingBagCollection,
ManagedCollisionEmbeddingCollection,
)
from torchrec.modules.mc_modules import (
DistanceLFU_EvictionPolicy,
ManagedCollisionModule,
MCHManagedCollisionModule,
)


class TestShardingOption(unittest.TestCase):
Expand Down Expand Up @@ -49,3 +65,143 @@ def test_hash_sharding_option(self) -> None:
bounds_check_mode=BoundsCheckMode.WARNING,
)
self.assertTrue(map(hash, [sharding_option]))

def test_module_pooled_ebc(self) -> None:
eb_config = EmbeddingBagConfig(
name="table_0",
embedding_dim=160,
num_embeddings=10000,
feature_names=["f1"],
data_type=DataType.FP16,
)
ebc = EmbeddingBagCollection(tables=[eb_config])

sharding_option: ShardingOption = ShardingOption(
name="table_0",
tensor=torch.empty(
(10000, 160), dtype=torch.float16, device=torch.device("meta")
),
module=("ebc", ebc),
input_lengths=MagicMock(),
batch_size=MagicMock(),
sharding_type=ShardingType.COLUMN_WISE.value,
partition_by=MagicMock(),
compute_kernel=EmbeddingComputeKernel.FUSED.value,
shards=[
Shard(size=[10000, 80], offset=offset) for offset in [[0, 0], [0, 80]]
],
)
self.assertEqual(sharding_option.is_pooled, True)

def test_module_pooled_mch_ebc(self) -> None:
eb_config = EmbeddingBagConfig(
name="table_0",
embedding_dim=160,
num_embeddings=10000,
feature_names=["f1"],
data_type=DataType.FP16,
)
ebc = EmbeddingBagCollection(tables=[eb_config])
mc_modules = {
"table_0": cast(
ManagedCollisionModule,
MCHManagedCollisionModule(
zch_size=10000,
device=torch.device("meta"),
eviction_interval=1,
eviction_policy=DistanceLFU_EvictionPolicy(),
),
),
}
mcc = ManagedCollisionCollection(
managed_collision_modules=mc_modules,
embedding_configs=[eb_config],
)
mch_ebc = ManagedCollisionEmbeddingBagCollection(ebc, mcc)

sharding_option: ShardingOption = ShardingOption(
name="table_0",
tensor=torch.empty(
(10000, 80), dtype=torch.float16, device=torch.device("meta")
),
module=("mch_ebc", mch_ebc),
input_lengths=MagicMock(),
batch_size=MagicMock(),
sharding_type=ShardingType.COLUMN_WISE.value,
partition_by=MagicMock(),
compute_kernel=EmbeddingComputeKernel.FUSED.value,
shards=[
Shard(size=[10000, 80], offset=offset) for offset in [[0, 0], [0, 80]]
],
)
self.assertEqual(sharding_option.is_pooled, True)

def test_module_pooled_ec(self) -> None:
e_config = EmbeddingConfig(
name="table_0",
embedding_dim=80,
num_embeddings=10000,
feature_names=["f1"],
data_type=DataType.FP16,
)
ec = EmbeddingCollection(tables=[e_config])

shard_size = [10000, 80]
shard_offsets = [[0, 0], [0, 80]]
sharding_option: ShardingOption = ShardingOption(
name="table_0",
tensor=torch.empty(
(10000, 160), dtype=torch.float16, device=torch.device("meta")
),
module=("ec", ec),
input_lengths=MagicMock(),
batch_size=MagicMock(),
sharding_type=ShardingType.COLUMN_WISE.value,
partition_by=MagicMock(),
compute_kernel=EmbeddingComputeKernel.FUSED.value,
shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets],
)
self.assertEqual(sharding_option.is_pooled, False)

def test_module_pooled_mch_ec(self) -> None:
e_config = EmbeddingConfig(
name="table_0",
embedding_dim=80,
num_embeddings=10000,
feature_names=["f1"],
data_type=DataType.FP16,
)
ec = EmbeddingCollection(tables=[e_config])
mc_modules = {
"table_0": cast(
ManagedCollisionModule,
MCHManagedCollisionModule(
zch_size=10000,
device=torch.device("meta"),
eviction_interval=1,
eviction_policy=DistanceLFU_EvictionPolicy(),
),
),
}
mcc = ManagedCollisionCollection(
managed_collision_modules=mc_modules,
embedding_configs=[e_config],
)
mch_ec = ManagedCollisionEmbeddingCollection(ec, mcc)

shard_size = [10000, 80]
shard_offsets = [[0, 0], [0, 80]]
sharding_option: ShardingOption = ShardingOption(
name="table_0",
tensor=torch.empty(
(10000, 160), dtype=torch.float16, device=torch.device("meta")
),
module=("mch_ec", mch_ec),
input_lengths=MagicMock(),
batch_size=MagicMock(),
sharding_type=ShardingType.COLUMN_WISE.value,
partition_by=MagicMock(),
compute_kernel=EmbeddingComputeKernel.FUSED.value,
shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets],
)
self.assertEqual(sharding_option.is_pooled, False)
9 changes: 7 additions & 2 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ShardingPlan,
)
from torchrec.modules.embedding_modules import EmbeddingCollectionInterface
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection

# ---- Perf ---- #

Expand Down Expand Up @@ -355,11 +356,15 @@ def is_pooled(self) -> bool:
@staticmethod
def module_pooled(module: nn.Module, sharding_option_name: str) -> bool:
"""Determine if module pools output (e.g. EmbeddingBag) or uses unpooled/sequential output."""
if isinstance(module, EmbeddingCollectionInterface):
if isinstance(module, EmbeddingCollectionInterface) or isinstance(
module, ManagedCollisionEmbeddingCollection
):
return False

for submodule in module.modules():
if isinstance(submodule, EmbeddingCollectionInterface):
if isinstance(submodule, EmbeddingCollectionInterface) or isinstance(
module, ManagedCollisionEmbeddingCollection
):
for name, _ in submodule.named_parameters():
if sharding_option_name in name:
return False
Expand Down

0 comments on commit 8bf159f

Please sign in to comment.