Skip to content

Commit

Permalink
dyn col dim proposer (#1794)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1794

Refactored version of D48339456.

design doc (wip): https://docs.google.com/document/d/1cM8jlJJJYwkXDfeRKGT9jo_S4IDnVZ0X2DWF2XwHTvE/

The idea is to decrease CW dim for certain tables that are causing max perf or max hbm usage.

Reviewed By: ge0405

Differential Revision: D54824506

fbshipit-source-id: 8fb18c00c89b059f00f63de62f0b435e60f0fe7b
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Mar 15, 2024
1 parent 32e0e14 commit 62d0742
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 1 deletion.
32 changes: 31 additions & 1 deletion torchrec/distributed/planner/perf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,21 @@

from typing import cast, List

from torchrec.distributed.planner.types import Perf, PerfModel, ShardingOption, Topology
from torchrec.distributed.planner.types import (
Perf,
PerfModel,
ShardingOption,
Storage,
Topology,
)


class NoopPerfModel(PerfModel):
"""
A no-op model that returns the maximum perf among all shards. Here, no-op
means we estimate the performance of a model without actually running it.
"""

def __init__(self, topology: Topology) -> None:
self._topology = topology

Expand All @@ -24,3 +35,22 @@ def rate(self, plan: List[ShardingOption]) -> float:
perfs[shard.rank] += cast(Perf, shard.perf).total

return max(perfs)


class NoopStorageModel(PerfModel):
"""
A no-op model that returns the maximum hbm usage among all shards. Here, no-op
means we estimate the performance of a model without actually running it.
"""

def __init__(self, topology: Topology) -> None:
self._topology = topology

def rate(self, plan: List[ShardingOption]) -> float:
hbms = [0] * self._topology.world_size
for sharding_option in plan:
for shard in sharding_option.shards:
# pyre-ignore [6]: Expected `typing_extensions.SupportsIndex`
hbms[shard.rank] += cast(Storage, shard.storage).hbm

return max(hbms)
62 changes: 62 additions & 0 deletions torchrec/distributed/planner/tests/test_perf_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
from unittest.mock import MagicMock

from torchrec.distributed.planner.perf_models import NoopPerfModel, NoopStorageModel
from torchrec.distributed.planner.types import (
Perf,
Shard,
ShardingOption,
Storage,
Topology,
)


class TestPerfModels(unittest.TestCase):
def setUp(self) -> None:
self.topology = Topology(world_size=2, compute_device="cuda")
self.tables = [
ShardingOption(
name=MagicMock(),
tensor=MagicMock(),
module=MagicMock(),
input_lengths=MagicMock(),
batch_size=MagicMock(),
sharding_type=MagicMock(),
partition_by=MagicMock(),
compute_kernel=MagicMock(),
shards=[
Shard(
size=MagicMock(),
offset=MagicMock(),
rank=rank,
perf=Perf(
fwd_compute=2 - rank,
fwd_comms=0,
bwd_compute=0,
bwd_comms=0,
),
storage=Storage(hbm=100 * (rank + 1), ddr=0),
),
],
)
for rank in range(2)
]

def test_noop_perf_model(self) -> None:
perf_model = NoopPerfModel(self.topology)
perf_rating = perf_model.rate(self.tables)
self.assertEqual(perf_rating, 2)

def test_noop_storage_model(self) -> None:
perf_model = NoopStorageModel(self.topology)
perf_rating = perf_model.rate(self.tables)
self.assertEqual(perf_rating, 200)

0 comments on commit 62d0742

Please sign in to comment.