From 3e2e290d39077dd6c052389b381f900b88bbcaa9 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 27 Aug 2024 08:59:06 -0700 Subject: [PATCH] Consolidate `BenchmarkMetric` functionality in one file (#2710) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2710 There is no need to have multiple files here anymore now that a lot of BenchmarkMetric functionality has disappeared. D61736027 follows up by moving `benchmark/metrics/benchmark.py` to `benchmark/benchmark_metric.py` and moving the corresponding test file. Reviewed By: Balandat Differential Revision: D61432000 fbshipit-source-id: c45b78e5b79bce827adf9cc40e7b805a4e5e318b --- ax/benchmark/metrics/benchmark.py | 66 +++++++++++++++++++++---- ax/benchmark/metrics/utils.py | 81 ------------------------------- 2 files changed, 56 insertions(+), 91 deletions(-) delete mode 100644 ax/benchmark/metrics/utils.py diff --git a/ax/benchmark/metrics/benchmark.py b/ax/benchmark/metrics/benchmark.py index 5e854ff29e7..139f1e2e64e 100644 --- a/ax/benchmark/metrics/benchmark.py +++ b/ax/benchmark/metrics/benchmark.py @@ -5,13 +5,14 @@ # pyre-strict -from __future__ import annotations - from typing import Any, Optional -from ax.benchmark.metrics.utils import _fetch_trial_data +import pandas as pd from ax.core.base_trial import BaseTrial -from ax.core.metric import Metric, MetricFetchResult + +from ax.core.data import Data +from ax.core.metric import Metric, MetricFetchE, MetricFetchResult +from ax.utils.common.result import Err, Ok class BenchmarkMetric(Metric): @@ -48,14 +49,59 @@ def __init__( self.outcome_index = outcome_index def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: + """ + Args: + trial: The trial from which to fetch data. + kwargs: Unsupported and will raise an exception. + + Returns: + A MetricFetchResult containing the data for the requested metric. + """ if len(kwargs) > 0: raise NotImplementedError( f"Arguments {set(kwargs)} are not supported in " f"{self.__class__.__name__}.fetch_trial_data." ) - return _fetch_trial_data( - trial=trial, - metric_name=self.name, - outcome_index=self.outcome_index, - include_noise_sd=self.observe_noise_sd, - ) + outcome_index = self.outcome_index + if outcome_index is None: + # Look up the index based on the outcome name under which we track the data + # as part of `run_metadata`. + outcome_names = trial.run_metadata.get("outcome_names") + if outcome_names is None: + raise RuntimeError( + "Trials' `run_metadata` must contain `outcome_names` if " + "no `outcome_index` is provided." + ) + outcome_index = outcome_names.index(self.name) + + try: + arm_names = list(trial.arms_by_name.keys()) + all_Ys = trial.run_metadata["Ys"] + Ys = [all_Ys[arm_name][outcome_index] for arm_name in arm_names] + + if self.observe_noise_sd: + stdvs = [ + trial.run_metadata["Ystds"][arm_name][outcome_index] + for arm_name in arm_names + ] + else: + stdvs = [float("nan")] * len(Ys) + + df = pd.DataFrame( + { + "arm_name": arm_names, + "metric_name": self.name, + "mean": Ys, + "sem": stdvs, + "trial_index": trial.index, + } + ) + return Ok(value=Data(df=df)) + + except Exception as e: + return Err( + MetricFetchE( + message=f"Failed to obtain data for trial {trial.index}", + exception=e, + ) + ) diff --git a/ax/benchmark/metrics/utils.py b/ax/benchmark/metrics/utils.py deleted file mode 100644 index 0b55df0aca0..00000000000 --- a/ax/benchmark/metrics/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Optional - -import pandas as pd -from ax.core.base_trial import BaseTrial -from ax.core.data import Data -from ax.core.metric import MetricFetchE, MetricFetchResult -from ax.utils.common.result import Err, Ok - - -def _fetch_trial_data( - trial: BaseTrial, - metric_name: str, - outcome_index: Optional[int] = None, - include_noise_sd: bool = True, -) -> MetricFetchResult: - """ - Args: - trial: The trial from which to fetch data. - metric_name: Name of the metric to fetch. If `metric_index` is not specified, - this is used to retrieve the index (of the outcomes) from the - `outcome_names` dict in a trial's `run_metadata`. If `metric_index` is - specified, this is simply the name of the metric. - outcome_index: The index (in the last dimension) of the `Ys` and - `Ystds` lists of outcomes stored by the respective runner in the trial's - `run_metadata`. If omitted, `run_metadata` must contain a `outcome_names` - list of names in the same order as the outcomes that will be used to - determine the index. - include_noise_sd: Whether to include noise standard deviation in the returned - data. - - Returns: - A MetricFetchResult containing the data for the requested metric. - """ - if outcome_index is None: - # Look up the index based on the outcome name under which we track the data - # as part of `run_metadata`. - outcome_names = trial.run_metadata.get("outcome_names") - if outcome_names is None: - raise RuntimeError( - "Trials' `run_metadata` must contain `outcome_names` if " - "no `outcome_index` is provided." - ) - outcome_index = outcome_names.index(metric_name) - - try: - arm_names = list(trial.arms_by_name.keys()) - all_Ys = trial.run_metadata["Ys"] - Ys = [all_Ys[arm_name][outcome_index] for arm_name in arm_names] - - if include_noise_sd: - stdvs = [ - trial.run_metadata["Ystds"][arm_name][outcome_index] - for arm_name in arm_names - ] - else: - stdvs = [float("nan")] * len(Ys) - - df = pd.DataFrame( - { - "arm_name": arm_names, - "metric_name": metric_name, - "mean": Ys, - "sem": stdvs, - "trial_index": trial.index, - } - ) - return Ok(value=Data(df=df)) - - except Exception as e: - return Err( - MetricFetchE( - message=f"Failed to obtain data for trial {trial.index}", exception=e - ) - )