Skip to content

Commit

Permalink
Support weighted_bwd_compute_multiplier in sharding estimators (#2068)
Browse files Browse the repository at this point in the history
Summary:
# Context
This diff stack is a quick mitigation for sharding imbalance on weighted features.

TorchRec sharding planner focus more on forward cost. e.g., it isn’t aware difference like:
- for the unweighted lookups, backward kernel is, say, 2x of forward kernel.
- but for the weighted lookups, backward kernel may be 4x of forward kernel.

See https://docs.google.com/document/d/1o-lB6veGVIZFO148ljSuhVr8sA9fm2EGwOfhX6TSdrg/edit?usp=sharing for the context

# This Diff
- Enable shard estimator to include weighted_feature_bwd_compute_multiplier when compute bwd cost for weighted feature

Pull Request resolved: #2068

Reviewed By: xush6528, sarckk

Differential Revision: D53550851

fbshipit-source-id: bd14c9b8dc01d47802741978288a002cc58e85ee
  • Loading branch information
ys97529 authored and facebook-github-bot committed Jun 6, 2024
1 parent da49f44 commit 8393202
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 1 deletion.
1 change: 1 addition & 0 deletions torchrec/distributed/planner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
HALF_BLOCK_PENALTY: float = 1.15 # empirical studies
QUARTER_BLOCK_PENALTY: float = 1.75 # empirical studies
BWD_COMPUTE_MULTIPLIER: float = 2 # empirical studies
WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER: float = 1 # empirical studies
WEIGHTED_KERNEL_MULTIPLIER: float = 1.1 # empirical studies
DP_ELEMENTWISE_KERNELS_PERF_FACTOR: float = 9.22 # empirical studies

Expand Down
18 changes: 18 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def estimate(
intra_host_bw=self._topology.intra_host_bw,
inter_host_bw=self._topology.inter_host_bw,
bwd_compute_multiplier=self._topology.bwd_compute_multiplier,
weighted_feature_bwd_compute_multiplier=self._topology.weighted_feature_bwd_compute_multiplier,
is_pooled=sharding_option.is_pooled,
is_weighted=is_weighted,
is_inference=self._is_inference,
Expand Down Expand Up @@ -251,6 +252,7 @@ def perf_func_emb_wall_time(
intra_host_bw: float,
inter_host_bw: float,
bwd_compute_multiplier: float,
weighted_feature_bwd_compute_multiplier: float,
is_pooled: bool,
is_weighted: bool = False,
caching_ratio: Optional[float] = None,
Expand Down Expand Up @@ -336,6 +338,7 @@ def perf_func_emb_wall_time(
inter_host_bw=inter_host_bw,
intra_host_bw=intra_host_bw,
bwd_compute_multiplier=bwd_compute_multiplier,
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
is_pooled=is_pooled,
is_weighted=is_weighted,
is_inference=is_inference,
Expand All @@ -361,6 +364,7 @@ def perf_func_emb_wall_time(
inter_host_bw=inter_host_bw,
intra_host_bw=intra_host_bw,
bwd_compute_multiplier=bwd_compute_multiplier,
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
is_pooled=is_pooled,
is_weighted=is_weighted,
expected_cache_fetches=expected_cache_fetches,
Expand All @@ -386,6 +390,7 @@ def perf_func_emb_wall_time(
inter_host_bw=inter_host_bw,
intra_host_bw=intra_host_bw,
bwd_compute_multiplier=bwd_compute_multiplier,
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
is_pooled=is_pooled,
is_weighted=is_weighted,
expected_cache_fetches=expected_cache_fetches,
Expand All @@ -405,6 +410,7 @@ def perf_func_emb_wall_time(
device_bw=device_bw,
inter_host_bw=inter_host_bw,
bwd_compute_multiplier=bwd_compute_multiplier,
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
is_pooled=is_pooled,
is_weighted=is_weighted,
)
Expand Down Expand Up @@ -447,6 +453,7 @@ def _get_tw_sharding_perf(
inter_host_bw: float,
intra_host_bw: float,
bwd_compute_multiplier: float,
weighted_feature_bwd_compute_multiplier: float,
is_pooled: bool,
is_weighted: bool = False,
is_inference: bool = False,
Expand Down Expand Up @@ -507,6 +514,8 @@ def _get_tw_sharding_perf(

# includes fused optimizers
bwd_compute = fwd_compute * bwd_compute_multiplier
if is_weighted:
bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier

prefetch_compute = cls._get_expected_cache_prefetch_time(
ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size
Expand Down Expand Up @@ -543,6 +552,7 @@ def _get_rw_sharding_perf(
inter_host_bw: float,
intra_host_bw: float,
bwd_compute_multiplier: float,
weighted_feature_bwd_compute_multiplier: float,
is_pooled: bool,
is_weighted: bool = False,
expected_cache_fetches: float = 0,
Expand Down Expand Up @@ -601,6 +611,8 @@ def _get_rw_sharding_perf(
)

bwd_compute = fwd_compute * bwd_compute_multiplier
if is_weighted:
bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier

# for row-wise, expected_cache_fetches per shard is / world_size
prefetch_compute = cls._get_expected_cache_prefetch_time(
Expand Down Expand Up @@ -639,6 +651,7 @@ def _get_twrw_sharding_perf(
inter_host_bw: float,
intra_host_bw: float,
bwd_compute_multiplier: float,
weighted_feature_bwd_compute_multiplier: float,
is_pooled: bool,
is_weighted: bool = False,
expected_cache_fetches: float = 0,
Expand Down Expand Up @@ -697,6 +710,8 @@ def _get_twrw_sharding_perf(
bwd_batched_copy = bwd_output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw

bwd_compute = fwd_compute * bwd_compute_multiplier
if is_weighted:
bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier

# for table-wise-row-wise, expected_cache_fetches per shard is / local_world_size
prefetch_compute = cls._get_expected_cache_prefetch_time(
Expand Down Expand Up @@ -730,6 +745,7 @@ def _get_dp_sharding_perf(
device_bw: float,
inter_host_bw: float,
bwd_compute_multiplier: float,
weighted_feature_bwd_compute_multiplier: float,
is_pooled: bool,
is_weighted: bool = False,
) -> Perf:
Expand Down Expand Up @@ -772,6 +788,8 @@ def _get_dp_sharding_perf(
optimizer_kernels = table_size * DP_ELEMENTWISE_KERNELS_PERF_FACTOR / device_bw

bwd_compute = fwd_compute * bwd_compute_multiplier
if is_weighted:
bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier

bwd_grad_indice_weights_kernel = (
fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0
Expand Down
91 changes: 90 additions & 1 deletion torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import cast
from typing import cast, Dict, List, Tuple

from unittest.mock import Mock, patch

Expand Down Expand Up @@ -526,6 +526,95 @@ def cacheability(self) -> float:
}
self.assertEqual(expected_prefetch_computes, prefetch_computes)

def test_weighted_feature_bwd_compute_multiplier(self) -> None:
def _get_bwd_computes(
model: torch.nn.Module,
weighted_feature_bwd_compute_multiplier: float,
) -> Dict[Tuple[str, str, str], List[float]]:
topology = Topology(
world_size=2,
compute_device="cuda",
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
)
estimator = EmbeddingPerfEstimator(topology=topology)
enumerator = EmbeddingEnumerator(
topology=topology, batch_size=BATCH_SIZE, estimator=estimator
)
sharding_options = enumerator.enumerate(
module=model,
sharders=[
cast(
ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()
)
],
)
bwd_computes = {
(
sharding_option.name,
sharding_option.compute_kernel,
sharding_option.sharding_type,
): [
shard.perf.bwd_compute if shard.perf else -1
for shard in sharding_option.shards
]
for sharding_option in sharding_options
}
return bwd_computes

tables = [
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=10,
name="table_0",
feature_names=["feature_0"],
)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=10,
name="weighted_table_0",
feature_names=["weighted_feature_0"],
)
]
model = TestSparseNN(tables=tables, weighted_tables=weighted_tables)

MULTIPLIER = 7
bwd_computes_1 = _get_bwd_computes(
model, weighted_feature_bwd_compute_multiplier=1
)
bwd_computes_2 = _get_bwd_computes(
model,
weighted_feature_bwd_compute_multiplier=2,
)
bwd_computes_n = _get_bwd_computes(
model,
weighted_feature_bwd_compute_multiplier=MULTIPLIER,
)
self.assertEqual(bwd_computes_1.keys(), bwd_computes_2.keys())
self.assertEqual(bwd_computes_1.keys(), bwd_computes_n.keys())
for key in bwd_computes_1.keys():
table_name, _, sharding_type = key
if table_name.startswith("weighted"):
self.assertEqual(len(bwd_computes_1), len(bwd_computes_2))
self.assertEqual(len(bwd_computes_1), len(bwd_computes_n))
for bwd_compute_1, bwd_compute_2, bwd_compute_n in zip(
bwd_computes_1[key], bwd_computes_2[key], bwd_computes_n[key]
):
# bwd_compute_1 = base_bwd_compute + offset
# bwd_compute_2 = base_bwd_compute * 2 + offset
# bwd_compute_n = base_bwd_compute * MULTIPLIER + offset
# (where offset = bwd_grad_indice_weights_kernel in production
# https://fburl.com/code/u9hq6vhf)
base_bwd_compute = bwd_compute_2 - bwd_compute_1
offset = bwd_compute_1 - base_bwd_compute
self.assertAlmostEqual(
base_bwd_compute * MULTIPLIER,
bwd_compute_n - offset,
)
else:
self.assertEqual(bwd_computes_1[key], bwd_computes_2[key])


# pyre-ignore[3]
def calculate_storage_specific_size_data_provider():
Expand Down
9 changes: 9 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
HBM_MEM_BW,
INTRA_NODE_BANDWIDTH,
POOLING_FACTOR,
WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER,
)
from torchrec.distributed.types import (
BoundsCheckMode,
Expand Down Expand Up @@ -186,6 +187,7 @@ def __init__(
inter_host_bw: float = CROSS_NODE_BANDWIDTH,
bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER,
custom_topology_data: Optional[CustomTopologyData] = None,
weighted_feature_bwd_compute_multiplier: float = WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER,
) -> None:
"""
Representation of a network of devices in a cluster.
Expand Down Expand Up @@ -238,6 +240,9 @@ def __init__(
self._inter_host_bw = inter_host_bw
self._bwd_compute_multiplier = bwd_compute_multiplier
self._custom_topology_data = custom_topology_data
self._weighted_feature_bwd_compute_multiplier = (
weighted_feature_bwd_compute_multiplier
)

@property
def compute_device(self) -> str:
Expand Down Expand Up @@ -275,6 +280,10 @@ def inter_host_bw(self) -> float:
def bwd_compute_multiplier(self) -> float:
return self._bwd_compute_multiplier

@property
def weighted_feature_bwd_compute_multiplier(self) -> float:
return self._weighted_feature_bwd_compute_multiplier

def __repr__(self) -> str:
topology_repr: str = f"world_size={self._world_size} \n"
topology_repr += f"compute_device={self._compute_device}\n"
Expand Down

0 comments on commit 8393202

Please sign in to comment.