Skip to content

Commit 62d0742

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
dyn col dim proposer (#1794)
Summary: Pull Request resolved: #1794 Refactored version of D48339456. design doc (wip): https://docs.google.com/document/d/1cM8jlJJJYwkXDfeRKGT9jo_S4IDnVZ0X2DWF2XwHTvE/ The idea is to decrease CW dim for certain tables that are causing max perf or max hbm usage. Reviewed By: ge0405 Differential Revision: D54824506 fbshipit-source-id: 8fb18c00c89b059f00f63de62f0b435e60f0fe7b
1 parent 32e0e14 commit 62d0742

File tree

2 files changed

+93
-1
lines changed

2 files changed

+93
-1
lines changed

torchrec/distributed/planner/perf_models.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,21 @@
99

1010
from typing import cast, List
1111

12-
from torchrec.distributed.planner.types import Perf, PerfModel, ShardingOption, Topology
12+
from torchrec.distributed.planner.types import (
13+
Perf,
14+
PerfModel,
15+
ShardingOption,
16+
Storage,
17+
Topology,
18+
)
1319

1420

1521
class NoopPerfModel(PerfModel):
22+
"""
23+
A no-op model that returns the maximum perf among all shards. Here, no-op
24+
means we estimate the performance of a model without actually running it.
25+
"""
26+
1627
def __init__(self, topology: Topology) -> None:
1728
self._topology = topology
1829

@@ -24,3 +35,22 @@ def rate(self, plan: List[ShardingOption]) -> float:
2435
perfs[shard.rank] += cast(Perf, shard.perf).total
2536

2637
return max(perfs)
38+
39+
40+
class NoopStorageModel(PerfModel):
41+
"""
42+
A no-op model that returns the maximum hbm usage among all shards. Here, no-op
43+
means we estimate the performance of a model without actually running it.
44+
"""
45+
46+
def __init__(self, topology: Topology) -> None:
47+
self._topology = topology
48+
49+
def rate(self, plan: List[ShardingOption]) -> float:
50+
hbms = [0] * self._topology.world_size
51+
for sharding_option in plan:
52+
for shard in sharding_option.shards:
53+
# pyre-ignore [6]: Expected `typing_extensions.SupportsIndex`
54+
hbms[shard.rank] += cast(Storage, shard.storage).hbm
55+
56+
return max(hbms)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from unittest.mock import MagicMock
12+
13+
from torchrec.distributed.planner.perf_models import NoopPerfModel, NoopStorageModel
14+
from torchrec.distributed.planner.types import (
15+
Perf,
16+
Shard,
17+
ShardingOption,
18+
Storage,
19+
Topology,
20+
)
21+
22+
23+
class TestPerfModels(unittest.TestCase):
24+
def setUp(self) -> None:
25+
self.topology = Topology(world_size=2, compute_device="cuda")
26+
self.tables = [
27+
ShardingOption(
28+
name=MagicMock(),
29+
tensor=MagicMock(),
30+
module=MagicMock(),
31+
input_lengths=MagicMock(),
32+
batch_size=MagicMock(),
33+
sharding_type=MagicMock(),
34+
partition_by=MagicMock(),
35+
compute_kernel=MagicMock(),
36+
shards=[
37+
Shard(
38+
size=MagicMock(),
39+
offset=MagicMock(),
40+
rank=rank,
41+
perf=Perf(
42+
fwd_compute=2 - rank,
43+
fwd_comms=0,
44+
bwd_compute=0,
45+
bwd_comms=0,
46+
),
47+
storage=Storage(hbm=100 * (rank + 1), ddr=0),
48+
),
49+
],
50+
)
51+
for rank in range(2)
52+
]
53+
54+
def test_noop_perf_model(self) -> None:
55+
perf_model = NoopPerfModel(self.topology)
56+
perf_rating = perf_model.rate(self.tables)
57+
self.assertEqual(perf_rating, 2)
58+
59+
def test_noop_storage_model(self) -> None:
60+
perf_model = NoopStorageModel(self.topology)
61+
perf_rating = perf_model.rate(self.tables)
62+
self.assertEqual(perf_rating, 200)

0 commit comments

Comments
 (0)