From 586f2c3a6209bac0d0c4a4fa90d6d005b73092c2 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Mon, 25 Mar 2024 09:05:30 -0700 Subject: [PATCH] TileFittedPlot Differential Revision: D55030326 --- ax/analysis/helpers/plot_data_df_helpers.py | 91 ++++++++++ ax/analysis/helpers/scatter_helpers.py | 22 +-- ax/analysis/parallel_coordinates_plot.py | 4 +- ax/analysis/tests/test_tile_fitted_plot.py | 40 +++++ ax/analysis/tile_fitted_plot.py | 185 ++++++++++++++++++++ 5 files changed, 330 insertions(+), 12 deletions(-) create mode 100644 ax/analysis/helpers/plot_data_df_helpers.py create mode 100644 ax/analysis/tests/test_tile_fitted_plot.py create mode 100644 ax/analysis/tile_fitted_plot.py diff --git a/ax/analysis/helpers/plot_data_df_helpers.py b/ax/analysis/helpers/plot_data_df_helpers.py new file mode 100644 index 00000000000..64dee6016c7 --- /dev/null +++ b/ax/analysis/helpers/plot_data_df_helpers.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Set + +import numpy as np +import pandas as pd + +from ax.modelbridge import ModelBridge +from ax.modelbridge.prediction_utils import predict_at_point + +from ax.modelbridge.transforms.ivw import IVW + + +def get_plot_data_in_sample_arms_df( + model: ModelBridge, + metric_names: Set[str], +) -> pd.DataFrame: + """Get in-sample arms from a model with observed and predicted values + for specified metrics. + + Returns a dataframe in which repeated observations are merged + with IVW (inverse variance weighting) + + Args: + model: An instance of the model bridge. + metric_names: Restrict predictions to these metrics. If None, uses all + metrics in the model. + + Returns: + A dataframe containing + columns: + { + "arm_name": name of the arm in the cross validation result + "metric_name": name of the observed/predicted metric + "x": value observed for the metric for this arm + "x_se": standard error of observed metric (0 for observations) + "y": value predicted for the metric for this arm + "y_se": standard error of predicted metric for this arm + "arm_parameters": Parametrization of the arm + } + """ + observations = model.get_training_data() + training_in_design: List[bool] = model.training_in_design + + # Merge multiple measurements within each Observation with IVW to get + # un-modeled prediction + observations = IVW(None, []).transform_observations(observations) + + # Create records for dict + records = [] + + for i, obs in enumerate(observations): + # Extract raw measurement + features = obs.features + + if training_in_design[i]: + pred_y_dict, pred_se_dict = predict_at_point(model, features, metric_names) + else: + pred_y_dict = None + pred_se_dict = None + + for metric_name in obs.data.metric_names: + if metric_name not in metric_names: + continue + obs_y = obs.data.means_dict[metric_name] + obs_se = np.sqrt(obs.data.covariance_matrix[metric_name][metric_name]) + + if pred_y_dict and pred_se_dict: + pred_y = pred_y_dict[metric_name] + pred_se = pred_se_dict[metric_name] + else: + pred_y = obs_y + pred_se = obs_se + + records.append( + { + "arm_name": obs.arm_name, + "metric_name": metric_name, + "x": obs_y, + "x_se": obs_se, + "y": pred_y, + "y_se": pred_se, + "arm_parameters": obs.features.parameters, + } + ) + + return pd.DataFrame.from_records(records) diff --git a/ax/analysis/helpers/scatter_helpers.py b/ax/analysis/helpers/scatter_helpers.py index 7a931213891..12dda0b2229 100644 --- a/ax/analysis/helpers/scatter_helpers.py +++ b/ax/analysis/helpers/scatter_helpers.py @@ -72,14 +72,14 @@ def make_label( estimate=( round(x_val, DECIMALS) if isinstance(x_val, numbers.Number) else x_val ), - ci=_format_CI(x_val, x_se), + ci=_format_CI(estimate=x_val, sd=x_se), ) y_lab = "{name}: {estimate} {ci}
".format( name=y_name, estimate=( round(y_val, DECIMALS) if isinstance(y_val, numbers.Number) else y_val ), - ci=_format_CI(y_val, y_se), + ci=_format_CI(estimate=y_val, sd=y_se), ) parameterization = _format_dict(param_blob, "Parameterization") @@ -114,22 +114,24 @@ def error_scatter_trace_from_df( x, x_se, y, y_se = extract_mean_and_error_from_df(df) labels = [] - param_blobs = df["arm_parameters"] arm_names = df["arm_name"] metric_name = df["metric_name"].iloc[0] - for i in range(len(param_blobs)): + print("Data frame: " + str(df)) + print("Arm names: " + str(arm_names)) + print("x" + str(x)) + for _, row in df.iterrows(): labels.append( make_label( - arm_name=arm_names[i], + arm_name=row["arm_name"], x_name=metric_name if x_axis_label is None else x_axis_label, - x_val=x[i], - x_se=x_se[i], + x_val=row["x"], + x_se=row["x_se"], y_name=(metric_name if y_axis_label is None else y_axis_label), - y_val=y[i], - y_se=y_se[i], - param_blob=param_blobs[i], + y_val=row["y"], + y_se=row["y_se"], + param_blob=row["arm_parameters"], ) ) diff --git a/ax/analysis/parallel_coordinates_plot.py b/ax/analysis/parallel_coordinates_plot.py index f8a67697432..f2b5f7b999f 100644 --- a/ax/analysis/parallel_coordinates_plot.py +++ b/ax/analysis/parallel_coordinates_plot.py @@ -9,6 +9,8 @@ import pandas as pd +from ax.analysis.base_plotly_visualization import BasePlotlyVisualization + from ax.core.arm import Arm from ax.core.base_trial import BaseTrial @@ -18,8 +20,6 @@ from plotly import express as px, graph_objs as go -from .base_plotly_visualization import BasePlotlyVisualization - class ParallelCoordinatesPlot(BasePlotlyVisualization): def __init__( diff --git a/ax/analysis/tests/test_tile_fitted_plot.py b/ax/analysis/tests/test_tile_fitted_plot.py new file mode 100644 index 00000000000..94478c261c8 --- /dev/null +++ b/ax/analysis/tests/test_tile_fitted_plot.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import plotly.graph_objects as go + +from ax.analysis.tile_fitted_plot import TileFittedPlot + +from ax.modelbridge.registry import Models + +from ax.utils.testing.core_stubs import get_branin_experiment + + +class TestTileFittedPlot(unittest.TestCase): + def setUp(self) -> None: + self.exp = get_branin_experiment(with_batch=True) + self.exp.trials[0].run() + self.model = Models.BOTORCH_MODULAR( + # Model bridge kwargs + experiment=self.exp, + data=self.exp.fetch_data(), + ) + + super().setUp() + + def test_tile_fitted_plot(self) -> None: + tile_fitted_plot = TileFittedPlot( + experiment=self.exp, + model=self.model, + ) + + df = tile_fitted_plot.get_df() + print(df) + + fig = tile_fitted_plot.get_fig() + self.assertIsInstance(fig, go.Figure) diff --git a/ax/analysis/tile_fitted_plot.py b/ax/analysis/tile_fitted_plot.py new file mode 100644 index 00000000000..5e25cc82c61 --- /dev/null +++ b/ax/analysis/tile_fitted_plot.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, Set + +import numpy as np + +import pandas as pd + +from ax.analysis.base_plotly_visualization import BasePlotlyVisualization + +from ax.analysis.helpers.plot_data_df_helpers import get_plot_data_in_sample_arms_df + +from ax.analysis.helpers.scatter_helpers import error_scatter_trace_from_df + +from ax.core.experiment import Experiment + +from ax.modelbridge import ModelBridge + +from plotly import graph_objs as go, subplots + + +class TileFittedPlot(BasePlotlyVisualization): + def __init__( + self, + experiment: Experiment, + model: ModelBridge, + ) -> None: + """ """ + + self.model = model + self.metrics: Set[str] = model.metric_names + + super().__init__(experiment=experiment) + + def get_df(self) -> pd.DataFrame: + """ """ + return get_plot_data_in_sample_arms_df( + model=self.model, metric_names=self.metrics + ) + + def get_fig( + self, + ) -> go.Figure: + """Tile version of fitted outcome plots.""" + metrics = self.metrics + nrows = int(np.ceil(len(metrics) / 2)) + ncols = min(len(metrics), 2) + + subplot_titles = metrics + + fig = subplots.make_subplots( + rows=nrows, + cols=ncols, + print_grid=False, + shared_xaxes=False, + shared_yaxes=False, + subplot_titles=tuple(subplot_titles), + horizontal_spacing=0.05, + vertical_spacing=0.30 / nrows, + ) + + name_order_args: Dict[str, Any] = {} + name_order_axes: Dict[str, Dict[str, Any]] = {} + effect_order_args: Dict[str, Any] = {} + + in_sample_df = self.get_df() + + for i, metric in enumerate(metrics): + filtered_df = in_sample_df.loc[in_sample_df["metric_name"] == metric] + data: Dict[str, Any] = error_scatter_trace_from_df( + df=filtered_df, + show_CI=True, + ) + + # order arm name sorting arm numbers within batch + """names_by_arm = sorted( + np.unique(np.concatenate([d["x"] for d in data])), + key=lambda x: arm_name_to_sort_key(x), + )""" + names_by_arm = [] + + # get arm names sorted by effect size + """names_by_effect = list( + OrderedDict.fromkeys( + np.concatenate([d["x"] for d in data]) + .flatten() + .take(np.argsort(np.concatenate([d["y"] for d in data]).flatten())) + ) + )""" + names_by_effect = [] + + # options for ordering arms (x-axis) + # Note that xaxes need to be references as xaxis, xaxis2, xaxis3, etc. + # for the purposes of updatemenus argument (dropdown) in layout. + # However, when setting the initial ordering layout, the keys should be + # xaxis1, xaxis2, xaxis3, etc. Note the discrepancy for the initial + # axis. + label = "" if i == 0 else i + 1 + name_order_args["xaxis{}.categoryorder".format(label)] = "array" + name_order_args["xaxis{}.categoryarray".format(label)] = names_by_arm + effect_order_args["xaxis{}.categoryorder".format(label)] = "array" + effect_order_args["xaxis{}.categoryarray".format(label)] = names_by_effect + name_order_axes["xaxis{}".format(i + 1)] = { + "categoryorder": "array", + "categoryarray": names_by_arm, + "type": "category", + } + name_order_axes["yaxis{}".format(i + 1)] = { + "ticksuffix": "", + "zerolinecolor": "red", + } + + fig.append_trace( # pyre-ignore[16] + data, int(np.floor(i / ncols)) + 1, i % ncols + 1 + ) + + order_options = [ + {"args": [name_order_args], "label": "Name", "method": "relayout"}, + {"args": [effect_order_args], "label": "Effect Size", "method": "relayout"}, + ] + + # if odd number of plots, need to manually remove the last blank subplot + # generated by `subplots.make_subplots` + if len(metrics) % 2 == 1: + fig["layout"].pop("xaxis{}".format(nrows * ncols)) + fig["layout"].pop("yaxis{}".format(nrows * ncols)) + + # allocate 400 px per plot + fig["layout"].update( + margin={"t": 0}, + hovermode="closest", + updatemenus=[ + { + "x": 0.15, + "y": 1 + 0.40 / nrows, + "buttons": order_options, + "xanchor": "left", + "yanchor": "middle", + } + ], + font={"size": 10}, + width=650 if ncols == 1 else 950, + height=300 * nrows, + legend={ + "orientation": "h", + "x": 0, + "y": 1 + 0.20 / nrows, + "xanchor": "left", + "yanchor": "middle", + }, + **name_order_axes, + ) + + # append dropdown annotations + fig["layout"]["annotations"] = fig["layout"]["annotations"] + ( + { + "x": 0.5, + "y": 1 + 0.40 / nrows, + "xref": "paper", + "yref": "paper", + "font": {"size": 14}, + "text": "Predicted Outcomes", + "showarrow": False, + "xanchor": "center", + "yanchor": "middle", + }, + { + "x": 0.05, + "y": 1 + 0.40 / nrows, + "xref": "paper", + "yref": "paper", + "text": "Sort By", + "showarrow": False, + "xanchor": "left", + "yanchor": "middle", + }, + ) + + # fig = resize_subtitles(figure=fig, size=10) + return fig