Skip to content

Commit

Permalink
CrossValidationPlot analysis (#2861)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2861

New CrossValidationPlot using the Analysis framework. Plots CVResults of the current model on the provided GS with CI if available. Forces a square aspect ratio (but does not force any specific resolution). Hover data for x, y value with CI and arm name. Displays a gray dashed line on y=x.

This diff also refactors _select_metric out of Parallel Coordinates into a utils file so we can use it here to select the objective if no metric_name is provided.

NOTE: We are intentionally leaving the two dropdowns out: the "metric" dropdown because we want to generate many CVs separately and combine them more gracefully in the UI (cc Cesar-Cardoso), the latter because we dont actually want users to be able to view this without the confidence intervals (it is not useful).

Reviewed By: danielcohenlive

Differential Revision: D64056891

fbshipit-source-id: 01553fc9438cff9eb539ddf0697dca130a422140
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Oct 16, 2024
1 parent cfb0e01 commit 98dbb39
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 29 deletions.
2 changes: 2 additions & 0 deletions ax/analysis/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

# pyre-strict

from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.scatter import ScatterPlot

__all__ = [
"CrossValidationPlot",
"PlotlyAnalysis",
"PlotlyAnalysisCard",
"ParallelCoordinatesPlot",
Expand Down
213 changes: 213 additions & 0 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.modelbridge.cross_validation import cross_validate
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import express as px, graph_objects as go
from pyre_extensions import assert_is_instance, none_throws


class CrossValidationPlot(PlotlyAnalysis):
"""
Plotly Scatter plot for cross validation for model predictions using the current
model on the GenerationStrategy. This plot is useful for understanding how well
the model is able to predict out-of-sample which in turn is indicative of its
ability to suggest valuable candidates.
Splits the model's training data into train/test folds and makes
out-of-sample predictions on the test folds.
A well fit model will have points clustered around the y=x line, and a model with
poor fit may have points in a horizontal band in the center of the plot
indicating a tendency to predict the observed mean of the specificed metric for
all arms.
The DataFrame computed will contain one row per arm and the following columns:
- arm_name: The name of the arm
- observed: The observed mean of the metric specified
- observed_sem: The SEM of the observed mean of the metric specified
- predicted: The predicted mean of the metric specified
- predicted_sem: The SEM of the predicted mean of the metric specified
"""

def __init__(
self, metric_name: str | None = None, folds: int = -1, untransform: bool = True
) -> None:
"""
Args:
metric_name: The name of the metric to plot. If not specified the objective
will be used. Note that the metric cannot be inferred for
multi-objective or scalarized-objective experiments.
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Models are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
is actually used for candidate generation in case of non-invertible
transforms, e.g., Winsorize or LogY. While the model in the
transformed space may not be representative of the original data in
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
"""

self.metric_name = metric_name
self.folds = folds
self.untransform = untransform

def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> PlotlyAnalysisCard:
if generation_strategy is None:
raise UserInputError("CrossValidation requires a GenerationStrategy")

metric_name = self.metric_name or select_metric(
experiment=generation_strategy.experiment
)

df = _prepare_data(
# CrossValidationPlot requires a native Ax GenerationStrategy and cannot be
# used with a GenerationStrategyInterface.
generation_strategy=assert_is_instance(
generation_strategy, GenerationStrategy
),
metric_name=metric_name,
folds=self.folds,
untransform=self.untransform,
)
fig = _prepare_plot(df=df, metric_name=metric_name)

k_folds_substring = f"{self.folds}-fold" if self.folds > 0 else "leave-one-out"

# Nudge the priority if the metric is important to the experiment
if (
experiment is not None
and (optimization_config := experiment.optimization_config) is not None
and (objective := optimization_config.objective) is not None
and metric_name in objective.metric_names
):
nudge = 2
elif (
experiment is not None
and (optimization_config := experiment.optimization_config) is not None
and metric_name in optimization_config.outcome_constraints
):
nudge = 1
else:
nudge = 0

return self._create_plotly_analysis_card(
title=f"Cross Validation for {metric_name}",
subtitle=f"Out-of-sample predictions using {k_folds_substring} CV",
level=AnalysisCardLevel.LOW.value + nudge,
df=df,
fig=fig,
)


def _prepare_data(
generation_strategy: GenerationStrategy,
metric_name: str,
folds: int,
untransform: bool,
) -> pd.DataFrame:
# If model is not fit already, fit it
if generation_strategy.model is None:
generation_strategy._fit_current_model(None)

cv_results = cross_validate(
model=none_throws(generation_strategy.model),
folds=folds,
untransform=untransform,
)

records = []
for observed, predicted in cv_results:
for i in range(len(observed.data.metric_names)):
# Find the index of the metric we want to plot
if not (
observed.data.metric_names[i] == metric_name
and predicted.metric_names[i] == metric_name
):
continue

record = {
"arm_name": observed.arm_name,
"observed": observed.data.means[i],
"predicted": predicted.means[i],
# Take the square root of the the SEM to get the standard deviation
"observed_sem": observed.data.covariance[i][i],
"predicted_sem": predicted.covariance[i][i],
}
records.append(record)
break

return pd.DataFrame.from_records(records)


def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure:
fig = px.scatter(
df,
x="observed",
y="predicted",
error_x="observed_sem",
error_y="predicted_sem",
labels={
"observed": f"Observed {metric_name}",
"predicted": f"Predicted {metric_name}",
},
hover_data=["arm_name", "observed", "predicted"],
)

# Add a gray dashed line at y=x starting and ending just outside of the region of
# interest for reference. A well fit model should have points clustered around this
# line.
lower_bound = (
min(
(df["observed"] - df["observed_sem"].fillna(0)).min(),
(df["predicted"] - df["predicted_sem"].fillna(0)).min(),
)
* 0.99
)
upper_bound = (
max(
(df["observed"] + df["observed_sem"].fillna(0)).max(),
(df["predicted"] + df["predicted_sem"].fillna(0)).max(),
)
* 1.01
)

fig.add_shape(
type="line",
x0=lower_bound,
y0=lower_bound,
x1=upper_bound,
y1=upper_bound,
line={"color": "gray", "dash": "dot"},
)

# Force plot to display as a square
fig.update_xaxes(range=[lower_bound, upper_bound], constrain="domain")
fig.update_yaxes(
scaleanchor="x",
scaleratio=1,
)

return fig
25 changes: 3 additions & 22 deletions ax/analysis/plotly/parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.objective import MultiObjective, ScalarizedObjective
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import UserInputError
from plotly import graph_objects as go


Expand Down Expand Up @@ -50,7 +50,7 @@ def compute(
if experiment is None:
raise UserInputError("ParallelCoordinatesPlot requires an Experiment")

metric_name = self.metric_name or _select_metric(experiment=experiment)
metric_name = self.metric_name or select_metric(experiment=experiment)

df = _prepare_data(experiment=experiment, metric=metric_name)
fig = _prepare_plot(df=df, metric_name=metric_name)
Expand Down Expand Up @@ -113,25 +113,6 @@ def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure:
)


def _select_metric(experiment: Experiment) -> str:
if experiment.optimization_config is None:
raise ValueError(
"Cannot infer metric to plot from Experiment without OptimizationConfig"
)
objective = experiment.optimization_config.objective
if isinstance(objective, MultiObjective):
raise UnsupportedError(
"Cannot infer metric to plot from MultiObjective, please "
"specify a metric"
)
if isinstance(objective, ScalarizedObjective):
raise UnsupportedError(
"Cannot infer metric to plot from ScalarizedObjective, please "
"specify a metric"
)
return experiment.optimization_config.objective.metric.name


def _find_mean_by_arm_name(
df: pd.DataFrame,
arm_name: str,
Expand Down
61 changes: 61 additions & 0 deletions ax/analysis/plotly/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.exceptions.core import UserInputError
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.common.testutils import TestCase
from ax.utils.testing.mock import fast_botorch_optimize


class TestCrossValidationPlot(TestCase):
@fast_botorch_optimize
def test_compute(self) -> None:
client = AxClient()
client.create_experiment(
is_test=True,
name="foo",
parameters=[
{
"name": "x",
"type": "range",
"bounds": [-1.0, 1.0],
}
],
objectives={"bar": ObjectiveProperties(minimize=True)},
)

for _ in range(10):
parameterization, trial_index = client.get_next_trial()
client.complete_trial(
trial_index=trial_index, raw_data={"bar": parameterization["x"] ** 2}
)

analysis = CrossValidationPlot(metric_name="bar")

# Test that it fails if no GenerationStrategy is provided
with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"):
analysis.compute()

card = analysis.compute(generation_strategy=client.generation_strategy)
self.assertEqual(
card.name,
"CrossValidationPlot",
)
self.assertEqual(card.title, "Cross Validation for bar")
self.assertEqual(
card.subtitle,
"Out-of-sample predictions using leave-one-out CV",
)
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(
{*card.df.columns},
{"arm_name", "observed", "observed_sem", "predicted", "predicted_sem"},
)
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")
12 changes: 6 additions & 6 deletions ax/analysis/plotly/tests/test_parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.plotly.parallel_coordinates import (
_get_parameter_dimension,
_select_metric,
ParallelCoordinatesPlot,
)
from ax.analysis.plotly.utils import select_metric
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_compute(self) -> None:
analysis_no_metric = ParallelCoordinatesPlot()
_ = analysis_no_metric.compute(experiment=experiment)

def test_select_metric(self) -> None:
def testselect_metric(self) -> None:
experiment = get_branin_experiment()
experiment_no_optimization_config = get_branin_experiment(
has_optimization_config=False
Expand All @@ -54,16 +54,16 @@ def test_select_metric(self) -> None:
get_experiment_with_scalarized_objective_and_outcome_constraint()
)

self.assertEqual(_select_metric(experiment=experiment), "branin")
self.assertEqual(select_metric(experiment=experiment), "branin")

with self.assertRaisesRegex(ValueError, "OptimizationConfig"):
_select_metric(experiment=experiment_no_optimization_config)
select_metric(experiment=experiment_no_optimization_config)

with self.assertRaisesRegex(UnsupportedError, "MultiObjective"):
_select_metric(experiment=experiment_multi_objective)
select_metric(experiment=experiment_multi_objective)

with self.assertRaisesRegex(UnsupportedError, "ScalarizedObjective"):
_select_metric(experiment=experiment_scalarized_objective)
select_metric(experiment=experiment_scalarized_objective)

def test_get_parameter_dimension(self) -> None:
range_series = pd.Series([0, 1, 2, 3], name="range")
Expand Down
Loading

0 comments on commit 98dbb39

Please sign in to comment.