diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index b89eb767a..bc355af3b 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -338,7 +338,9 @@ def feedback( logger.info( f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB" ) - self.search = LuusJaakolaSearch(0, hbm_available, max_iterations=16) + self.search = LuusJaakolaSearch( + 0, hbm_available, max_iterations=16, left_cost=perf_rating + ) logger.info( f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}" diff --git a/torchrec/distributed/planner/tests/test_utils.py b/torchrec/distributed/planner/tests/test_utils.py index 2bf861537..5fe3ce9ed 100644 --- a/torchrec/distributed/planner/tests/test_utils.py +++ b/torchrec/distributed/planner/tests/test_utils.py @@ -7,7 +7,7 @@ import math import unittest -from typing import Callable, List +from typing import Callable, List, Optional from unittest.mock import MagicMock import torch @@ -124,12 +124,17 @@ class TestLuusJaakolaSearch(unittest.TestCase): # just getting lucky. # Returns a Nx2 tensor of [xs, ys] of discovered minimums. @staticmethod - def evaluate(x0: float, x1: float, f: Callable[[float], float]) -> torch.Tensor: + def evaluate( + x0: float, + x1: float, + f: Callable[[float], float], + left_cost: Optional[float] = None, + ) -> torch.Tensor: xs = [] ys = [] iterations = 16 for i in range(5): - search = LuusJaakolaSearch(x0, x1, iterations, seed=i) + search = LuusJaakolaSearch(x0, x1, iterations, seed=i, left_cost=left_cost) y = search.next(0.0) while y is not None: fy = f(y) @@ -360,3 +365,18 @@ def f(x: float) -> float: ], ) torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate( + mem.min().item(), mem.max().item(), f, left_cost=cost[0].item() + ) + want = torch.tensor( + [ + [5.370294e11, 2.314406e02], + # 2nd search finds better result given left_cost + [5.918126e11, 2.308140e02], + [5.908549e11, 2.308194e02], + [5.755533e11, 2.309337e02], + [6.184178e11, 2.308121e02], + ], + ) + torch.testing.assert_close(results, want) diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index c8f96da03..2e33d31d1 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -172,7 +172,14 @@ class LuusJaakolaSearch: See https://en.wikipedia.org/wiki/Luus-Jaakola. """ - def __init__(self, A: float, B: float, max_iterations: int, seed: int = 42) -> None: + def __init__( + self, + A: float, + B: float, + max_iterations: int, + seed: int = 42, + left_cost: Optional[float] = None, + ) -> None: self.left = A self.right = B self.iteration = -1 @@ -184,7 +191,7 @@ def __init__(self, A: float, B: float, max_iterations: int, seed: int = 42) -> N self.x: float = self.uniform(self.left, self.right) self.fx: float = 0.0 self.y: float = math.nan - self.fleft: Optional[float] = None + self.fleft: Optional[float] = left_cost self.fright: Optional[float] = None self.d: float = self.right - self.left