Skip to content

Commit

Permalink
Change early stopping estimated savings calculation (#1944)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1944

The old method of estimation was incredibly optimistic and compared the cost of every trial to the cost of the most expensive completed trial. This was giving us some silly results, ex. 1 of 9 trials stopped yielding 47% savings.

The new method assumes any early stopped trial would have took the mean resources of completed trials had it completed successfully, with some thresholding to set negative savings estimates to zero savings.

Reviewed By: SebastianAment

Differential Revision: D50841014

fbshipit-source-id: b2cd3e09b08cacbddd3a3363dd4506074341db8d
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Nov 2, 2023
1 parent f8f9a41 commit 47cded9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
48 changes: 28 additions & 20 deletions ax/early_stopping/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from collections import defaultdict
from logging import Logger
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -123,7 +122,7 @@ def estimate_early_stopping_savings(
map_key: Optional[str] = None,
) -> float:
"""Estimate resource savings due to early stopping by considering
COMPLETED and EARLY_STOPPED trials. First, use the maximum of final
COMPLETED and EARLY_STOPPED trials. First, use the mean of final
progressions of the set completed trials as a benchmark for the
length of a single trial. The savings is then estimated as:
Expand All @@ -150,25 +149,34 @@ def estimate_early_stopping_savings(
else:
return 0

completed_trial_idcs = experiment.trial_indices_by_status[TrialStatus.COMPLETED]

total_resources = (
map_data.map_df[["trial_index", step_key]].groupby("trial_index").max().sum()
# Get final number of steps of each trial
trial_resources = (
map_data.map_df[["trial_index", step_key]]
.groupby("trial_index")
.max()
.reset_index()
)

completed_df = map_data.map_df[
(map_data.map_df["trial_index"].isin(completed_trial_idcs))
early_stopped_trial_idcs = experiment.trial_indices_by_status[
TrialStatus.EARLY_STOPPED
]
single_trial_resources = (
completed_df[["trial_index", step_key]].groupby("trial_index").max().max()
)

savings: float = (
1 - total_resources / (experiment.num_trials * single_trial_resources)
).item()

if math.isnan(savings):
# NaN implies division by zero, which should be interpreted as no savings
return 0
completed_trial_idcs = experiment.trial_indices_by_status[TrialStatus.COMPLETED]

return savings
# Assume that any early stopped trial would have had the mean number of steps of
# the completed trials
mean_completed_trial_resources = trial_resources[
trial_resources["trial_index"].isin(completed_trial_idcs)
][step_key].mean()

# Calculate the steps saved per early stopped trial. If savings are estimated to be
# negative assume no savings
stopped_trial_resources = trial_resources[
trial_resources["trial_index"].isin(early_stopped_trial_idcs)
][step_key]
saved_trial_resources = (
mean_completed_trial_resources - stopped_trial_resources
).clip(0)

# Return the ratio of the total saved resources over the total resources used plus
# the total saved resources
return saved_trial_resources.sum() / trial_resources[step_key].sum()
2 changes: 1 addition & 1 deletion ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def should_stop_trials_early(
self.assertEqual(len(fetched_data.map_df), expected_num_rows)
self.assertEqual(len(looked_up_data.map_df), expected_num_rows)

self.assertAlmostEqual(scheduler.estimate_early_stopping_savings(), 2 / 3)
self.assertAlmostEqual(scheduler.estimate_early_stopping_savings(), 0.5)

def test_run_trials_in_batches(self) -> None:
# TODO[drfreund]: Use `Runner` instead when `poll_available_capacity`
Expand Down

0 comments on commit 47cded9

Please sign in to comment.