Skip to content

Commit

Permalink
Support one-TBE-per-table (#1775)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1775

Implement one-TBE-per-table as a first class option for diagnosing
uvm_caching. We'll wire this up to APF in a subsequent diff, this just
adds support in torchrec.

Reviewed By: henrylhtsang

Differential Revision: D54222784

fbshipit-source-id: 20e16ac4b2ce36075ec41d261fc1a5a8b5dd6c9e
  • Loading branch information
Damian Reeves authored and facebook-github-bot committed Mar 12, 2024
1 parent ec7c05b commit 2edb86c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import abc
import copy
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -48,6 +49,7 @@
torch.fx.wrap("len")

CACHE_LOAD_FACTOR_STR: str = "cache_load_factor"
USE_ONE_TBE_PER_TABLE: str = "use_one_tbe_per_table"


# torch.Tensor.to can not be fx symbolic traced as it does not go through __torch_dispatch__ => fx.wrap it
Expand Down Expand Up @@ -213,6 +215,10 @@ def _get_grouping_fused_params(
if CACHE_LOAD_FACTOR_STR in grouping_fused_params:
del grouping_fused_params[CACHE_LOAD_FACTOR_STR]

if grouping_fused_params.get(USE_ONE_TBE_PER_TABLE, False):
# Replace with unique value to force it into singleton group.
grouping_fused_params[USE_ONE_TBE_PER_TABLE] = str(uuid.uuid4())

return grouping_fused_params


Expand Down Expand Up @@ -296,11 +302,11 @@ def _group_tables_per_rank(
_,
) = grouping_key
grouped_tables = groups[grouping_key]

# remove non-native fused params
per_tbe_fused_params = {
k: v
for k, v in fused_params_tuple
if k not in ["_batch_key"] # drop '_batch_key' not a native fused param
if k not in ["_batch_key", USE_ONE_TBE_PER_TABLE]
}
cache_load_factor = _get_weighted_avg_cache_load_factor(grouped_tables)
if cache_load_factor is not None:
Expand Down
25 changes: 25 additions & 0 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,28 @@ def test_should_not_group_together(
sorted(_get_table_names_by_groups(tables)),
[["table_0"], ["table_1"]],
)

def test_use_one_tbe_per_table(
self,
) -> None:

tables = [
ShardedEmbeddingTable(
name=f"table_{i}",
data_type=DataType.FP16,
pooling=PoolingType.SUM,
has_feature_processor=False,
fused_params={"use_one_tbe_per_table": i % 2 != 0},
compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING,
embedding_dim=10,
num_embeddings=10000,
)
for i in range(5)
]

# Even tables should all be grouped in a single TBE, odd tables should be in
# their own TBEs.
self.assertEqual(
_get_table_names_by_groups(tables),
[["table_0", "table_2", "table_4"], ["table_1"], ["table_3"]],
)

0 comments on commit 2edb86c

Please sign in to comment.