-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CrossValidationPlot analysis (#2861)
Summary: Pull Request resolved: #2861 New CrossValidationPlot using the Analysis framework. Plots CVResults of the current model on the provided GS with CI if available. Forces a square aspect ratio (but does not force any specific resolution). Hover data for x, y value with CI and arm name. Displays a gray dashed line on y=x. This diff also refactors _select_metric out of Parallel Coordinates into a utils file so we can use it here to select the objective if no metric_name is provided. NOTE: We are intentionally leaving the two dropdowns out: the "metric" dropdown because we want to generate many CVs separately and combine them more gracefully in the UI (cc Cesar-Cardoso), the latter because we dont actually want users to be able to view this without the confidence intervals (it is not useful). Reviewed By: danielcohenlive Differential Revision: D64056891 fbshipit-source-id: 01553fc9438cff9eb539ddf0697dca130a422140
- Loading branch information
1 parent
cfb0e01
commit 98dbb39
Showing
7 changed files
with
316 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
|
||
import pandas as pd | ||
from ax.analysis.analysis import AnalysisCardLevel | ||
|
||
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard | ||
from ax.analysis.plotly.utils import select_metric | ||
from ax.core.experiment import Experiment | ||
from ax.core.generation_strategy_interface import GenerationStrategyInterface | ||
from ax.exceptions.core import UserInputError | ||
from ax.modelbridge.cross_validation import cross_validate | ||
from ax.modelbridge.generation_strategy import GenerationStrategy | ||
from plotly import express as px, graph_objects as go | ||
from pyre_extensions import assert_is_instance, none_throws | ||
|
||
|
||
class CrossValidationPlot(PlotlyAnalysis): | ||
""" | ||
Plotly Scatter plot for cross validation for model predictions using the current | ||
model on the GenerationStrategy. This plot is useful for understanding how well | ||
the model is able to predict out-of-sample which in turn is indicative of its | ||
ability to suggest valuable candidates. | ||
Splits the model's training data into train/test folds and makes | ||
out-of-sample predictions on the test folds. | ||
A well fit model will have points clustered around the y=x line, and a model with | ||
poor fit may have points in a horizontal band in the center of the plot | ||
indicating a tendency to predict the observed mean of the specificed metric for | ||
all arms. | ||
The DataFrame computed will contain one row per arm and the following columns: | ||
- arm_name: The name of the arm | ||
- observed: The observed mean of the metric specified | ||
- observed_sem: The SEM of the observed mean of the metric specified | ||
- predicted: The predicted mean of the metric specified | ||
- predicted_sem: The SEM of the predicted mean of the metric specified | ||
""" | ||
|
||
def __init__( | ||
self, metric_name: str | None = None, folds: int = -1, untransform: bool = True | ||
) -> None: | ||
""" | ||
Args: | ||
metric_name: The name of the metric to plot. If not specified the objective | ||
will be used. Note that the metric cannot be inferred for | ||
multi-objective or scalarized-objective experiments. | ||
folds: Number of subsamples to partition observations into. Use -1 for | ||
leave-one-out cross validation. | ||
untransform: Whether to untransform the model predictions before cross | ||
validating. Models are trained on transformed data, and candidate | ||
generation is performed in the transformed space. Computing the model | ||
quality metric based on the cross-validation results in the | ||
untransformed space may not be representative of the model that | ||
is actually used for candidate generation in case of non-invertible | ||
transforms, e.g., Winsorize or LogY. While the model in the | ||
transformed space may not be representative of the original data in | ||
regions where outliers have been removed, we have found it to better | ||
reflect the how good the model used for candidate generation actually | ||
is. | ||
""" | ||
|
||
self.metric_name = metric_name | ||
self.folds = folds | ||
self.untransform = untransform | ||
|
||
def compute( | ||
self, | ||
experiment: Experiment | None = None, | ||
generation_strategy: GenerationStrategyInterface | None = None, | ||
) -> PlotlyAnalysisCard: | ||
if generation_strategy is None: | ||
raise UserInputError("CrossValidation requires a GenerationStrategy") | ||
|
||
metric_name = self.metric_name or select_metric( | ||
experiment=generation_strategy.experiment | ||
) | ||
|
||
df = _prepare_data( | ||
# CrossValidationPlot requires a native Ax GenerationStrategy and cannot be | ||
# used with a GenerationStrategyInterface. | ||
generation_strategy=assert_is_instance( | ||
generation_strategy, GenerationStrategy | ||
), | ||
metric_name=metric_name, | ||
folds=self.folds, | ||
untransform=self.untransform, | ||
) | ||
fig = _prepare_plot(df=df, metric_name=metric_name) | ||
|
||
k_folds_substring = f"{self.folds}-fold" if self.folds > 0 else "leave-one-out" | ||
|
||
# Nudge the priority if the metric is important to the experiment | ||
if ( | ||
experiment is not None | ||
and (optimization_config := experiment.optimization_config) is not None | ||
and (objective := optimization_config.objective) is not None | ||
and metric_name in objective.metric_names | ||
): | ||
nudge = 2 | ||
elif ( | ||
experiment is not None | ||
and (optimization_config := experiment.optimization_config) is not None | ||
and metric_name in optimization_config.outcome_constraints | ||
): | ||
nudge = 1 | ||
else: | ||
nudge = 0 | ||
|
||
return self._create_plotly_analysis_card( | ||
title=f"Cross Validation for {metric_name}", | ||
subtitle=f"Out-of-sample predictions using {k_folds_substring} CV", | ||
level=AnalysisCardLevel.LOW.value + nudge, | ||
df=df, | ||
fig=fig, | ||
) | ||
|
||
|
||
def _prepare_data( | ||
generation_strategy: GenerationStrategy, | ||
metric_name: str, | ||
folds: int, | ||
untransform: bool, | ||
) -> pd.DataFrame: | ||
# If model is not fit already, fit it | ||
if generation_strategy.model is None: | ||
generation_strategy._fit_current_model(None) | ||
|
||
cv_results = cross_validate( | ||
model=none_throws(generation_strategy.model), | ||
folds=folds, | ||
untransform=untransform, | ||
) | ||
|
||
records = [] | ||
for observed, predicted in cv_results: | ||
for i in range(len(observed.data.metric_names)): | ||
# Find the index of the metric we want to plot | ||
if not ( | ||
observed.data.metric_names[i] == metric_name | ||
and predicted.metric_names[i] == metric_name | ||
): | ||
continue | ||
|
||
record = { | ||
"arm_name": observed.arm_name, | ||
"observed": observed.data.means[i], | ||
"predicted": predicted.means[i], | ||
# Take the square root of the the SEM to get the standard deviation | ||
"observed_sem": observed.data.covariance[i][i], | ||
"predicted_sem": predicted.covariance[i][i], | ||
} | ||
records.append(record) | ||
break | ||
|
||
return pd.DataFrame.from_records(records) | ||
|
||
|
||
def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure: | ||
fig = px.scatter( | ||
df, | ||
x="observed", | ||
y="predicted", | ||
error_x="observed_sem", | ||
error_y="predicted_sem", | ||
labels={ | ||
"observed": f"Observed {metric_name}", | ||
"predicted": f"Predicted {metric_name}", | ||
}, | ||
hover_data=["arm_name", "observed", "predicted"], | ||
) | ||
|
||
# Add a gray dashed line at y=x starting and ending just outside of the region of | ||
# interest for reference. A well fit model should have points clustered around this | ||
# line. | ||
lower_bound = ( | ||
min( | ||
(df["observed"] - df["observed_sem"].fillna(0)).min(), | ||
(df["predicted"] - df["predicted_sem"].fillna(0)).min(), | ||
) | ||
* 0.99 | ||
) | ||
upper_bound = ( | ||
max( | ||
(df["observed"] + df["observed_sem"].fillna(0)).max(), | ||
(df["predicted"] + df["predicted_sem"].fillna(0)).max(), | ||
) | ||
* 1.01 | ||
) | ||
|
||
fig.add_shape( | ||
type="line", | ||
x0=lower_bound, | ||
y0=lower_bound, | ||
x1=upper_bound, | ||
y1=upper_bound, | ||
line={"color": "gray", "dash": "dot"}, | ||
) | ||
|
||
# Force plot to display as a square | ||
fig.update_xaxes(range=[lower_bound, upper_bound], constrain="domain") | ||
fig.update_yaxes( | ||
scaleanchor="x", | ||
scaleratio=1, | ||
) | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from ax.analysis.analysis import AnalysisCardLevel | ||
from ax.analysis.plotly.cross_validation import CrossValidationPlot | ||
from ax.exceptions.core import UserInputError | ||
from ax.service.ax_client import AxClient, ObjectiveProperties | ||
from ax.utils.common.testutils import TestCase | ||
from ax.utils.testing.mock import fast_botorch_optimize | ||
|
||
|
||
class TestCrossValidationPlot(TestCase): | ||
@fast_botorch_optimize | ||
def test_compute(self) -> None: | ||
client = AxClient() | ||
client.create_experiment( | ||
is_test=True, | ||
name="foo", | ||
parameters=[ | ||
{ | ||
"name": "x", | ||
"type": "range", | ||
"bounds": [-1.0, 1.0], | ||
} | ||
], | ||
objectives={"bar": ObjectiveProperties(minimize=True)}, | ||
) | ||
|
||
for _ in range(10): | ||
parameterization, trial_index = client.get_next_trial() | ||
client.complete_trial( | ||
trial_index=trial_index, raw_data={"bar": parameterization["x"] ** 2} | ||
) | ||
|
||
analysis = CrossValidationPlot(metric_name="bar") | ||
|
||
# Test that it fails if no GenerationStrategy is provided | ||
with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"): | ||
analysis.compute() | ||
|
||
card = analysis.compute(generation_strategy=client.generation_strategy) | ||
self.assertEqual( | ||
card.name, | ||
"CrossValidationPlot", | ||
) | ||
self.assertEqual(card.title, "Cross Validation for bar") | ||
self.assertEqual( | ||
card.subtitle, | ||
"Out-of-sample predictions using leave-one-out CV", | ||
) | ||
self.assertEqual(card.level, AnalysisCardLevel.LOW) | ||
self.assertEqual( | ||
{*card.df.columns}, | ||
{"arm_name", "observed", "observed_sem", "predicted", "predicted_sem"}, | ||
) | ||
self.assertIsNotNone(card.blob) | ||
self.assertEqual(card.blob_annotation, "plotly") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.