diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 137bc7ad1..af174dd75 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -283,9 +283,11 @@ def plan( ) if current_storage < lowest_storage: lowest_storage = current_storage - proposal_cache[proposal_key] = (False, None, None) + proposal_cache[proposal_key] = (False, proposal, None) proposer.feedback( - partitionable=False, storage_constraint=storage_constraint + partitionable=False, + plan=proposal, + storage_constraint=storage_constraint, ) # clear shard.rank for each sharding_option diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index 342c618c1..b89eb767a 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -22,7 +22,7 @@ ShardingOption, Topology, ) -from torchrec.distributed.planner.utils import BinarySearchPredicate, bytes_to_gb, prod +from torchrec.distributed.planner.utils import bytes_to_gb, LuusJaakolaSearch, prod logger: logging.Logger = logging.getLogger(__name__) @@ -283,8 +283,7 @@ def __init__(self, use_depth: bool = True) -> None: self.enumerator: Optional[Enumerator] = None self.starting_proposal: List[ShardingOption] = [] self.proposal: Optional[List[ShardingOption]] = None - self.search: Optional[BinarySearchPredicate] = None - self.previous_plan_perf_rating: float = 0.0 + self.search: Optional[LuusJaakolaSearch] = None def load( self, @@ -320,7 +319,7 @@ def feedback( perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None, ) -> None: - if not self.enumerator or not plan or not storage_constraint: + if not self.enumerator or plan is None: self.proposal = None return @@ -329,47 +328,26 @@ def feedback( ) if self.search is None: - # Determine how much extra HBM memory is available for scaling our caches - # beyond the baseline-min-working-set plan. We may not be able to find a - # partitionable plan that uses all this budget, or we may find a plan that - # uses only a portion of this budget enables a layout that reduces overall - # cost at the expense of larger prefetch delay. So we perform a binary - # search to sample plans with different budgets to discover a good - # configuration. + if not partitionable or storage_constraint is None: + self.proposal = None + return + hbm_available = EmbeddingOffloadScaleupProposer.get_budget( plan, storage_constraint ) logger.info( f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB" ) - # Partitioning proposals is quite expensive when there are a lot of tables, - # so we reduce the number probes the binary search uses to find the max - # cache sizes that fit inside the budget by specifying a tolerance. Once - # we've less than tolerance bytes left of unused budget we stop searching. - # We set tolerance to try to waste less than 3% of budget. For 100TB budget, - # this reduces number of proposals from 47 to 6. - tolerance = round(hbm_available * 0.03) - self.search = BinarySearchPredicate(0, hbm_available, tolerance) + self.search = LuusJaakolaSearch(0, hbm_available, max_iterations=16) logger.info( f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}" ) - # Guide the binary search. We assume the partitioned perf model cost is - # monotonic with respect to CLF, so if the feedback from our prior attempt was - # worse than previous one, we reduce the memory for our next proposal, else we - # try using more. This allows us to focus the budget allocation search into the - # productive region where plans are still getting better. - warmer = partitionable and ( - self.previous_plan_perf_rating == 0.0 - or ( - perf_rating is not None and perf_rating < self.previous_plan_perf_rating - ) - ) - self.previous_plan_perf_rating = perf_rating or 0.0 - assert self.search is not None # keep pyre happy - budget = self.search.next(warmer) + budget = self.search.next(perf_rating or 1e99) + if budget is not None: + budget = int(budget) self.proposal = EmbeddingOffloadScaleupProposer.next_plan( self.starting_proposal, budget, self.enumerator ) diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index 89574032c..817d90c70 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -422,20 +422,21 @@ def test_allocate_budget(self) -> None: self.assertEqual(increase, budget) def test_scaleup(self) -> None: - tables = [ EmbeddingBagConfig( num_embeddings=2_000_000, - embedding_dim=10, + embedding_dim=10000, name=f"table_{i}", feature_names=[f"feature_{i}"], ) - for i in range(3) + for i in range(4) ] - # Place first two tables into cache, 3rd table leave on hbm. table_1 has a + # Place first three tables into cache, 4th table leave on hbm. table_1 has a # larger cacheability score so budget should be skewed to scaling table_1 more - # than table_0. + # than table_0. table_2 is a deprecated feature we have no stats for (so + # expected_lookups 0), we want to see that left at its original load factor, + # i.e. doesn't participate in scaleup. constraints = { "table_0": ParameterConstraints( compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], @@ -451,126 +452,22 @@ def test_scaleup(self) -> None: stats=MockCacheStatistics(expected_lookups=2, cacheability=0.5), ), ), - } - - MB = 1024 * 1024 - storage_constraint = Topology( - world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB - ) - - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) - enumerator = EmbeddingEnumerator( - topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints - ) - search_space = enumerator.enumerate( - module=model, - sharders=[ - cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) - ], - ) - proposer = EmbeddingOffloadScaleupProposer() - proposer.load(search_space, enumerator=enumerator) - - output = [] - proposal = proposer.propose() - while proposal is not None: - output.append( - [ - ( - candidate.name, - candidate.compute_kernel, - candidate.cache_params.load_factor - if candidate.cache_params - else None, - ) - for candidate in proposal - ] - ) - proposer.feedback( - partitionable=True, - plan=proposal, - storage_constraint=storage_constraint, - ) - proposal = proposer.propose() - - # Expected output (name, kernel clf). - # First attempt uses the mins supplied, then as we apply increasing budget - # clfs increase, with the later attempts enough to promote table_3 into hbm. - expected_output = [ - [ - ("table_0", "fused_uvm_caching", 0.1), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.3025801181793213), - ("table_1", "fused_uvm_caching", 0.6064502596855164), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.403870165348053), - ("table_1", "fused_uvm_caching", 0.859675407409668), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.4545151889324188), - ("table_1", "fused_uvm_caching", 0.9862880110740662), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.5294319987297058), - ("table_1", "fused", None), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.573746383190155), - ("table_1", "fused", None), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.5959035754203796), - ("table_1", "fused", None), - ("table_2", "fused", None), - ], - ] - - self.assertEqual(output, expected_output) - - def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: - tables = [ - EmbeddingBagConfig( - num_embeddings=2_000_000, - embedding_dim=10, - name=f"table_{i}", - feature_names=[f"feature_{i}"], - ) - for i in range(3) - ] - - # Place first two tables into cache, 3rd table leave on hbm. table_1 has an - # expected lookup of 0 (deprecated feature). - constraints = { - "table_0": ParameterConstraints( - compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], - cache_params=CacheParams( - load_factor=0.1, - stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), - ), - ), - "table_1": ParameterConstraints( + "table_2": ParameterConstraints( compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], cache_params=CacheParams( - load_factor=0.1, - stats=MockCacheStatistics(expected_lookups=0, cacheability=0), + load_factor=0.002, + stats=MockCacheStatistics(expected_lookups=0, cacheability=0.5), ), ), } - MB = 1024 * 1024 + GB = 1024 * 1024 * 1024 storage_constraint = Topology( - world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB + world_size=2, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB ) - + # Ignoring table_2, the remainder require 224GB if all placed on HBM. We only + # have 200GB, so we can't promote both uvm tables. Initial plan needs uses 90GB, + # with 110GB of available budget. model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) enumerator = EmbeddingEnumerator( topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints @@ -584,11 +481,18 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: proposer = EmbeddingOffloadScaleupProposer() proposer.load(search_space, enumerator=enumerator) - output = [] proposal = proposer.propose() + best_plan = None + best_perf = 1e99 + proposals = -1 while proposal is not None: - output.append( - [ + proposals += 1 + mem = sum(so.total_storage.hbm for so in proposal) + # simple perf model, assume partitioner gives a lowest score around 150GB of memory. + perf = abs(mem - (150 * GB)) + plan = { + "mem": mem, + "proposal": [ ( candidate.name, candidate.compute_kernel, @@ -597,36 +501,36 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: else None, ) for candidate in proposal - ] - ) + ], + } + if perf < best_perf: + best_plan = plan + best_perf = perf proposer.feedback( partitionable=True, plan=proposal, + perf_rating=perf, storage_constraint=storage_constraint, ) proposal = proposer.propose() - # Expected output (name, kernel clf). - # First attempt uses the mins supplied, then as we apply increasing budget - # clfs increase, table 0 gets promoted, table 1 left as original minimum. - expected_output = [ - [ - ("table_0", "fused_uvm_caching", 0.1), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.8090304136276245), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused", None), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - ] - self.assertEqual(output[0:3], expected_output) + self.assertEqual(proposals, 16) + self.assertEqual( + best_plan, + { + # 146GB, close to target of 150GB + "mem": 157178896800, + # table 1 has been scaled up 2.5x more than table 0 (vs original 0.1) + # which aligns with their different cacheability scores + # table_2 has been left alone (deprecated feature, expected zero lookups in stats) + "proposal": [ + ("table_0", "fused_uvm_caching", 0.3173336386680603), + ("table_1", "fused_uvm_caching", 0.6433340907096863), + ("table_2", "fused_uvm_caching", 0.002), + ("table_3", "fused", None), + ], + }, + ) def test_proposers_to_proposals_list(self) -> None: def make_mock_proposal(name: str) -> List[ShardingOption]: diff --git a/torchrec/distributed/planner/tests/test_utils.py b/torchrec/distributed/planner/tests/test_utils.py index c1b9e8583..2bf861537 100644 --- a/torchrec/distributed/planner/tests/test_utils.py +++ b/torchrec/distributed/planner/tests/test_utils.py @@ -5,14 +5,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import unittest from typing import Callable, List from unittest.mock import MagicMock +import torch + from torchrec.distributed.planner.types import Perf, Shard, ShardingOption, Storage from torchrec.distributed.planner.utils import ( _find_imbalance_tables, BinarySearchPredicate, + LuusJaakolaSearch, reset_shard_rank, ) from torchrec.distributed.types import ShardingType @@ -111,3 +115,248 @@ def probes( self.assertEqual(got, [1]) got = probes(BinarySearchPredicate(1, 0, 0), F) self.assertEqual(got, []) + + +class TestLuusJaakolaSearch(unittest.TestCase): + + # Find minimum of f between x0 and x1. + # Evaluate multiple times with different random seeds to ensure we're not + # just getting lucky. + # Returns a Nx2 tensor of [xs, ys] of discovered minimums. + @staticmethod + def evaluate(x0: float, x1: float, f: Callable[[float], float]) -> torch.Tensor: + xs = [] + ys = [] + iterations = 16 + for i in range(5): + search = LuusJaakolaSearch(x0, x1, iterations, seed=i) + y = search.next(0.0) + while y is not None: + fy = f(y) + y = search.next(fy) + x, y = search.best() + xs.append(x) + ys.append(y) + return torch.stack([torch.tensor(xs), torch.tensor(ys)], dim=1) + + def test_simple(self) -> None: + # See N4816561 to view these results graphically + def f1(x: float) -> float: + return x + + def f2(x: float) -> float: + return x * x - 10 * x + 10 # min at x = 5 + + def f3(x: float) -> float: + # bumpy function, overall min at x=30 + return (x - 30) ** 2 + 100 * math.sin(x) + + def f4(x: float) -> float: + # spiky/non-smooth function, min at x = 30 + return (x - 30) ** 2 + (x % 10) * 100 + + results = TestLuusJaakolaSearch.evaluate(0, 100, f1) + want = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], dtype=torch.int64) + torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate(0, 100, f2) + want = torch.tensor( + [ + [3.51914, -12.80705], + [4.22958, -14.40646], + [5.41303, -14.82940], + [2.35012, -7.97811], + [4.18552, -14.33662], + ] + ) + torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate(0, 100, f3) + want = torch.tensor( + [ + [36.58517, -46.37988], + [29.73184, -99.28705], + [37.67208, 56.15779], + [35.85468, -62.00219], + [41.76223, 58.69744], + ] + ) + torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate(0, 100, f4) + want = torch.tensor( + [ + [23.68681, 408.53735], + [31.62534, 165.17535], + [32.81968, 289.91898], + [42.81567, 445.80777], + [22.53002, 308.80225], + ] + ) + torch.testing.assert_close(results, want) + + def test_iterations(self) -> None: + search = LuusJaakolaSearch(0, 1, 3) + y = search.next(0) + probes = 0 + while y is not None: + probes += 1 + fy = y + y = search.next(fy) + self.assertEqual(probes, 3) + + # https://github.com/pytorch/pytorch/issues/50334 + @staticmethod + def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: + """One-dimensional linear interpolation for monotonically increasing sample + points. + + Returns the one-dimensional piecewise linear interpolant to a function with + given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. + + Args: + x: the :math:`x`-coordinates at which to evaluate the interpolated + values. + xp: the :math:`x`-coordinates of the data points, must be increasing. + fp: the :math:`y`-coordinates of the data points, same length as `xp`. + + Returns: + the interpolated values, same size as `x`. + """ + m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) + b = fp[:-1] - (m * xp[:-1]) + + indicies = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1 + indicies = torch.clamp(indicies, 0, len(m) - 1) + + return m[indicies] * x + b[indicies] + + def test_real(self) -> None: + # See N4816561 to view these results graphically + + # Real data collected from bin packing has non-smooth surface and many local minimums. + # mem vs cost from cmf_icvr bin packing + cmf_icvr = torch.tensor( + [ + [4.6741845183e11, 2.3563506569e02], + [4.6741845240e11, 2.3563506569e02], + [4.7121749230e11, 2.3506600864e02], + [4.7501653103e11, 2.3468280680e02], + [4.7881557076e11, 2.3430065943e02], + [4.8261460996e11, 2.3396533990e02], + [4.8641364892e11, 2.3367888393e02], + [4.9021268717e11, 2.3339395760e02], + [4.9401172728e11, 2.3316084540e02], + [4.9781076708e11, 2.3292654771e02], + [5.0160980674e11, 2.3275780179e02], + [5.0540884491e11, 2.3256067684e02], + [5.0920788486e11, 2.3235742684e02], + [5.1300692424e11, 2.3219262609e02], + [5.1680596356e11, 2.3206849693e02], + [5.2060500162e11, 2.3193348320e02], + [5.2440404195e11, 2.3180536764e02], + [5.2820308146e11, 2.3170546631e02], + [5.3200212032e11, 2.3158138440e02], + [5.3580115967e11, 2.3146545816e02], + [5.3960019895e11, 2.3138856778e02], + [5.4339923878e11, 2.3128211641e02], + [5.4719827815e11, 2.3121699239e02], + [5.5099731798e11, 2.3169756090e02], + [5.5479635643e11, 2.3103278320e02], + [5.5859539575e11, 2.3171106005e02], + [5.6239443438e11, 2.3091072319e02], + [5.6619349259e11, 2.3084920287e02], + [5.6999251415e11, 2.3078335619e02], + [5.7379155310e11, 2.3113596330e02], + [5.7759059204e11, 2.3069988094e02], + [5.8138963104e11, 2.3127273113e02], + [5.8518866978e11, 2.3172034584e02], + [5.8898770984e11, 2.3083009711e02], + [5.9278674971e11, 2.3080842049e02], + [5.9658578920e11, 2.3176370343e02], + [6.0038482804e11, 2.3071235199e02], + [6.0418386709e11, 2.3213900014e02], + [6.0798290658e11, 2.3332448570e02], + [6.1178194561e11, 2.3275468168e02], + [6.1558098586e11, 2.3028775311e02], + [6.1938002497e11, 2.3099002246e02], + [6.2317906405e11, 2.3169044278e02], + [6.2697810321e11, 2.3387964670e02], + [6.3077714335e11, 2.3211138392e02], + [6.3457618280e11, 2.3106450194e02], + [6.3837522051e11, 2.3392878354e02], + [6.4217426058e11, 2.3260742338e02], + [6.4597330044e11, 2.3212726336e02], + [6.4977233953e11, 2.3355375214e02], + [6.5357137911e11, 2.3370492744e02], + [6.5737041818e11, 2.3274859312e02], + [6.6116945832e11, 2.3454963160e02], + [6.6496849695e11, 2.3314306687e02], + [6.6876753631e11, 2.3387508611e02], + [6.7256657578e11, 2.3164114924e02], + [6.7636561494e11, 2.3335876240e02], + [6.8016465549e11, 2.3259160444e02], + [6.8396369350e11, 2.3472844839e02], + [6.8776273363e11, 2.3402051674e02], + [6.9156177298e11, 2.3574191998e02], + [6.9536081174e11, 2.3853930635e02], + [6.9915984917e11, 2.3440978885e02], + [7.0295889084e11, 2.3613333429e02], + [7.0675792895e11, 2.3783556448e02], + [7.1055696937e11, 2.3596357613e02], + [7.1435600664e11, 2.4035834255e02], + [7.1815504705e11, 2.3882352229e02], + [7.2195408724e11, 2.4316494619e02], + [7.2575312535e11, 2.4125740709e02], + [7.2955216606e11, 2.3700425464e02], + [7.3335120460e11, 2.4198517463e02], + [7.3715024347e11, 2.4290543902e02], + [7.4094928544e11, 2.3961167246e02], + [7.4474832211e11, 2.4162098068e02], + [7.4854736178e11, 2.4791162259e02], + [7.5234640124e11, 2.4706576073e02], + [7.5614544041e11, 2.4682659631e02], + [7.5994447978e11, 2.4839164423e02], + [7.6374351905e11, 2.5108968132e02], + [7.6754255785e11, 2.5344371602e02], + [7.7134159724e11, 2.6063943014e02], + [7.7514063682e11, 2.4953670969e02], + [7.7893967570e11, 2.5865807123e02], + [7.8273871453e11, 2.6094569799e02], + [7.8653775458e11, 2.6653191005e02], + [7.9033679421e11, 2.6909497473e02], + [7.9413583349e11, 2.7149400968e02], + [7.9793487494e11, 2.7245403781e02], + [8.0173391173e11, 2.8131908812e02], + [8.0553295106e11, 2.9112192412e02], + [8.0933199067e11, 2.9245070076e02], + [8.1313102998e11, 2.8235347505e02], + [8.1693006950e11, 2.9033406803e02], + [8.2072910826e11, 3.0580905927e02], + [8.2452814772e11, 3.1147292572e02], + [8.2832723864e11, 3.0812470431e02], + [8.3212622721e11, 3.4879506066e02], + [8.3592526617e11, 3.2790815984e02], + [8.3972430401e11, 3.6465536216e02], + [8.4352334347e11, 3.9066552303e02], + ], + dtype=torch.float64, + ) + + mem: torch.Tensor = cmf_icvr[:, 0] + cost: torch.Tensor = cmf_icvr[:, 1] + + def f(x: float) -> float: + return TestLuusJaakolaSearch.interp(torch.tensor([x]), mem, cost).item() + + results = TestLuusJaakolaSearch.evaluate(mem.min().item(), mem.max().item(), f) + want = torch.tensor( + [ + [5.370294e11, 2.314406e02], + [5.426136e11, 2.313041e02], + [5.908549e11, 2.308194e02], + [5.755533e11, 2.309337e02], + [6.184178e11, 2.308121e02], + ], + ) + torch.testing.assert_close(results, want) diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index 4594396f8..c8f96da03 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -5,9 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import operator from functools import reduce -from typing import Any, cast, Dict, Iterable, List, Optional, Type, Union +from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type, Union import torch from torchrec.distributed.planner.types import Perf, ShardingOption, Storage @@ -163,3 +164,82 @@ def next(self, prior_result: bool) -> Optional[int]: def _mid(self) -> int: return self.left + ((self.right - self.left) // 2) + + +class LuusJaakolaSearch: + """Implements a clamped variant of Luus Jaakola search. + + See https://en.wikipedia.org/wiki/Luus-Jaakola. + """ + + def __init__(self, A: float, B: float, max_iterations: int, seed: int = 42) -> None: + self.left = A + self.right = B + self.iteration = -1 + self.max_iterations = max_iterations + + self.gen = torch.Generator() + self.gen.manual_seed(seed) + + self.x: float = self.uniform(self.left, self.right) + self.fx: float = 0.0 + self.y: float = math.nan + self.fleft: Optional[float] = None + self.fright: Optional[float] = None + self.d: float = self.right - self.left + + def clamp(self, x: float) -> float: + "Clamp x into range [left, right]" + if x < self.left: + return self.left + if x > self.right: + return self.right + return x + + def uniform(self, A: float, B: float) -> float: + "Return a random uniform position in range [A,B]." + u = torch.rand(1, generator=self.gen).item() + return A + (B - A) * u + + def next(self, fy: float) -> Optional[float]: + """Return the next probe point 'y' to evaluate, given the previous result. + + The first time around fy is ignored. Subsequent invocations should provide the + result of evaluating the function being minimized, i.e. f(y). + + Returns None when the maximum number of iterations has been reached. + """ + self.iteration += 1 + if self.iteration == 0: + return self.x + elif self.iteration == 1: + self.fx = fy + elif self.iteration == self.max_iterations: + return None + elif fy <= self.fx: + self.x = self.y + self.fx = fy + self.d = 0.95 * self.d + + if self.y == self.left: + self.fleft = fy + elif self.y == self.right: + self.fright = fy + + while True: + a = self.uniform(-self.d, self.d) + y = self.clamp(self.x + a) + # Unlike standard Luus-Jaakola, we don't want to explore outside of our bounds. + # Clamping can cause us to explore the boundary multiple times, so we + # remember if we already know the boundary cost and request a new sample if + # we do. + if y == self.left and self.fleft is not None: + continue + if y == self.right and self.fright is not None: + continue + self.y = y + return self.y + + def best(self) -> Tuple[float, float]: + "Return the best position so far, and its associated cost." + return self.x, self.fx