From 47cded966eedd62a70702eedb54cc24a91772ba1 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 1 Nov 2023 18:27:16 -0700 Subject: [PATCH] Change early stopping estimated savings calculation (#1944) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1944 The old method of estimation was incredibly optimistic and compared the cost of every trial to the cost of the most expensive completed trial. This was giving us some silly results, ex. 1 of 9 trials stopped yielding 47% savings. The new method assumes any early stopped trial would have took the mean resources of completed trials had it completed successfully, with some thresholding to set negative savings estimates to zero savings. Reviewed By: SebastianAment Differential Revision: D50841014 fbshipit-source-id: b2cd3e09b08cacbddd3a3363dd4506074341db8d --- ax/early_stopping/utils.py | 48 +++++++++++++++++------------- ax/service/tests/test_scheduler.py | 2 +- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/ax/early_stopping/utils.py b/ax/early_stopping/utils.py index d45b59204b3..a81d9dde238 100644 --- a/ax/early_stopping/utils.py +++ b/ax/early_stopping/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import math from collections import defaultdict from logging import Logger from typing import Dict, List, Optional, Tuple @@ -123,7 +122,7 @@ def estimate_early_stopping_savings( map_key: Optional[str] = None, ) -> float: """Estimate resource savings due to early stopping by considering - COMPLETED and EARLY_STOPPED trials. First, use the maximum of final + COMPLETED and EARLY_STOPPED trials. First, use the mean of final progressions of the set completed trials as a benchmark for the length of a single trial. The savings is then estimated as: @@ -150,25 +149,34 @@ def estimate_early_stopping_savings( else: return 0 - completed_trial_idcs = experiment.trial_indices_by_status[TrialStatus.COMPLETED] - - total_resources = ( - map_data.map_df[["trial_index", step_key]].groupby("trial_index").max().sum() + # Get final number of steps of each trial + trial_resources = ( + map_data.map_df[["trial_index", step_key]] + .groupby("trial_index") + .max() + .reset_index() ) - completed_df = map_data.map_df[ - (map_data.map_df["trial_index"].isin(completed_trial_idcs)) + early_stopped_trial_idcs = experiment.trial_indices_by_status[ + TrialStatus.EARLY_STOPPED ] - single_trial_resources = ( - completed_df[["trial_index", step_key]].groupby("trial_index").max().max() - ) - - savings: float = ( - 1 - total_resources / (experiment.num_trials * single_trial_resources) - ).item() - - if math.isnan(savings): - # NaN implies division by zero, which should be interpreted as no savings - return 0 + completed_trial_idcs = experiment.trial_indices_by_status[TrialStatus.COMPLETED] - return savings + # Assume that any early stopped trial would have had the mean number of steps of + # the completed trials + mean_completed_trial_resources = trial_resources[ + trial_resources["trial_index"].isin(completed_trial_idcs) + ][step_key].mean() + + # Calculate the steps saved per early stopped trial. If savings are estimated to be + # negative assume no savings + stopped_trial_resources = trial_resources[ + trial_resources["trial_index"].isin(early_stopped_trial_idcs) + ][step_key] + saved_trial_resources = ( + mean_completed_trial_resources - stopped_trial_resources + ).clip(0) + + # Return the ratio of the total saved resources over the total resources used plus + # the total saved resources + return saved_trial_resources.sum() / trial_resources[step_key].sum() diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 593a64e4aa1..fc35fcc5d10 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -952,7 +952,7 @@ def should_stop_trials_early( self.assertEqual(len(fetched_data.map_df), expected_num_rows) self.assertEqual(len(looked_up_data.map_df), expected_num_rows) - self.assertAlmostEqual(scheduler.estimate_early_stopping_savings(), 2 / 3) + self.assertAlmostEqual(scheduler.estimate_early_stopping_savings(), 0.5) def test_run_trials_in_batches(self) -> None: # TODO[drfreund]: Use `Runner` instead when `poll_available_capacity`