From 98dbb397e093d0e0df43dea7940405fb426e22a9 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 16 Oct 2024 14:41:08 -0700 Subject: [PATCH] CrossValidationPlot analysis (#2861) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2861 New CrossValidationPlot using the Analysis framework. Plots CVResults of the current model on the provided GS with CI if available. Forces a square aspect ratio (but does not force any specific resolution). Hover data for x, y value with CI and arm name. Displays a gray dashed line on y=x. This diff also refactors _select_metric out of Parallel Coordinates into a utils file so we can use it here to select the objective if no metric_name is provided. NOTE: We are intentionally leaving the two dropdowns out: the "metric" dropdown because we want to generate many CVs separately and combine them more gracefully in the UI (cc Cesar-Cardoso), the latter because we dont actually want users to be able to view this without the confidence intervals (it is not useful). Reviewed By: danielcohenlive Differential Revision: D64056891 fbshipit-source-id: 01553fc9438cff9eb539ddf0697dca130a422140 --- ax/analysis/plotly/__init__.py | 2 + ax/analysis/plotly/cross_validation.py | 213 ++++++++++++++++++ ax/analysis/plotly/parallel_coordinates.py | 25 +- .../plotly/tests/test_cross_validation.py | 61 +++++ .../plotly/tests/test_parallel_coordinates.py | 12 +- ax/analysis/plotly/utils.py | 24 +- sphinx/source/analysis.rst | 8 + 7 files changed, 316 insertions(+), 29 deletions(-) create mode 100644 ax/analysis/plotly/cross_validation.py create mode 100644 ax/analysis/plotly/tests/test_cross_validation.py diff --git a/ax/analysis/plotly/__init__.py b/ax/analysis/plotly/__init__.py index 59be369b41c..078ad5b594a 100644 --- a/ax/analysis/plotly/__init__.py +++ b/ax/analysis/plotly/__init__.py @@ -5,11 +5,13 @@ # pyre-strict +from ax.analysis.plotly.cross_validation import CrossValidationPlot from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.scatter import ScatterPlot __all__ = [ + "CrossValidationPlot", "PlotlyAnalysis", "PlotlyAnalysisCard", "ParallelCoordinatesPlot", diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py new file mode 100644 index 00000000000..34f7b55aa08 --- /dev/null +++ b/ax/analysis/plotly/cross_validation.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.utils import select_metric +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.exceptions.core import UserInputError +from ax.modelbridge.cross_validation import cross_validate +from ax.modelbridge.generation_strategy import GenerationStrategy +from plotly import express as px, graph_objects as go +from pyre_extensions import assert_is_instance, none_throws + + +class CrossValidationPlot(PlotlyAnalysis): + """ + Plotly Scatter plot for cross validation for model predictions using the current + model on the GenerationStrategy. This plot is useful for understanding how well + the model is able to predict out-of-sample which in turn is indicative of its + ability to suggest valuable candidates. + + Splits the model's training data into train/test folds and makes + out-of-sample predictions on the test folds. + + A well fit model will have points clustered around the y=x line, and a model with + poor fit may have points in a horizontal band in the center of the plot + indicating a tendency to predict the observed mean of the specificed metric for + all arms. + + The DataFrame computed will contain one row per arm and the following columns: + - arm_name: The name of the arm + - observed: The observed mean of the metric specified + - observed_sem: The SEM of the observed mean of the metric specified + - predicted: The predicted mean of the metric specified + - predicted_sem: The SEM of the predicted mean of the metric specified + """ + + def __init__( + self, metric_name: str | None = None, folds: int = -1, untransform: bool = True + ) -> None: + """ + Args: + metric_name: The name of the metric to plot. If not specified the objective + will be used. Note that the metric cannot be inferred for + multi-objective or scalarized-objective experiments. + folds: Number of subsamples to partition observations into. Use -1 for + leave-one-out cross validation. + untransform: Whether to untransform the model predictions before cross + validating. Models are trained on transformed data, and candidate + generation is performed in the transformed space. Computing the model + quality metric based on the cross-validation results in the + untransformed space may not be representative of the model that + is actually used for candidate generation in case of non-invertible + transforms, e.g., Winsorize or LogY. While the model in the + transformed space may not be representative of the original data in + regions where outliers have been removed, we have found it to better + reflect the how good the model used for candidate generation actually + is. + """ + + self.metric_name = metric_name + self.folds = folds + self.untransform = untransform + + def compute( + self, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, + ) -> PlotlyAnalysisCard: + if generation_strategy is None: + raise UserInputError("CrossValidation requires a GenerationStrategy") + + metric_name = self.metric_name or select_metric( + experiment=generation_strategy.experiment + ) + + df = _prepare_data( + # CrossValidationPlot requires a native Ax GenerationStrategy and cannot be + # used with a GenerationStrategyInterface. + generation_strategy=assert_is_instance( + generation_strategy, GenerationStrategy + ), + metric_name=metric_name, + folds=self.folds, + untransform=self.untransform, + ) + fig = _prepare_plot(df=df, metric_name=metric_name) + + k_folds_substring = f"{self.folds}-fold" if self.folds > 0 else "leave-one-out" + + # Nudge the priority if the metric is important to the experiment + if ( + experiment is not None + and (optimization_config := experiment.optimization_config) is not None + and (objective := optimization_config.objective) is not None + and metric_name in objective.metric_names + ): + nudge = 2 + elif ( + experiment is not None + and (optimization_config := experiment.optimization_config) is not None + and metric_name in optimization_config.outcome_constraints + ): + nudge = 1 + else: + nudge = 0 + + return self._create_plotly_analysis_card( + title=f"Cross Validation for {metric_name}", + subtitle=f"Out-of-sample predictions using {k_folds_substring} CV", + level=AnalysisCardLevel.LOW.value + nudge, + df=df, + fig=fig, + ) + + +def _prepare_data( + generation_strategy: GenerationStrategy, + metric_name: str, + folds: int, + untransform: bool, +) -> pd.DataFrame: + # If model is not fit already, fit it + if generation_strategy.model is None: + generation_strategy._fit_current_model(None) + + cv_results = cross_validate( + model=none_throws(generation_strategy.model), + folds=folds, + untransform=untransform, + ) + + records = [] + for observed, predicted in cv_results: + for i in range(len(observed.data.metric_names)): + # Find the index of the metric we want to plot + if not ( + observed.data.metric_names[i] == metric_name + and predicted.metric_names[i] == metric_name + ): + continue + + record = { + "arm_name": observed.arm_name, + "observed": observed.data.means[i], + "predicted": predicted.means[i], + # Take the square root of the the SEM to get the standard deviation + "observed_sem": observed.data.covariance[i][i], + "predicted_sem": predicted.covariance[i][i], + } + records.append(record) + break + + return pd.DataFrame.from_records(records) + + +def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure: + fig = px.scatter( + df, + x="observed", + y="predicted", + error_x="observed_sem", + error_y="predicted_sem", + labels={ + "observed": f"Observed {metric_name}", + "predicted": f"Predicted {metric_name}", + }, + hover_data=["arm_name", "observed", "predicted"], + ) + + # Add a gray dashed line at y=x starting and ending just outside of the region of + # interest for reference. A well fit model should have points clustered around this + # line. + lower_bound = ( + min( + (df["observed"] - df["observed_sem"].fillna(0)).min(), + (df["predicted"] - df["predicted_sem"].fillna(0)).min(), + ) + * 0.99 + ) + upper_bound = ( + max( + (df["observed"] + df["observed_sem"].fillna(0)).max(), + (df["predicted"] + df["predicted_sem"].fillna(0)).max(), + ) + * 1.01 + ) + + fig.add_shape( + type="line", + x0=lower_bound, + y0=lower_bound, + x1=upper_bound, + y1=upper_bound, + line={"color": "gray", "dash": "dot"}, + ) + + # Force plot to display as a square + fig.update_xaxes(range=[lower_bound, upper_bound], constrain="domain") + fig.update_yaxes( + scaleanchor="x", + scaleratio=1, + ) + + return fig diff --git a/ax/analysis/plotly/parallel_coordinates.py b/ax/analysis/plotly/parallel_coordinates.py index a861ebf6685..f3f73e2ee1d 100644 --- a/ax/analysis/plotly/parallel_coordinates.py +++ b/ax/analysis/plotly/parallel_coordinates.py @@ -12,10 +12,10 @@ from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface -from ax.core.objective import MultiObjective, ScalarizedObjective -from ax.exceptions.core import UnsupportedError, UserInputError +from ax.exceptions.core import UserInputError from plotly import graph_objects as go @@ -50,7 +50,7 @@ def compute( if experiment is None: raise UserInputError("ParallelCoordinatesPlot requires an Experiment") - metric_name = self.metric_name or _select_metric(experiment=experiment) + metric_name = self.metric_name or select_metric(experiment=experiment) df = _prepare_data(experiment=experiment, metric=metric_name) fig = _prepare_plot(df=df, metric_name=metric_name) @@ -113,25 +113,6 @@ def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure: ) -def _select_metric(experiment: Experiment) -> str: - if experiment.optimization_config is None: - raise ValueError( - "Cannot infer metric to plot from Experiment without OptimizationConfig" - ) - objective = experiment.optimization_config.objective - if isinstance(objective, MultiObjective): - raise UnsupportedError( - "Cannot infer metric to plot from MultiObjective, please " - "specify a metric" - ) - if isinstance(objective, ScalarizedObjective): - raise UnsupportedError( - "Cannot infer metric to plot from ScalarizedObjective, please " - "specify a metric" - ) - return experiment.optimization_config.objective.metric.name - - def _find_mean_by_arm_name( df: pd.DataFrame, arm_name: str, diff --git a/ax/analysis/plotly/tests/test_cross_validation.py b/ax/analysis/plotly/tests/test_cross_validation.py new file mode 100644 index 00000000000..c3b7d17ea7d --- /dev/null +++ b/ax/analysis/plotly/tests/test_cross_validation.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.cross_validation import CrossValidationPlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties +from ax.utils.common.testutils import TestCase +from ax.utils.testing.mock import fast_botorch_optimize + + +class TestCrossValidationPlot(TestCase): + @fast_botorch_optimize + def test_compute(self) -> None: + client = AxClient() + client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + } + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, + ) + + for _ in range(10): + parameterization, trial_index = client.get_next_trial() + client.complete_trial( + trial_index=trial_index, raw_data={"bar": parameterization["x"] ** 2} + ) + + analysis = CrossValidationPlot(metric_name="bar") + + # Test that it fails if no GenerationStrategy is provided + with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"): + analysis.compute() + + card = analysis.compute(generation_strategy=client.generation_strategy) + self.assertEqual( + card.name, + "CrossValidationPlot", + ) + self.assertEqual(card.title, "Cross Validation for bar") + self.assertEqual( + card.subtitle, + "Out-of-sample predictions using leave-one-out CV", + ) + self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual( + {*card.df.columns}, + {"arm_name", "observed", "observed_sem", "predicted", "predicted_sem"}, + ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") diff --git a/ax/analysis/plotly/tests/test_parallel_coordinates.py b/ax/analysis/plotly/tests/test_parallel_coordinates.py index d92dc89033e..1ecb178cb62 100644 --- a/ax/analysis/plotly/tests/test_parallel_coordinates.py +++ b/ax/analysis/plotly/tests/test_parallel_coordinates.py @@ -9,9 +9,9 @@ from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.parallel_coordinates import ( _get_parameter_dimension, - _select_metric, ParallelCoordinatesPlot, ) +from ax.analysis.plotly.utils import select_metric from ax.exceptions.core import UnsupportedError, UserInputError from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -44,7 +44,7 @@ def test_compute(self) -> None: analysis_no_metric = ParallelCoordinatesPlot() _ = analysis_no_metric.compute(experiment=experiment) - def test_select_metric(self) -> None: + def testselect_metric(self) -> None: experiment = get_branin_experiment() experiment_no_optimization_config = get_branin_experiment( has_optimization_config=False @@ -54,16 +54,16 @@ def test_select_metric(self) -> None: get_experiment_with_scalarized_objective_and_outcome_constraint() ) - self.assertEqual(_select_metric(experiment=experiment), "branin") + self.assertEqual(select_metric(experiment=experiment), "branin") with self.assertRaisesRegex(ValueError, "OptimizationConfig"): - _select_metric(experiment=experiment_no_optimization_config) + select_metric(experiment=experiment_no_optimization_config) with self.assertRaisesRegex(UnsupportedError, "MultiObjective"): - _select_metric(experiment=experiment_multi_objective) + select_metric(experiment=experiment_multi_objective) with self.assertRaisesRegex(UnsupportedError, "ScalarizedObjective"): - _select_metric(experiment=experiment_scalarized_objective) + select_metric(experiment=experiment_scalarized_objective) def test_get_parameter_dimension(self) -> None: range_series = pd.Series([0, 1, 2, 3], name="range") diff --git a/ax/analysis/plotly/utils.py b/ax/analysis/plotly/utils.py index 5cd3377095b..0f40cc55a52 100644 --- a/ax/analysis/plotly/utils.py +++ b/ax/analysis/plotly/utils.py @@ -5,8 +5,10 @@ import numpy as np import torch +from ax.core.experiment import Experiment +from ax.core.objective import MultiObjective, ScalarizedObjective from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint -from ax.exceptions.core import UserInputError +from ax.exceptions.core import UnsupportedError, UserInputError from ax.modelbridge.base import ModelBridge from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds @@ -134,3 +136,23 @@ def is_predictive(model: ModelBridge) -> bool: except Exception: return True return True + + +def select_metric(experiment: Experiment) -> str: + """Select the most relevant metric to plot from an Experiment.""" + if experiment.optimization_config is None: + raise ValueError( + "Cannot infer metric to plot from Experiment without OptimizationConfig" + ) + objective = experiment.optimization_config.objective + if isinstance(objective, MultiObjective): + raise UnsupportedError( + "Cannot infer metric to plot from MultiObjective, please " + "specify a metric" + ) + if isinstance(objective, ScalarizedObjective): + raise UnsupportedError( + "Cannot infer metric to plot from ScalarizedObjective, please " + "specify a metric" + ) + return experiment.optimization_config.objective.metric.name diff --git a/sphinx/source/analysis.rst b/sphinx/source/analysis.rst index 82352a4c21f..af1b146cd1c 100644 --- a/sphinx/source/analysis.rst +++ b/sphinx/source/analysis.rst @@ -15,6 +15,14 @@ Analysis :undoc-members: :show-inheritance: +Cross Validation Analysis +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.analysis.plotly.cross_validation + :members: + :undoc-members: + :show-inheritance: + Markdown Analysis ~~~~~~~~~~~~~~~~~