From a85c7eebbff0d46fbdcbf6f002c6aca60bcfce55 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Thu, 24 Oct 2024 11:18:51 -0700 Subject: [PATCH] Reap ax.analysis.old (#2956) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2956 All of these plots have been recreated using the new Analysis structure we intend to use going forward. Reaping dead code Reviewed By: Cesar-Cardoso Differential Revision: D64907658 fbshipit-source-id: 1014721104c6a9084ae7158cac23a002bc99b451 --- ax/analysis/old/analysis_report.py | 96 ------ ax/analysis/old/base_analysis.py | 52 --- ax/analysis/old/base_plotly_visualization.py | 52 --- ax/analysis/old/cross_validation_plot.py | 252 -------------- ax/analysis/old/helpers/color_helpers.py | 18 - ax/analysis/old/helpers/constants.py | 23 -- .../old/helpers/cross_validation_helpers.py | 200 ----------- ax/analysis/old/helpers/layout_helpers.py | 111 ------- .../old/helpers/plot_data_df_helpers.py | 92 ----- ax/analysis/old/helpers/plot_helpers.py | 91 ----- ax/analysis/old/helpers/scatter_helpers.py | 314 ------------------ .../tests/test_cross_validation_helpers.py | 91 ----- .../tests/test_cv_consistency_checks.py | 129 ------- .../old/predicted_outcomes_dot_plot.py | 143 -------- ax/analysis/old/tests/test_analysis_report.py | 144 -------- ax/analysis/old/tests/test_base_classes.py | 142 -------- .../old/tests/test_cross_validation_plot.py | 39 --- .../tests/test_predicted_outcomes_dot_plot.py | 56 ---- 18 files changed, 2045 deletions(-) delete mode 100644 ax/analysis/old/analysis_report.py delete mode 100644 ax/analysis/old/base_analysis.py delete mode 100644 ax/analysis/old/base_plotly_visualization.py delete mode 100644 ax/analysis/old/cross_validation_plot.py delete mode 100644 ax/analysis/old/helpers/color_helpers.py delete mode 100644 ax/analysis/old/helpers/constants.py delete mode 100644 ax/analysis/old/helpers/cross_validation_helpers.py delete mode 100644 ax/analysis/old/helpers/layout_helpers.py delete mode 100644 ax/analysis/old/helpers/plot_data_df_helpers.py delete mode 100644 ax/analysis/old/helpers/plot_helpers.py delete mode 100644 ax/analysis/old/helpers/scatter_helpers.py delete mode 100644 ax/analysis/old/helpers/tests/test_cross_validation_helpers.py delete mode 100644 ax/analysis/old/helpers/tests/test_cv_consistency_checks.py delete mode 100644 ax/analysis/old/predicted_outcomes_dot_plot.py delete mode 100644 ax/analysis/old/tests/test_analysis_report.py delete mode 100644 ax/analysis/old/tests/test_base_classes.py delete mode 100644 ax/analysis/old/tests/test_cross_validation_plot.py delete mode 100644 ax/analysis/old/tests/test_predicted_outcomes_dot_plot.py diff --git a/ax/analysis/old/analysis_report.py b/ax/analysis/old/analysis_report.py deleted file mode 100644 index 350213a0e44..00000000000 --- a/ax/analysis/old/analysis_report.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -import pandas as pd -import plotly.graph_objects as go - -from ax.analysis.old.base_analysis import BaseAnalysis -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization - -from ax.core.experiment import Experiment - -from ax.utils.common.timeutils import current_timestamp_in_millis - - -class AnalysisReport: - """ - A class corresponding to a set of analysis ran on the same - set of data from an experiment. - """ - - analyses: list[BaseAnalysis] = [] - experiment: Experiment - - time_started: int | None = None - time_completed: int | None = None - - def __init__( - self, - experiment: Experiment, - analyses: list[BaseAnalysis], - time_started: int | None = None, - time_completed: int | None = None, - ) -> None: - """ - This class is a collection of AnalysisReport. - - Args: - experiment: Experiment which the analyses are generated from - time_started: time the completed report was started - time_completed: time the completed report was started - - """ - self.experiment = experiment - self.analyses = analyses - self.time_started = time_started - self.time_completed = time_completed - - @property - def report_completed(self) -> bool: - """ - Returns: - True if the report is completed, False otherwise. - """ - return self.time_completed is not None - - def run_analysis_report( - self, - ) -> list[ - tuple[ - BaseAnalysis, - pd.DataFrame, - go.Figure | None, - ] - ]: - """ - Runs all analyses in the report and produces the result. - - Returns: - analysis_report_result: list of tuples (analysis, df, Optional[fig]) - """ - if not self.report_completed: - self.time_started = current_timestamp_in_millis() - - analysis_report_result = [] - for analysis in self.analyses: - analysis_report_result.append( - ( - analysis, - analysis.get_df(), - ( - None - if not isinstance(analysis, BasePlotlyVisualization) - else analysis.get_fig() - ), - ) - ) - - if not self.report_completed: - self.time_completed = current_timestamp_in_millis() - - return analysis_report_result diff --git a/ax/analysis/old/base_analysis.py b/ax/analysis/old/base_analysis.py deleted file mode 100644 index d7a69187879..00000000000 --- a/ax/analysis/old/base_analysis.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -import pandas as pd -from ax.core.experiment import Experiment - - -class BaseAnalysis: - """ - Abstract Analysis class for ax. - This is an interface that defines the methods to be implemented by all analyses. - Computes an output dataframe for each analysis - """ - - def __init__( - self, - experiment: Experiment, - df_input: pd.DataFrame | None = None, - # TODO: add support for passing in experiment name, and markdown message - ) -> None: - """ - Initialize the analysis with the experiment object. - For scenarios where an analysis output is already available, - we can pass the dataframe as an input. - """ - self._experiment = experiment - self._df: pd.DataFrame | None = df_input - - @property - def experiment(self) -> Experiment: - return self._experiment - - @property - def df(self) -> pd.DataFrame: - """ - Return the output of the analysis of this class. - """ - if self._df is None: - self._df = self.get_df() - return self._df - - def get_df(self) -> pd.DataFrame: - """ - Return the output of the analysis of this class. - Subclasses should overwrite this. - """ - raise NotImplementedError("get_df must be implemented by subclass") diff --git a/ax/analysis/old/base_plotly_visualization.py b/ax/analysis/old/base_plotly_visualization.py deleted file mode 100644 index ab65dc0d0e6..00000000000 --- a/ax/analysis/old/base_plotly_visualization.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -import pandas as pd - -import plotly.graph_objects as go - -from ax.analysis.old.base_analysis import BaseAnalysis -from ax.core.experiment import Experiment - - -class BasePlotlyVisualization(BaseAnalysis): - """ - Abstract PlotlyVisualization class for ax. - This is an interface that defines the method to be implemented by all ax plots. - Computes an output dataframe for each analysis - """ - - def __init__( - self, - experiment: Experiment, - df_input: pd.DataFrame | None = None, - fig_input: go.Figure | None = None, - ) -> None: - """ - Initialize the analysis with the experiment object. - For scenarios where an analysis output is already available, - we can pass the dataframe as an input. - """ - self._fig = fig_input - super().__init__(experiment=experiment, df_input=df_input) - - @property - def fig(self) -> go.Figure: - """ - Return the output of the analysis of this class. - """ - if self._fig is None: - self._fig = self.get_fig() - return self._fig - - def get_fig(self) -> go.Figure: - """ - Return the plotly figure of the analysis of this class. - Subclasses should overwrite this. - """ - raise NotImplementedError("get_fig must be implemented by subclass") diff --git a/ax/analysis/old/cross_validation_plot.py b/ax/analysis/old/cross_validation_plot.py deleted file mode 100644 index 3f3ac5646a3..00000000000 --- a/ax/analysis/old/cross_validation_plot.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from copy import deepcopy -from typing import Any - -import pandas as pd - -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization - -from ax.analysis.old.helpers.cross_validation_helpers import ( - cv_results_to_df, - diagonal_trace, - get_plotting_limit_ignore_outliers, -) - -from ax.analysis.old.helpers.layout_helpers import layout_format, updatemenus_format - -from ax.analysis.old.helpers.scatter_helpers import ( - error_scatter_trace_from_df, - extract_mean_and_error_from_df, -) - -from ax.core.experiment import Experiment -from ax.modelbridge import ModelBridge - -from ax.modelbridge.cross_validation import cross_validate, CVResult - -from plotly import graph_objs as go - - -class CrossValidationPlot(BasePlotlyVisualization): - CROSS_VALIDATION_CAPTION = ( - "NOTE: We have tried our best to only plot the region of interest.
" - "This may hide outliers. You can autoscale the axes to see all trials." - ) - - def __init__( - self, - experiment: Experiment, - model: ModelBridge, - label_dict: dict[str, str] | None = None, - caption: str = CROSS_VALIDATION_CAPTION, - ) -> None: - """ - Args: - experiment: Experiment containing trials to plot - model: ModelBridge to cross validate against - label_dict: optional map from real metric names to shortened names - caption: text to display below the plot - """ - self.model = model - self.cv: list[CVResult] = cross_validate(model=model) - - self.label_dict: dict[str, str] | None = label_dict - if self.label_dict: - self.cv = self.remap_label(cv_results=self.cv, label_dict=self.label_dict) - - self.metric_names: list[str] = list( - set().union(*(cv_result.predicted.metric_names for cv_result in self.cv)) - ) - self.caption = caption - - super().__init__(experiment=experiment) - - def get_df(self) -> pd.DataFrame: - """ - Overrides BaseAnalysis.get_df() - - Returns: - df representation of the cross validation results. - columns: - { - "arm_name": name of the arm in the cross validation result - "metric_name": name of the observed/predicted metric - "x": value observed for the metric for this arm - "x_se": standard error of observed metric (0 for observations) - "y": value predicted for the metric for this arm - "y_se": standard error of predicted metric for this arm - "arm_parameters": Parametrization of the arm - } - """ - - df = pd.concat( - [ - cv_results_to_df( - cv_results=self.cv, - metric_name=metric, - ) - for metric in self.metric_names - ] - ) - - return df - - @staticmethod - def compose_annotation( - caption: str, x: float = 0.0, y: float = -0.15 - ) -> list[dict[str, Any]]: - """Composes an annotation dict for use in Plotly figure. - args: - caption: str to use for dropdown text - x: x position of the annotation - y: y position of the annotation - - returns: - Annotation dict for use in Plotly figure. - """ - return [ - { - "showarrow": False, - "text": caption, - "x": x, - "xanchor": "left", - "xref": "paper", - "y": y, - "yanchor": "top", - "yref": "paper", - "align": "left", - }, - ] - - @staticmethod - def remap_label( - cv_results: list[CVResult], label_dict: dict[str, str] - ) -> list[CVResult]: - """Remaps labels in cv_results according to label_dict. - - Args: - cv_results: A CVResult for each observation in the training data. - label_dict: optional map from real metric names to shortened names - - Returns: - A CVResult with metric names mapped from label_dict. - """ - cv_results = deepcopy(cv_results) # Copy and edit in-place - for cv_i in cv_results: - cv_i.observed.data.metric_names = [ - label_dict.get(m, m) for m in cv_i.observed.data.metric_names - ] - cv_i.predicted.metric_names = [ - label_dict.get(m, m) for m in cv_i.predicted.metric_names - ] - return cv_results - - def obs_vs_pred_dropdown_plot( - self, - xlabel: str = "Actual Outcome", - ylabel: str = "Predicted Outcome", - ) -> go.Figure: - """Plot a dropdown plot of observed vs. predicted values from the - cross validation results. - - Args: - xlabel: Label for x-axis. - ylabel: Label for y-axis. - """ - traces = [] - metric_dropdown = [] - layout_axis_range = [] - - # Get the union of all metric_names seen in predictions - metric_names = self.metric_names - df = self.get_df() - - for i, metric in enumerate(metric_names): - metric_filtered_df = df.loc[df["metric_name"] == metric] - - y_raw, se_raw, y_hat, se_hat = extract_mean_and_error_from_df( - metric_filtered_df - ) - - # Use the min/max of the limits - layout_range, diagonal_trace_range = get_plotting_limit_ignore_outliers( - x=y_raw, y=y_hat, se_x=se_raw, se_y=se_hat - ) - layout_axis_range.append(layout_range) - - # add a diagonal dotted line to plot - traces.append( - diagonal_trace( - diagonal_trace_range[0], - diagonal_trace_range[1], - visible=(i == 0), - ) - ) - - traces.append( - error_scatter_trace_from_df( - df=metric_filtered_df, - show_CI=True, - visible=(i == 0), - x_axis_label="Actual Outcome", - y_axis_label="Predicted Outcome", - ) - ) - - # only the first two traces are visible (corresponding to first outcome - # in dropdown) - is_visible = [False] * (len(metric_names) * 2) - is_visible[2 * i] = True - is_visible[2 * i + 1] = True - - # on dropdown change, restyle - metric_dropdown.append( - { - "args": [ - {"visible": is_visible}, - { - "xaxis.range": layout_axis_range[-1], - "yaxis.range": layout_axis_range[-1], - }, - ], - "label": metric, - "method": "update", - } - ) - - updatemenus = updatemenus_format(metric_dropdown=metric_dropdown) - layout = layout_format( - layout_axis_range_value=layout_axis_range[0], - xlabel=xlabel, - ylabel=ylabel, - updatemenus=updatemenus, - ) - - return go.Figure(data=traces, layout=layout) - - def get_fig(self) -> go.Figure: - """ - Interactive cross-validation (CV) plotting; select metric via dropdown. - Note: uses the Plotly version of dropdown (which means that all data is - stored within the notebook). - - Returns: - go.Figure: Plotly figure with cross validation plot - """ - caption = self.caption - - fig = self.obs_vs_pred_dropdown_plot() - - current_bmargin = fig["layout"]["margin"].b or 90 - caption_height = 100 * (len(caption) > 0) - fig["layout"]["margin"].b = current_bmargin + caption_height - fig["layout"]["height"] += caption_height - fig["layout"]["annotations"] += tuple(self.compose_annotation(caption)) - fig["layout"]["title"] = "Cross-validation" - return fig diff --git a/ax/analysis/old/helpers/color_helpers.py b/ax/analysis/old/helpers/color_helpers.py deleted file mode 100644 index 5eda51507bd..00000000000 --- a/ax/analysis/old/helpers/color_helpers.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -from numbers import Real - -# type aliases -TRGB = tuple[Real, ...] - - -def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str: - """Convert RGB tuple to an RGBA string.""" - return "rgba({},{},{},{alpha})".format(*rgb_tuple, alpha=alpha) diff --git a/ax/analysis/old/helpers/constants.py b/ax/analysis/old/helpers/constants.py deleted file mode 100644 index 63de5a55094..00000000000 --- a/ax/analysis/old/helpers/constants.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import enum - -# Constants used for numerous plots -CI_OPACITY = 0.4 -DECIMALS = 3 -Z = 1.96 - - -# color constants used for plotting -class COLORS(enum.Enum): - STEELBLUE = (128, 177, 211) - CORAL = (251, 128, 114) - TEAL = (141, 211, 199) - PINK = (188, 128, 189) - LIGHT_PURPLE = (190, 186, 218) - ORANGE = (253, 180, 98) diff --git a/ax/analysis/old/helpers/cross_validation_helpers.py b/ax/analysis/old/helpers/cross_validation_helpers.py deleted file mode 100644 index 3522c37abda..00000000000 --- a/ax/analysis/old/helpers/cross_validation_helpers.py +++ /dev/null @@ -1,200 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Any - -import numpy as np -import pandas as pd -import plotly.graph_objs as go - -from ax.analysis.old.helpers.constants import Z - -from ax.analysis.old.helpers.plot_helpers import arm_name_to_sort_key - -from ax.modelbridge.cross_validation import CVResult - - -def error_scatter_data_from_cv_results( - cv_results: list[CVResult], - metric_name: str, -) -> tuple[list[float], list[float], list[float], list[float]]: - """Extract mean and error from CVResults - - Args: - cv_results: list of cross_validation result objects - metric_name: metric name to use for prediction - and observation - Returns: - x: list of x values - x_se: list of x standard error - y: list of y values - y_se: list of y standard error - """ - y = [cv_result.predicted.means_dict[metric_name] for cv_result in cv_results] - y_se = [ - np.sqrt(cv_result.predicted.covariance_matrix[metric_name][metric_name]) - for cv_result in cv_results - ] - - x = [cv_result.observed.data.means_dict[metric_name] for cv_result in cv_results] - x_se = [ - np.sqrt(cv_result.observed.data.covariance_matrix[metric_name][metric_name]) - for cv_result in cv_results - ] - - return x, x_se, y, y_se - - -def cv_results_to_df( - cv_results: list[CVResult], - metric_name: str, -) -> pd.DataFrame: - """Create a dataframe with error scatterplot data - - Args: - cv_results: list of cross validation results - metric_name: name of metric. Predicted val on y-axis, - observed val on x-axis. - """ - - # Opportunistically sort if arm names are in {trial}_{arm} format - cv_results = sorted( - cv_results, - key=lambda c: arm_name_to_sort_key(c.observed.arm_name), - reverse=True, - ) - - x, x_se, y, y_se = error_scatter_data_from_cv_results( - cv_results=cv_results, - metric_name=metric_name, - ) - - arm_names = [c.observed.arm_name for c in cv_results] - records = [] - - for i in range(len(arm_names)): - records.append( - { - "arm_name": arm_names[i], - "metric_name": metric_name, - "x": x[i], - "x_se": x_se[i], - "y": y[i], - "y_se": y_se[i], - "arm_parameters": cv_results[i].observed.features.parameters, - } - ) - return pd.DataFrame.from_records(records) - - -# Helper functions for plotting model fits -def get_min_max_with_errors( - x: list[float], y: list[float], se_x: list[float], se_y: list[float] -) -> tuple[float, float]: - """Get min and max of a bivariate dataset (across variables). - - Args: - x: point estimate of x variable. - y: point estimate of y variable. - se_x: standard error of x variable. - se_y: standard error of y variable. - - Returns: - min_: minimum of points, including uncertainty. - max_: maximum of points, including uncertainty. - - """ - min_ = min( - min(np.array(x) - np.multiply(se_x, Z)), min(np.array(y) - np.multiply(se_y, Z)) - ) - max_ = max( - max(np.array(x) + np.multiply(se_x, Z)), max(np.array(y) + np.multiply(se_y, Z)) - ) - return min_, max_ - - -def get_plotting_limit_ignore_outliers( - x: list[float], y: list[float], se_x: list[float], se_y: list[float] -) -> tuple[list[float], tuple[float, float]]: - """Get a range for a bivarite dataset based on the 25th and 75th percentiles - Used as plotting limit to ignore outliers. - - Args: - x: point estimate of x variable. - y: point estimate of y variable. - se_x: standard error of x variable. - se_y: standard error of y variable. - - Returns: - (min, max): layout axis range - (min, max): diagonal trace range - - """ - se_x = default_value_se_raw(se_raw=se_x, out_length=len(x)) - - min_, max_ = get_min_max_with_errors(x=x, y=y, se_x=se_x, se_y=se_y) - - x_np = np.array(x) - # TODO: replace interpolation->method once it becomes standard. - # pyre-fixme[28]: Unexpected keyword argument `interpolation`. - q1 = np.nanpercentile(x_np, q=25, interpolation="lower").min() - # pyre-fixme[28]: Unexpected keyword argument `interpolation`. - q3 = np.nanpercentile(x_np, q=75, interpolation="higher").max() - quartile_difference = q3 - q1 - - y_lower = q1 - 1.5 * quartile_difference - y_upper = q3 + 1.5 * quartile_difference - - # clip outliers from x - x_np = x_np.clip(y_lower, y_upper).tolist() - min_robust, max_robust = get_min_max_with_errors(x=x_np, y=y, se_x=se_x, se_y=se_y) - y_padding = 0.05 * (max_robust - min_robust) - - layout_range = [ - max(min_robust, min_) - y_padding, - min(max_robust, max_) + y_padding, - ] - diagonal_trace_range = ( - min(min_robust, min_) - y_padding, - max(max_robust, max_) + y_padding, - ) - - return (layout_range, diagonal_trace_range) - - -def diagonal_trace(min_: float, max_: float, visible: bool = True) -> dict[str, Any]: - """Diagonal line trace from (min_, min_) to (max_, max_). - - Args: - min_: minimum to be used for starting point of line. - max_: maximum to be used for ending point of line. - visible: if True, trace is set to visible. - """ - # pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`. - return go.Scatter( - x=[min_, max_], - y=[min_, max_], - line=dict(color="black", width=2, dash="dot"), # noqa: C408 - mode="lines", - hoverinfo="none", - visible=visible, - showlegend=False, - ) - - -def default_value_se_raw(se_raw: list[float] | None, out_length: int) -> list[float]: - """ - Takes a list of standard errors and maps edge cases to default list - of floats. - """ - new_se_raw = ( - [0.0 if np.isnan(se) else se for se in se_raw] - if se_raw is not None - else [0.0] * out_length - ) - return new_se_raw diff --git a/ax/analysis/old/helpers/layout_helpers.py b/ax/analysis/old/helpers/layout_helpers.py deleted file mode 100644 index 67bd3a09c19..00000000000 --- a/ax/analysis/old/helpers/layout_helpers.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Any - -import plotly.graph_objs as go - - -def updatemenus_format(metric_dropdown: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Formats for use in the cross validation plot - """ - return [ - { - "x": 0, - "y": 1.125, - "yanchor": "top", - "xanchor": "left", - "buttons": metric_dropdown, - }, - { - "buttons": [ - { - "args": [ - { - "error_x.width": 4, - "error_x.thickness": 2, - "error_y.width": 4, - "error_y.thickness": 2, - } - ], - "label": "Yes", - "method": "restyle", - }, - { - "args": [ - { - "error_x.width": 0, - "error_x.thickness": 0, - "error_y.width": 0, - "error_y.thickness": 0, - } - ], - "label": "No", - "method": "restyle", - }, - ], - "x": 1.125, - "xanchor": "left", - "y": 0.8, - "yanchor": "middle", - }, - ] - - -def layout_format( - layout_axis_range_value: tuple[float, float], - xlabel: str, - ylabel: str, - updatemenus: list[dict[str, Any]], -) -> type[go.Figure]: - """ - Constructs a layout object for a CrossValidation figure. - args: - layout_axis_range_value: A tuple containing the range of values - for the x-axis and y-axis. - xlabel: Label for the x-axis. - ylabel: Label for the y-axis. - updatemenus: A list of dictionaries containing information to use on update. - """ - layout = go.Layout( - annotations=[ - { - "showarrow": False, - "text": "Show CI", - "x": 1.125, - "xanchor": "left", - "xref": "paper", - "y": 0.9, - "yanchor": "middle", - "yref": "paper", - } - ], - xaxis={ - "range": layout_axis_range_value, - "title": xlabel, - "zeroline": False, - "mirror": True, - "linecolor": "black", - "linewidth": 0.5, - }, - yaxis={ - "range": layout_axis_range_value, - "title": ylabel, - "zeroline": False, - "mirror": True, - "linecolor": "black", - "linewidth": 0.5, - }, - showlegend=False, - hovermode="closest", - updatemenus=updatemenus, - width=530, - height=500, - ) - # pyre-fixme[7]: Expected `Type[Figure]` but got `Layout`. - return layout diff --git a/ax/analysis/old/helpers/plot_data_df_helpers.py b/ax/analysis/old/helpers/plot_data_df_helpers.py deleted file mode 100644 index 445a1f4340e..00000000000 --- a/ax/analysis/old/helpers/plot_data_df_helpers.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -import numpy as np -import pandas as pd - -from ax.modelbridge import ModelBridge -from ax.modelbridge.prediction_utils import predict_at_point - -from ax.modelbridge.transforms.ivw import IVW - - -def get_plot_data_in_sample_arms_df( - model: ModelBridge, - metric_names: set[str], -) -> pd.DataFrame: - """Get in-sample arms from a model with observed and predicted values - for specified metrics. - - Returns a dataframe in which repeated observations are merged - with IVW (inverse variance weighting) - - Args: - model: An instance of the model bridge. - metric_names: Restrict predictions to these metrics. If None, uses all - metrics in the model. - - Returns: - A dataframe containing - columns: - { - "arm_name": name of the arm in the cross validation result - "metric_name": name of the observed/predicted metric - "x": value observed for the metric for this arm - "x_se": standard error of observed metric (0 for observations) - "y": value predicted for the metric for this arm - "y_se": standard error of predicted metric for this arm - "arm_parameters": Parametrization of the arm - } - """ - observations = model.get_training_data() - training_in_design: list[bool] = model.training_in_design - - # Merge multiple measurements within each Observation with IVW to get - # un-modeled prediction - observations = IVW(None, []).transform_observations(observations) - - # Create records for dict - records = [] - - for i, obs in enumerate(observations): - # Extract raw measurement - features = obs.features - - if training_in_design[i]: - pred_y_dict, pred_se_dict = predict_at_point(model, features, metric_names) - else: - pred_y_dict = None - pred_se_dict = None - - for metric_name in obs.data.metric_names: - if metric_name not in metric_names: - continue - obs_y = obs.data.means_dict[metric_name] - obs_se = np.sqrt(obs.data.covariance_matrix[metric_name][metric_name]) - - if pred_y_dict and pred_se_dict: - pred_y = pred_y_dict[metric_name] - pred_se = pred_se_dict[metric_name] - else: - pred_y = obs_y - pred_se = obs_se - - records.append( - { - "arm_name": obs.arm_name, - "metric_name": metric_name, - "x": obs_y, - "x_se": obs_se, - "y": pred_y, - "y_se": pred_se, - "arm_parameters": obs.features.parameters, - } - ) - - return pd.DataFrame.from_records(records) diff --git a/ax/analysis/old/helpers/plot_helpers.py b/ax/analysis/old/helpers/plot_helpers.py deleted file mode 100644 index 5ed1a7dfeee..00000000000 --- a/ax/analysis/old/helpers/plot_helpers.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -from logging import Logger -from typing import Optional, Union - -from ax.analysis.old.helpers.constants import DECIMALS, Z - -from ax.core.generator_run import GeneratorRun - -from ax.core.types import TParameterization -from ax.utils.common.logger import get_logger - -from plotly import graph_objs as go - - -logger: Logger = get_logger(__name__) - -# Typing alias -RawData = list[dict[str, Union[str, float]]] - -TNullableGeneratorRunsDict = Optional[dict[str, GeneratorRun]] - - -def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str: - """Format a dictionary for labels. - - Args: - param_dict: Dictionary to be formatted - name: String name of the thing being formatted. - - Returns: stringified blob. - """ - if len(param_dict) >= 10: - blob = "{} has too many items to render on hover ({}).".format( - name, len(param_dict) - ) - else: - blob = "
{}:
{}".format( - name, "
".join(f"{n}: {v}" for n, v in param_dict.items()) - ) - return blob - - -def _format_CI(estimate: float, sd: float, zval: float = Z) -> str: - """Format confidence intervals given estimate and standard deviation. - - Args: - estimate: point estimate. - sd: standard deviation of point estimate. - zval: z-value associated with desired CI (e.g. 1.96 for 95% CIs) - - Returns: formatted confidence interval. - """ - return "[{lb:.{digits}f}, {ub:.{digits}f}]".format( - lb=estimate - zval * sd, - ub=estimate + zval * sd, - digits=DECIMALS, - ) - - -def resize_subtitles(figure: go.Figure, size: int) -> go.Figure: - """Resize subtitles in a plotly figure - args: - figure: plotly figure to resize subtitles of - size: font size to resize subtitles to - """ - for ant in figure["layout"]["annotations"]: - ant["font"].update(size=size) - return figure - - -def arm_name_to_sort_key(arm_name: str) -> tuple[str, int, int]: - """Parses arm name into tuple suitable for reverse sorting by key - - Example: - arm_names = ["0_0", "1_10", "1_2", "10_0", "control"] - sorted(arm_names, key=arm_name_to_sort_key, reverse=True) - ["control", "0_0", "1_2", "1_10", "10_0"] - """ - try: - trial_index, arm_index = arm_name.split("_") - return ("", -int(trial_index), -int(arm_index)) - except (ValueError, IndexError): - return (arm_name, 0, 0) diff --git a/ax/analysis/old/helpers/scatter_helpers.py b/ax/analysis/old/helpers/scatter_helpers.py deleted file mode 100644 index 13fd9144ac8..00000000000 --- a/ax/analysis/old/helpers/scatter_helpers.py +++ /dev/null @@ -1,314 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import numbers - -from typing import Any - -import numpy as np -import pandas as pd - -import plotly.graph_objs as go -from ax.analysis.old.helpers.color_helpers import rgba - -from ax.analysis.old.helpers.constants import CI_OPACITY, COLORS, DECIMALS, Z - -from ax.analysis.old.helpers.plot_helpers import _format_CI, _format_dict - -from ax.core.types import TParameterization - -from ax.utils.stats.statstools import relativize - -# disable false positive "SettingWithCopyWarning" -pd.options.mode.chained_assignment = None - - -def relativize_dataframe(df: pd.DataFrame, status_quo_name: str) -> pd.DataFrame: - """ - Relativizes the dataframe with respect to "status_quo_name". Assumes as a - precondition that for each metric in the dataframe, there is a row with - arm_name == status_quo_name to relativize against. - - Args: - df: dataframe with the following columns - { - "arm_name": name of the arm in the cross validation result - "metric_name": name of the observed/predicted metric - "x": value of the observation for the metric for this arm - "x_se": standard error of the observation for the metric of this arm - "y": value predicted for the metric for this arm - "y_se": standard error of predicted metric for this arm - } - status_quo_name: name of the status quo arm in the dataframe to use - for relativization. - Returns: - A dataframe containing the same rows as df, with the observation and predicted - data values relativized with respect to the status quo. - An additional column "rel" is added to indicate whether the data is relativized. - - """ - metrics = df["metric_name"].unique() - - def _relativize_filtered_dataframe( - df: pd.DataFrame, metric_name: str, status_quo_name: str - ) -> pd.DataFrame: - df = df.loc[df["metric_name"] == metric_name] - status_quo_row = df.loc[df["arm_name"] == status_quo_name] - - mean_c = status_quo_row["y"].iloc[0] - sem_c = status_quo_row["y_se"].iloc[0] - y_rel, y_se_rel = relativize( - means_t=df["y"].tolist(), - sems_t=df["y_se"].tolist(), - mean_c=mean_c, - sem_c=sem_c, - as_percent=True, - ) - df["y"] = y_rel - df["y_se"] = y_se_rel - - mean_c = status_quo_row["x"].iloc[0] - sem_c = status_quo_row["x_se"].iloc[0] - x_rel, x_se_rel = relativize( - means_t=df["x"].tolist(), - sems_t=df["x_se"].tolist(), - mean_c=mean_c, - sem_c=sem_c, - as_percent=True, - ) - df["x"] = x_rel - df["x_se"] = x_se_rel - df["rel"] = True - return df - - return pd.concat( - [ - _relativize_filtered_dataframe( - df=df, metric_name=metric, status_quo_name=status_quo_name - ) - for metric in metrics - ] - ) - - -def extract_mean_and_error_from_df( - df: pd.DataFrame, -) -> tuple[list[float], list[float], list[float], list[float]]: - """Extract mean and error from dataframe. - - Args: - df: dataframe containing the scatter plot data - Returns: - x: list of x values - x_se: list of x standard error - y: list of y values - y_se: list of y standard error - """ - x = df["x"] - x_se = df["x_se"] - y = df["y"] - y_se = df["y_se"] - - return (x, x_se, y, y_se) - - -def make_label( - arm_name: str, - x_axis_values: tuple[str, float, float] | None, - y_axis_values: tuple[str, float, float], - param_blob: TParameterization, - rel: bool, -) -> str: - """Make label for scatter plot. - - Args: - arm_name: Name of arm - x_axis_values: Optional Tuple of - x_name: Name of x variable - x_val: Value of x variable - x_se: Standard error of x variable - y_axis_values: Tuple of - y_name: Name of y variable - y_val: Value of y variable - y_se: Standard error of y variable - param_blob: Parameterization of arm - rel: whether the data is relativized as a % - - Returns: - Label for scatter plot. - """ - heading = f"Arm {arm_name}
" - x_lab = "" - if x_axis_values is not None: - x_name, x_val, x_se = x_axis_values - x_lab = "{name}: {estimate}{perc} {ci}
".format( - name=x_name, - estimate=( - round(x_val, DECIMALS) if isinstance(x_val, numbers.Number) else x_val - ), - ci=_format_CI(estimate=x_val, sd=x_se), - perc="%" if rel else "", - ) - - y_name, y_val, y_se = y_axis_values - y_lab = "{name}: {estimate}{perc} {ci}
".format( - name=y_name, - estimate=( - round(y_val, DECIMALS) if isinstance(y_val, numbers.Number) else y_val - ), - ci=_format_CI(estimate=y_val, sd=y_se), - perc="%" if rel else "", - ) - - parameterization = _format_dict(param_blob, "Parameterization") - - return "{arm_name}
{xlab}{ylab}{param_blob}".format( - arm_name=heading, - xlab=x_lab, - ylab=y_lab, - param_blob=parameterization, - ) - - -def error_dot_plot_trace_from_df( - df: pd.DataFrame, - show_CI: bool = True, - visible: bool = True, -) -> dict[str, Any]: - """Creates trace for dot plot with confidence intervals. - Categorizes by arm name. - - Args: - df: dataframe containing the scatter plot data - show_CI: if True, plot confidence intervals. - visible: if True, trace will be visible in figure - """ - - _, _, y, y_se = extract_mean_and_error_from_df(df) - - labels = [] - - metric_name = df["metric_name"].iloc[0] - - for _, row in df.iterrows(): - labels.append( - make_label( - arm_name=row["arm_name"], - x_axis_values=(None), - y_axis_values=( - metric_name, - row["y"], - row["y_se"], - ), - param_blob=row["arm_parameters"], - rel=(False if "rel" not in row else row["rel"]), - ) - ) - - trace = go.Scatter( - x=df["arm_name"], - y=y, - marker={"color": rgba(COLORS.STEELBLUE.value)}, - mode="markers", - name="In-sample", - text=labels, - hoverinfo="text", - ) - - if show_CI: - if y_se is not None: - trace.update( - error_y={ - "type": "data", - "array": np.multiply(y_se, Z), - "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY), - } - ) - - trace.update(visible=visible) - trace.update(showlegend=False) - # pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`. - return trace - - -def error_scatter_trace_from_df( - df: pd.DataFrame, - show_CI: bool = True, - visible: bool = True, - y_axis_label: str | None = None, - x_axis_label: str | None = None, -) -> dict[str, Any]: - """Plot scatterplot with error bars. - - Args: - df: dataframe containing the scatter plot data - show_CI: if True, plot confidence intervals. - visible: if True, trace will be visible in figure - y_axis_label: custom label to use for y axis. - If None, use metric name from `y_axis_var`. - x_axis_label: custom label to use for x axis. - If None, use metric name from `x_axis_var` if that is not None. - """ - - x, x_se, y, y_se = extract_mean_and_error_from_df(df) - - labels = [] - - metric_name = df["metric_name"].iloc[0] - - for _, row in df.iterrows(): - labels.append( - make_label( - arm_name=row["arm_name"], - x_axis_values=( - metric_name if x_axis_label is None else x_axis_label, - row["x"], - row["x_se"], - ), - y_axis_values=( - (metric_name if y_axis_label is None else y_axis_label), - row["y"], - row["y_se"], - ), - param_blob=row["arm_parameters"], - rel=(False if "rel" not in row else row["rel"]), - ) - ) - - trace = go.Scatter( - x=x, - y=y, - marker={"color": rgba(COLORS.STEELBLUE.value)}, - mode="markers", - name="In-sample", - text=labels, - hoverinfo="text", - ) - - if show_CI: - if x_se is not None: - trace.update( - error_x={ - "type": "data", - "array": np.multiply(x_se, Z), - "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY), - } - ) - if y_se is not None: - trace.update( - error_y={ - "type": "data", - "array": np.multiply(y_se, Z), - "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY), - } - ) - - trace.update(visible=visible) - trace.update(showlegend=True) - # pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`. - return trace diff --git a/ax/analysis/old/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/old/helpers/tests/test_cross_validation_helpers.py deleted file mode 100644 index 3b615aa2ae0..00000000000 --- a/ax/analysis/old/helpers/tests/test_cross_validation_helpers.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import tempfile - -import plotly.graph_objects as go -import plotly.io as pio -from ax.analysis.old.cross_validation_plot import CrossValidationPlot -from ax.analysis.old.helpers.constants import Z -from ax.analysis.old.helpers.cross_validation_helpers import get_min_max_with_errors -from ax.modelbridge.registry import Models -from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_experiment -from ax.utils.testing.mock import fast_botorch_optimize -from pandas import read_json -from pandas.testing import assert_frame_equal - - -class TestCrossValidationHelpers(TestCase): - @fast_botorch_optimize - def setUp(self) -> None: - super().setUp() - self.exp = get_branin_experiment(with_batch=True) - self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - self.exp_status_quo = get_branin_experiment( - with_batch=True, with_status_quo=True - ) - self.exp_status_quo.trials[0].run() - self.model_status_quo = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - def test_get_min_max_with_errors(self) -> None: - # Test with sample data - x = [1.0, 2.0, 3.0] - y = [4.0, 5.0, 6.0] - sd_x = [0.1, 0.2, 0.3] - sd_y = [0.1, 0.2, 0.3] - min_, max_ = get_min_max_with_errors(x, y, sd_x, sd_y) - - expected_min = 1.0 - 0.1 * Z - expected_max = 6.0 + 0.3 * Z - # Check that the returned values are correct - print(f"min: {min_} {expected_min=}") - print(f"max: {max_} {expected_max=}") - self.assertAlmostEqual(min_, expected_min, delta=1e-4) - self.assertAlmostEqual(max_, expected_max, delta=1e-4) - - def test_obs_vs_pred_dropdown_plot(self) -> None: - cross_validation_plot = CrossValidationPlot( - experiment=self.exp, model=self.model - ) - fig = cross_validation_plot.get_fig() - - self.assertIsInstance(fig, go.Figure) - - def test_store_df_to_file(self) -> None: - with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: - cross_validation_plot = CrossValidationPlot( - experiment=self.exp, model=self.model - ) - cv_df = cross_validation_plot.get_df() - cv_df.to_json(f.name) - - loaded_dataframe = read_json(f.name, dtype={"arm_name": "str"}) - - assert_frame_equal(cv_df, loaded_dataframe, check_dtype=False) - - def test_store_plot_as_dict(self) -> None: - cross_validation_plot = CrossValidationPlot( - experiment=self.exp, model=self.model - ) - cv_fig = cross_validation_plot.get_fig() - - json_obj = pio.to_json(cv_fig, validate=True, remove_uids=False) - - loaded_json_obj = pio.from_json(json_obj, output_type="Figure") - self.assertEqual(cv_fig, loaded_json_obj) diff --git a/ax/analysis/old/helpers/tests/test_cv_consistency_checks.py b/ax/analysis/old/helpers/tests/test_cv_consistency_checks.py deleted file mode 100644 index d36b4c8ab9d..00000000000 --- a/ax/analysis/old/helpers/tests/test_cv_consistency_checks.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import copy - -import plotly.graph_objects as go -from ax.analysis.old.cross_validation_plot import CrossValidationPlot -from ax.analysis.old.helpers.cross_validation_helpers import ( - error_scatter_data_from_cv_results, -) -from ax.analysis.old.helpers.scatter_helpers import error_scatter_trace_from_df -from ax.modelbridge.cross_validation import cross_validate -from ax.modelbridge.registry import Models -from ax.plot.base import PlotMetric -from ax.plot.diagnostic import ( - _get_cv_plot_data as PLOT_get_cv_plot_data, - interact_cross_validation_plotly as PLOT_interact_cross_validation_plotly, -) -from ax.plot.scatter import ( - _error_scatter_data as PLOT_error_scatter_data, - _error_scatter_trace as PLOT_error_scatter_trace, -) -from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_experiment -from ax.utils.testing.mock import fast_botorch_optimize - - -class TestCVConsistencyCheck(TestCase): - @fast_botorch_optimize - def setUp(self) -> None: - super().setUp() - self.exp = get_branin_experiment(with_batch=True) - self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - self.exp_status_quo = get_branin_experiment( - with_batch=True, with_status_quo=True - ) - self.exp_status_quo.trials[0].run() - self.model_status_quo = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - def test_error_scatter_data_branin(self) -> None: - cv_results = cross_validate(self.model) - cv_results_plot = copy.deepcopy(cv_results) - - result_analysis = error_scatter_data_from_cv_results( - cv_results=cv_results, - metric_name="branin", - ) - - data = PLOT_get_cv_plot_data(cv_results_plot, label_dict={}) - result_plot = PLOT_error_scatter_data( - list(data.in_sample.values()), - y_axis_var=PlotMetric("branin", pred=True, rel=False), - x_axis_var=PlotMetric("branin", pred=False, rel=False), - ) - self.assertEqual(result_analysis, result_plot) - - def test_error_scatter_trace_branin(self) -> None: - cv_results = cross_validate(self.model) - cv_results_plot = copy.deepcopy(cv_results) - - cross_validation_plot = CrossValidationPlot( - experiment=self.exp, model=self.model - ) - df = cross_validation_plot.get_df() - - metric_filtered_df = df.loc[df["metric_name"] == "branin"] - result_analysis = error_scatter_trace_from_df( - df=metric_filtered_df, - show_CI=True, - visible=True, - x_axis_label="Actual Outcome", - y_axis_label="Predicted Outcome", - ) - - data = data = PLOT_get_cv_plot_data(cv_results_plot, label_dict={}) - - result_plot = PLOT_error_scatter_trace( - arms=list(data.in_sample.values()), - hoverinfo="text", - show_arm_details_on_hover=True, - show_CI=True, - show_context=False, - status_quo_arm=None, - visible=True, - y_axis_var=PlotMetric("branin", pred=True, rel=False), - x_axis_var=PlotMetric("branin", pred=False, rel=False), - x_axis_label="Actual Outcome", - y_axis_label="Predicted Outcome", - ) - - print(str(result_analysis)) - print(str(result_plot)) - - self.assertEqual(result_analysis, result_plot) - - def test_obs_vs_pred_dropdown_plot_branin(self) -> None: - label_dict = {"branin": "BrAnIn"} - - cross_validation_plot = CrossValidationPlot( - experiment=self.exp, model=self.model, label_dict=label_dict - ) - fig = cross_validation_plot.get_fig() - - self.assertIsInstance(fig, go.Figure) - - cv_results_plot = cross_validate(self.model) - - fig_PLOT = PLOT_interact_cross_validation_plotly( - cv_results_plot, - show_context=False, - label_dict=label_dict, - caption=CrossValidationPlot.CROSS_VALIDATION_CAPTION, - ) - self.assertEqual(fig, fig_PLOT) diff --git a/ax/analysis/old/predicted_outcomes_dot_plot.py b/ax/analysis/old/predicted_outcomes_dot_plot.py deleted file mode 100644 index 6022134616d..00000000000 --- a/ax/analysis/old/predicted_outcomes_dot_plot.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Any - -import numpy as np - -import pandas as pd - -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization - -from ax.analysis.old.helpers.layout_helpers import updatemenus_format - -from ax.analysis.old.helpers.plot_data_df_helpers import get_plot_data_in_sample_arms_df -from ax.analysis.old.helpers.plot_helpers import arm_name_to_sort_key, resize_subtitles - -from ax.analysis.old.helpers.scatter_helpers import ( - error_dot_plot_trace_from_df, - relativize_dataframe, -) - -from ax.core.experiment import Experiment - -from ax.exceptions.core import UnsupportedPlotError - -from ax.modelbridge import ModelBridge - -from plotly import graph_objs as go - - -class PredictedOutcomesDotPlot(BasePlotlyVisualization): - def __init__( - self, - experiment: Experiment, - model: ModelBridge, - ) -> None: - """ - Args: - experiment: The experiment associated with this plot - model: model which is used to fetch the plotting data. - """ - - self.model = model - self.metrics: set[str] = model.metric_names - if model.status_quo is None or model.status_quo.arm_name is None: - raise UnsupportedPlotError( - "status quo must be specified for PredictedOutcomesDotPlot" - ) - self.status_quo_name: str = model.status_quo.arm_name - - super().__init__(experiment=experiment) - - def get_df(self) -> pd.DataFrame: - """ - Returns: - A dataframe containing: - { - "arm_name": name of the arm in the cross validation result - "metric_name": name of the observed/predicted metric - "x": value of the observation for the metric for this arm - "x_se": standard error of the observation for the metric of this arm - "y": value predicted for the metric for this arm - "y_se": standard error of predicted metric for this arm - "arm_parameters": Parametrization of the arm - "rel": whether the data is relativized with respect to status quo - }""" - return relativize_dataframe( - get_plot_data_in_sample_arms_df( - model=self.model, metric_names=self.metrics - ), - status_quo_name=self.status_quo_name, - ) - - def get_fig( - self, - ) -> go.Figure: - """ - For each metric, we plot the predicted values for each arm along with its CI - These values are relativized with respect to the status quo. - """ - name_order_axes: dict[str, dict[str, Any]] = {} - - in_sample_df = self.get_df() - traces = [] - metric_dropdown = [] - - for i, metric in enumerate(self.metrics): - filtered_df = in_sample_df.loc[in_sample_df["metric_name"] == metric] - data_single: dict[str, Any] = error_dot_plot_trace_from_df( - df=filtered_df, show_CI=True, visible=(i == 0) - ) - - # order arm name sorting arm numbers within batch - names_by_arm = sorted( - np.unique(data_single["x"]), - key=lambda x: arm_name_to_sort_key(x), - reverse=True, - ) - - name_order_axes[f"xaxis{i + 1}"] = { - "categoryorder": "array", - "categoryarray": names_by_arm, - "type": "category", - } - name_order_axes[f"yaxis{i + 1}"] = { - "ticksuffix": "%", - "zerolinecolor": "red", - } - - traces.append(data_single) - - is_visible = [False] * (len(metric)) - is_visible[i] = True - metric_dropdown.append( - { - "args": [ - { - "visible": is_visible, - }, - ], - "label": metric, - "method": "update", - } - ) - - updatemenus = updatemenus_format(metric_dropdown=metric_dropdown) - - fig = go.Figure(data=traces) - - fig["layout"].update( - updatemenus=updatemenus, - width=1030, - height=500, - **name_order_axes, - ) - - fig = resize_subtitles(figure=fig, size=10) - fig["layout"]["title"] = "Predicted Outcomes by Metric" - return fig diff --git a/ax/analysis/old/tests/test_analysis_report.py b/ax/analysis/old/tests/test_analysis_report.py deleted file mode 100644 index 919baf8aaa6..00000000000 --- a/ax/analysis/old/tests/test_analysis_report.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import pandas as pd -import plotly.graph_objects as go - -from ax.analysis.old.analysis_report import AnalysisReport - -from ax.analysis.old.base_analysis import BaseAnalysis -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization - -from ax.modelbridge.registry import Models -from ax.utils.common.testutils import TestCase - -from ax.utils.common.timeutils import current_timestamp_in_millis -from ax.utils.testing.core_stubs import get_branin_experiment -from ax.utils.testing.mock import fast_botorch_optimize - - -class TestCrossValidationPlot(TestCase): - class TestAnalysis(BaseAnalysis): - def get_df(self) -> pd.DataFrame: - return pd.DataFrame() - - class TestPlotlyVisualization(BasePlotlyVisualization): - def get_df(self) -> pd.DataFrame: - return pd.DataFrame() - - def get_fig(self) -> go.Figure: - return go.Figure() - - @fast_botorch_optimize - def setUp(self) -> None: - super().setUp() - - self.exp = get_branin_experiment(with_batch=True) - self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - self.test_analysis = self.TestAnalysis(experiment=self.exp) - self.test_plotly_visualization = self.TestPlotlyVisualization( - experiment=self.exp - ) - - def test_init_analysis_report(self) -> None: - analysis_report = AnalysisReport( - experiment=self.exp, - analyses=[self.test_analysis, self.test_plotly_visualization], - ) - - self.assertEqual(len(analysis_report.analyses), 2) - - self.assertIsInstance(analysis_report.analyses[0], BaseAnalysis) - self.assertIsInstance(analysis_report.analyses[1], BasePlotlyVisualization) - - self.assertIsNone(analysis_report.time_started) - self.assertIsNone(analysis_report.time_completed) - self.assertFalse(analysis_report.report_completed) - - def test_execute_analysis_report(self) -> None: - analysis_report = AnalysisReport( - experiment=self.exp, - analyses=[self.test_analysis, self.test_plotly_visualization], - ) - - self.assertEqual(len(analysis_report.analyses), 2) - - results = analysis_report.run_analysis_report() - self.assertEqual(len(results), 2) - - self.assertIsNotNone(analysis_report.time_started) - self.assertIsNotNone(analysis_report.time_completed) - self.assertTrue(analysis_report.report_completed) - - # assert no plot is returned for BaseAnalysis - self.assertIsInstance(results[0][1], pd.DataFrame) - self.assertIsNone(results[0][2]) - - # assert plot is returned for BasePlotlyVisualization - self.assertIsInstance(results[1][1], pd.DataFrame) - self.assertIsInstance(results[1][2], go.Figure) - - def test_analysis_report_repeated_execute(self) -> None: - analysis_report = AnalysisReport( - experiment=self.exp, - analyses=[self.test_analysis, self.test_plotly_visualization], - ) - - self.assertEqual(len(analysis_report.analyses), 2) - - _ = analysis_report.run_analysis_report() - - saved_start = analysis_report.time_started - self.assertIsNotNone(saved_start) - - # ensure analyses are not re-ran - _ = analysis_report.run_analysis_report() - self.assertEqual(saved_start, analysis_report.time_started) - - def test_no_analysis_report(self) -> None: - analysis_report = AnalysisReport( - experiment=self.exp, - analyses=[], - ) - - self.assertEqual(len(analysis_report.analyses), 0) - - results = analysis_report.run_analysis_report() - self.assertEqual(len(results), 0) - - def test_singleton_analysis_report(self) -> None: - analysis_report = AnalysisReport( - experiment=self.exp, - analyses=[self.test_plotly_visualization], - ) - - self.assertEqual(len(analysis_report.analyses), 1) - - results = analysis_report.run_analysis_report() - self.assertEqual(len(results), 1) - - # assert plot is returned for BasePlotlyVisualization - self.assertIsInstance(results[0][1], pd.DataFrame) - self.assertIsInstance(results[0][2], go.Figure) - - def test_create_report_as_completed(self) -> None: - analysis_report = AnalysisReport( - experiment=self.exp, - analyses=[self.test_analysis, self.test_plotly_visualization], - time_started=current_timestamp_in_millis(), - time_completed=current_timestamp_in_millis(), - ) - - self.assertIsNotNone(analysis_report.time_started) - self.assertIsNotNone(analysis_report.time_completed) - self.assertTrue(analysis_report.report_completed) diff --git a/ax/analysis/old/tests/test_base_classes.py b/ax/analysis/old/tests/test_base_classes.py deleted file mode 100644 index 64d2040e056..00000000000 --- a/ax/analysis/old/tests/test_base_classes.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import pandas as pd - -import plotly.express as px - -import plotly.graph_objects as go - -from ax.analysis.old.base_analysis import BaseAnalysis -from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization - -from ax.modelbridge.registry import Models - -from ax.utils.common.testutils import TestCase - -from ax.utils.testing.core_stubs import get_branin_experiment -from ax.utils.testing.mock import fast_botorch_optimize - - -class TestBaseClasses(TestCase): - class TestAnalysis(BaseAnalysis): - get_df_call_count: int = 0 - - def get_df(self) -> pd.DataFrame: - self.get_df_call_count = self.get_df_call_count + 1 - return pd.DataFrame() - - class TestPlotlyVisualization(BasePlotlyVisualization): - get_df_call_count: int = 0 - get_fig_call_count: int = 0 - - def get_df(self) -> pd.DataFrame: - self.get_df_call_count = self.get_df_call_count + 1 - return pd.DataFrame() - - def get_fig(self) -> go.Figure: - self.get_fig_call_count = self.get_fig_call_count + 1 - return go.Figure() - - @fast_botorch_optimize - def setUp(self) -> None: - super().setUp() - - self.exp = get_branin_experiment(with_batch=True) - self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - self.test_analysis = self.TestAnalysis(experiment=self.exp) - self.test_plotly_visualization = self.TestPlotlyVisualization( - experiment=self.exp - ) - - def test_base_analysis_df_property(self) -> None: - test_analysis = self.TestAnalysis(experiment=self.exp) - - self.assertEqual(test_analysis.get_df_call_count, 0) - - # accessing the df property calls get_df - _ = test_analysis.df - self.assertEqual(test_analysis.get_df_call_count, 1) - - # once saved, get_df is not called again. - _ = test_analysis.df - self.assertEqual(test_analysis.get_df_call_count, 1) - - def test_base_analysis_pass_df_in(self) -> None: - existing_df = pd.DataFrame([1]) - test_analysis = self.TestAnalysis(experiment=self.exp, df_input=existing_df) - - self.assertEqual(test_analysis.get_df_call_count, 0) - saved_df = test_analysis.df - self.assertTrue(existing_df.equals(saved_df)) - - # when df is passed in directly, get_df is never called. - self.assertEqual(test_analysis.get_df_call_count, 0) - - def test_base_plotly_visualization_fig_property(self) -> None: - test_analysis = self.TestPlotlyVisualization(experiment=self.exp) - - self.assertEqual(test_analysis.get_fig_call_count, 0) - - # accessing the fig property calls get_fig - _ = test_analysis.fig - self.assertEqual(test_analysis.get_fig_call_count, 1) - - # once saved, get_fig is not called again. - _ = test_analysis.fig - self.assertEqual(test_analysis.get_fig_call_count, 1) - - def test_base_plotly_visualization_pass_fig_in(self) -> None: - fig = px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16]) - test_analysis = self.TestPlotlyVisualization( - experiment=self.exp, - fig_input=fig, - ) - - self.assertEqual(test_analysis.get_fig_call_count, 0) - saved_fig = test_analysis.fig - self.assertEqual(fig, saved_fig) - - # when fig is passed in directly, get_fig is never called. - self.assertEqual(test_analysis.get_fig_call_count, 0) - - def test_base_plotly_visualization_pass_df_and_fig(self) -> None: - df = pd.DataFrame([1]) - fig = px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16]) - - test_analysis = self.TestPlotlyVisualization( - experiment=self.exp, - df_input=df, - fig_input=fig, - ) - - self.assertEqual(test_analysis.get_df_call_count, 0) - saved_df = test_analysis.df - self.assertTrue(df.equals(saved_df)) - # when df is passed in directly, get_df is never called. - self.assertEqual(test_analysis.get_df_call_count, 0) - - self.assertEqual(test_analysis.get_fig_call_count, 0) - saved_fig = test_analysis.fig - self.assertEqual(fig, saved_fig) - # when fig is passed in directly, get_fig is never called. - self.assertEqual(test_analysis.get_fig_call_count, 0) - - def test_instantiate_base_classes(self) -> None: - test_analysis = BaseAnalysis(experiment=self.exp) - with self.assertRaises(NotImplementedError): - _ = test_analysis.df - - test_fig = BasePlotlyVisualization(experiment=self.exp) - with self.assertRaises(NotImplementedError): - _ = test_fig.fig diff --git a/ax/analysis/old/tests/test_cross_validation_plot.py b/ax/analysis/old/tests/test_cross_validation_plot.py deleted file mode 100644 index 5a6f68dea7a..00000000000 --- a/ax/analysis/old/tests/test_cross_validation_plot.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import plotly.graph_objects as go -from ax.analysis.old.cross_validation_plot import CrossValidationPlot -from ax.modelbridge.registry import Models -from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_experiment -from ax.utils.testing.mock import fast_botorch_optimize - - -class TestCrossValidationPlot(TestCase): - @fast_botorch_optimize - def setUp(self) -> None: - super().setUp() - self.exp = get_branin_experiment(with_batch=True) - self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - def test_cross_validation_plot(self) -> None: - plot = CrossValidationPlot(experiment=self.exp, model=self.model).get_fig() - x_range = plot.layout.updatemenus[0].buttons[0].args[1]["xaxis.range"] - y_range = plot.layout.updatemenus[0].buttons[0].args[1]["yaxis.range"] - self.assertTrue((len(x_range) == 2) and (x_range[0] < x_range[1])) - self.assertTrue((len(y_range) == 2) and (y_range[0] < y_range[1])) - - self.assertIsInstance(plot, go.Figure) - - def test_cross_validation_plot_get_df(self) -> None: - plot = CrossValidationPlot(experiment=self.exp, model=self.model) - _ = plot.get_df() diff --git a/ax/analysis/old/tests/test_predicted_outcomes_dot_plot.py b/ax/analysis/old/tests/test_predicted_outcomes_dot_plot.py deleted file mode 100644 index 89669330abc..00000000000 --- a/ax/analysis/old/tests/test_predicted_outcomes_dot_plot.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -import unittest - -import plotly.graph_objects as go -from ax.analysis.old.predicted_outcomes_dot_plot import PredictedOutcomesDotPlot -from ax.exceptions.core import UnsupportedPlotError -from ax.modelbridge.registry import Models -from ax.utils.testing.core_stubs import get_branin_experiment - - -class TestPredictedOutcomesDotPlot(unittest.TestCase): - def setUp(self) -> None: - super().setUp() - self.exp = get_branin_experiment(with_batch=True, with_status_quo=True) - self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=self.exp, - data=self.exp.fetch_data(), - ) - - def test_predicted_outcomes_dot_plot_no_status_quo(self) -> None: - exp = get_branin_experiment(with_batch=True, with_status_quo=False) - exp.trials[0].run() - model = Models.BOTORCH_MODULAR( - # Model bridge kwargs - experiment=exp, - data=exp.fetch_data(), - ) - - with self.assertRaisesRegex( - UnsupportedPlotError, - "status quo must be specified for PredictedOutcomesDotPlot", - ): - _ = PredictedOutcomesDotPlot( - experiment=exp, - model=model, - ) - - def test_predicted_outcomes_dot_plot(self) -> None: - predicted_outcomes_dot_plot = PredictedOutcomesDotPlot( - experiment=self.exp, - model=self.model, - ) - - _ = predicted_outcomes_dot_plot.get_df() - - fig = predicted_outcomes_dot_plot.get_fig() - self.assertIsInstance(fig, go.Figure)