diff --git a/ax/analysis/old/analysis_report.py b/ax/analysis/old/analysis_report.py
deleted file mode 100644
index 350213a0e44..00000000000
--- a/ax/analysis/old/analysis_report.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-
-import pandas as pd
-import plotly.graph_objects as go
-
-from ax.analysis.old.base_analysis import BaseAnalysis
-from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
-
-from ax.core.experiment import Experiment
-
-from ax.utils.common.timeutils import current_timestamp_in_millis
-
-
-class AnalysisReport:
- """
- A class corresponding to a set of analysis ran on the same
- set of data from an experiment.
- """
-
- analyses: list[BaseAnalysis] = []
- experiment: Experiment
-
- time_started: int | None = None
- time_completed: int | None = None
-
- def __init__(
- self,
- experiment: Experiment,
- analyses: list[BaseAnalysis],
- time_started: int | None = None,
- time_completed: int | None = None,
- ) -> None:
- """
- This class is a collection of AnalysisReport.
-
- Args:
- experiment: Experiment which the analyses are generated from
- time_started: time the completed report was started
- time_completed: time the completed report was started
-
- """
- self.experiment = experiment
- self.analyses = analyses
- self.time_started = time_started
- self.time_completed = time_completed
-
- @property
- def report_completed(self) -> bool:
- """
- Returns:
- True if the report is completed, False otherwise.
- """
- return self.time_completed is not None
-
- def run_analysis_report(
- self,
- ) -> list[
- tuple[
- BaseAnalysis,
- pd.DataFrame,
- go.Figure | None,
- ]
- ]:
- """
- Runs all analyses in the report and produces the result.
-
- Returns:
- analysis_report_result: list of tuples (analysis, df, Optional[fig])
- """
- if not self.report_completed:
- self.time_started = current_timestamp_in_millis()
-
- analysis_report_result = []
- for analysis in self.analyses:
- analysis_report_result.append(
- (
- analysis,
- analysis.get_df(),
- (
- None
- if not isinstance(analysis, BasePlotlyVisualization)
- else analysis.get_fig()
- ),
- )
- )
-
- if not self.report_completed:
- self.time_completed = current_timestamp_in_millis()
-
- return analysis_report_result
diff --git a/ax/analysis/old/base_analysis.py b/ax/analysis/old/base_analysis.py
deleted file mode 100644
index d7a69187879..00000000000
--- a/ax/analysis/old/base_analysis.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-
-import pandas as pd
-from ax.core.experiment import Experiment
-
-
-class BaseAnalysis:
- """
- Abstract Analysis class for ax.
- This is an interface that defines the methods to be implemented by all analyses.
- Computes an output dataframe for each analysis
- """
-
- def __init__(
- self,
- experiment: Experiment,
- df_input: pd.DataFrame | None = None,
- # TODO: add support for passing in experiment name, and markdown message
- ) -> None:
- """
- Initialize the analysis with the experiment object.
- For scenarios where an analysis output is already available,
- we can pass the dataframe as an input.
- """
- self._experiment = experiment
- self._df: pd.DataFrame | None = df_input
-
- @property
- def experiment(self) -> Experiment:
- return self._experiment
-
- @property
- def df(self) -> pd.DataFrame:
- """
- Return the output of the analysis of this class.
- """
- if self._df is None:
- self._df = self.get_df()
- return self._df
-
- def get_df(self) -> pd.DataFrame:
- """
- Return the output of the analysis of this class.
- Subclasses should overwrite this.
- """
- raise NotImplementedError("get_df must be implemented by subclass")
diff --git a/ax/analysis/old/base_plotly_visualization.py b/ax/analysis/old/base_plotly_visualization.py
deleted file mode 100644
index ab65dc0d0e6..00000000000
--- a/ax/analysis/old/base_plotly_visualization.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-
-import pandas as pd
-
-import plotly.graph_objects as go
-
-from ax.analysis.old.base_analysis import BaseAnalysis
-from ax.core.experiment import Experiment
-
-
-class BasePlotlyVisualization(BaseAnalysis):
- """
- Abstract PlotlyVisualization class for ax.
- This is an interface that defines the method to be implemented by all ax plots.
- Computes an output dataframe for each analysis
- """
-
- def __init__(
- self,
- experiment: Experiment,
- df_input: pd.DataFrame | None = None,
- fig_input: go.Figure | None = None,
- ) -> None:
- """
- Initialize the analysis with the experiment object.
- For scenarios where an analysis output is already available,
- we can pass the dataframe as an input.
- """
- self._fig = fig_input
- super().__init__(experiment=experiment, df_input=df_input)
-
- @property
- def fig(self) -> go.Figure:
- """
- Return the output of the analysis of this class.
- """
- if self._fig is None:
- self._fig = self.get_fig()
- return self._fig
-
- def get_fig(self) -> go.Figure:
- """
- Return the plotly figure of the analysis of this class.
- Subclasses should overwrite this.
- """
- raise NotImplementedError("get_fig must be implemented by subclass")
diff --git a/ax/analysis/old/cross_validation_plot.py b/ax/analysis/old/cross_validation_plot.py
deleted file mode 100644
index 3f3ac5646a3..00000000000
--- a/ax/analysis/old/cross_validation_plot.py
+++ /dev/null
@@ -1,252 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-from copy import deepcopy
-from typing import Any
-
-import pandas as pd
-
-from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
-
-from ax.analysis.old.helpers.cross_validation_helpers import (
- cv_results_to_df,
- diagonal_trace,
- get_plotting_limit_ignore_outliers,
-)
-
-from ax.analysis.old.helpers.layout_helpers import layout_format, updatemenus_format
-
-from ax.analysis.old.helpers.scatter_helpers import (
- error_scatter_trace_from_df,
- extract_mean_and_error_from_df,
-)
-
-from ax.core.experiment import Experiment
-from ax.modelbridge import ModelBridge
-
-from ax.modelbridge.cross_validation import cross_validate, CVResult
-
-from plotly import graph_objs as go
-
-
-class CrossValidationPlot(BasePlotlyVisualization):
- CROSS_VALIDATION_CAPTION = (
- "NOTE: We have tried our best to only plot the region of interest.
"
- "This may hide outliers. You can autoscale the axes to see all trials."
- )
-
- def __init__(
- self,
- experiment: Experiment,
- model: ModelBridge,
- label_dict: dict[str, str] | None = None,
- caption: str = CROSS_VALIDATION_CAPTION,
- ) -> None:
- """
- Args:
- experiment: Experiment containing trials to plot
- model: ModelBridge to cross validate against
- label_dict: optional map from real metric names to shortened names
- caption: text to display below the plot
- """
- self.model = model
- self.cv: list[CVResult] = cross_validate(model=model)
-
- self.label_dict: dict[str, str] | None = label_dict
- if self.label_dict:
- self.cv = self.remap_label(cv_results=self.cv, label_dict=self.label_dict)
-
- self.metric_names: list[str] = list(
- set().union(*(cv_result.predicted.metric_names for cv_result in self.cv))
- )
- self.caption = caption
-
- super().__init__(experiment=experiment)
-
- def get_df(self) -> pd.DataFrame:
- """
- Overrides BaseAnalysis.get_df()
-
- Returns:
- df representation of the cross validation results.
- columns:
- {
- "arm_name": name of the arm in the cross validation result
- "metric_name": name of the observed/predicted metric
- "x": value observed for the metric for this arm
- "x_se": standard error of observed metric (0 for observations)
- "y": value predicted for the metric for this arm
- "y_se": standard error of predicted metric for this arm
- "arm_parameters": Parametrization of the arm
- }
- """
-
- df = pd.concat(
- [
- cv_results_to_df(
- cv_results=self.cv,
- metric_name=metric,
- )
- for metric in self.metric_names
- ]
- )
-
- return df
-
- @staticmethod
- def compose_annotation(
- caption: str, x: float = 0.0, y: float = -0.15
- ) -> list[dict[str, Any]]:
- """Composes an annotation dict for use in Plotly figure.
- args:
- caption: str to use for dropdown text
- x: x position of the annotation
- y: y position of the annotation
-
- returns:
- Annotation dict for use in Plotly figure.
- """
- return [
- {
- "showarrow": False,
- "text": caption,
- "x": x,
- "xanchor": "left",
- "xref": "paper",
- "y": y,
- "yanchor": "top",
- "yref": "paper",
- "align": "left",
- },
- ]
-
- @staticmethod
- def remap_label(
- cv_results: list[CVResult], label_dict: dict[str, str]
- ) -> list[CVResult]:
- """Remaps labels in cv_results according to label_dict.
-
- Args:
- cv_results: A CVResult for each observation in the training data.
- label_dict: optional map from real metric names to shortened names
-
- Returns:
- A CVResult with metric names mapped from label_dict.
- """
- cv_results = deepcopy(cv_results) # Copy and edit in-place
- for cv_i in cv_results:
- cv_i.observed.data.metric_names = [
- label_dict.get(m, m) for m in cv_i.observed.data.metric_names
- ]
- cv_i.predicted.metric_names = [
- label_dict.get(m, m) for m in cv_i.predicted.metric_names
- ]
- return cv_results
-
- def obs_vs_pred_dropdown_plot(
- self,
- xlabel: str = "Actual Outcome",
- ylabel: str = "Predicted Outcome",
- ) -> go.Figure:
- """Plot a dropdown plot of observed vs. predicted values from the
- cross validation results.
-
- Args:
- xlabel: Label for x-axis.
- ylabel: Label for y-axis.
- """
- traces = []
- metric_dropdown = []
- layout_axis_range = []
-
- # Get the union of all metric_names seen in predictions
- metric_names = self.metric_names
- df = self.get_df()
-
- for i, metric in enumerate(metric_names):
- metric_filtered_df = df.loc[df["metric_name"] == metric]
-
- y_raw, se_raw, y_hat, se_hat = extract_mean_and_error_from_df(
- metric_filtered_df
- )
-
- # Use the min/max of the limits
- layout_range, diagonal_trace_range = get_plotting_limit_ignore_outliers(
- x=y_raw, y=y_hat, se_x=se_raw, se_y=se_hat
- )
- layout_axis_range.append(layout_range)
-
- # add a diagonal dotted line to plot
- traces.append(
- diagonal_trace(
- diagonal_trace_range[0],
- diagonal_trace_range[1],
- visible=(i == 0),
- )
- )
-
- traces.append(
- error_scatter_trace_from_df(
- df=metric_filtered_df,
- show_CI=True,
- visible=(i == 0),
- x_axis_label="Actual Outcome",
- y_axis_label="Predicted Outcome",
- )
- )
-
- # only the first two traces are visible (corresponding to first outcome
- # in dropdown)
- is_visible = [False] * (len(metric_names) * 2)
- is_visible[2 * i] = True
- is_visible[2 * i + 1] = True
-
- # on dropdown change, restyle
- metric_dropdown.append(
- {
- "args": [
- {"visible": is_visible},
- {
- "xaxis.range": layout_axis_range[-1],
- "yaxis.range": layout_axis_range[-1],
- },
- ],
- "label": metric,
- "method": "update",
- }
- )
-
- updatemenus = updatemenus_format(metric_dropdown=metric_dropdown)
- layout = layout_format(
- layout_axis_range_value=layout_axis_range[0],
- xlabel=xlabel,
- ylabel=ylabel,
- updatemenus=updatemenus,
- )
-
- return go.Figure(data=traces, layout=layout)
-
- def get_fig(self) -> go.Figure:
- """
- Interactive cross-validation (CV) plotting; select metric via dropdown.
- Note: uses the Plotly version of dropdown (which means that all data is
- stored within the notebook).
-
- Returns:
- go.Figure: Plotly figure with cross validation plot
- """
- caption = self.caption
-
- fig = self.obs_vs_pred_dropdown_plot()
-
- current_bmargin = fig["layout"]["margin"].b or 90
- caption_height = 100 * (len(caption) > 0)
- fig["layout"]["margin"].b = current_bmargin + caption_height
- fig["layout"]["height"] += caption_height
- fig["layout"]["annotations"] += tuple(self.compose_annotation(caption))
- fig["layout"]["title"] = "Cross-validation"
- return fig
diff --git a/ax/analysis/old/helpers/color_helpers.py b/ax/analysis/old/helpers/color_helpers.py
deleted file mode 100644
index 5eda51507bd..00000000000
--- a/ax/analysis/old/helpers/color_helpers.py
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-
-from numbers import Real
-
-# type aliases
-TRGB = tuple[Real, ...]
-
-
-def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str:
- """Convert RGB tuple to an RGBA string."""
- return "rgba({},{},{},{alpha})".format(*rgb_tuple, alpha=alpha)
diff --git a/ax/analysis/old/helpers/constants.py b/ax/analysis/old/helpers/constants.py
deleted file mode 100644
index 63de5a55094..00000000000
--- a/ax/analysis/old/helpers/constants.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-import enum
-
-# Constants used for numerous plots
-CI_OPACITY = 0.4
-DECIMALS = 3
-Z = 1.96
-
-
-# color constants used for plotting
-class COLORS(enum.Enum):
- STEELBLUE = (128, 177, 211)
- CORAL = (251, 128, 114)
- TEAL = (141, 211, 199)
- PINK = (188, 128, 189)
- LIGHT_PURPLE = (190, 186, 218)
- ORANGE = (253, 180, 98)
diff --git a/ax/analysis/old/helpers/cross_validation_helpers.py b/ax/analysis/old/helpers/cross_validation_helpers.py
deleted file mode 100644
index 3522c37abda..00000000000
--- a/ax/analysis/old/helpers/cross_validation_helpers.py
+++ /dev/null
@@ -1,200 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-from typing import Any
-
-import numpy as np
-import pandas as pd
-import plotly.graph_objs as go
-
-from ax.analysis.old.helpers.constants import Z
-
-from ax.analysis.old.helpers.plot_helpers import arm_name_to_sort_key
-
-from ax.modelbridge.cross_validation import CVResult
-
-
-def error_scatter_data_from_cv_results(
- cv_results: list[CVResult],
- metric_name: str,
-) -> tuple[list[float], list[float], list[float], list[float]]:
- """Extract mean and error from CVResults
-
- Args:
- cv_results: list of cross_validation result objects
- metric_name: metric name to use for prediction
- and observation
- Returns:
- x: list of x values
- x_se: list of x standard error
- y: list of y values
- y_se: list of y standard error
- """
- y = [cv_result.predicted.means_dict[metric_name] for cv_result in cv_results]
- y_se = [
- np.sqrt(cv_result.predicted.covariance_matrix[metric_name][metric_name])
- for cv_result in cv_results
- ]
-
- x = [cv_result.observed.data.means_dict[metric_name] for cv_result in cv_results]
- x_se = [
- np.sqrt(cv_result.observed.data.covariance_matrix[metric_name][metric_name])
- for cv_result in cv_results
- ]
-
- return x, x_se, y, y_se
-
-
-def cv_results_to_df(
- cv_results: list[CVResult],
- metric_name: str,
-) -> pd.DataFrame:
- """Create a dataframe with error scatterplot data
-
- Args:
- cv_results: list of cross validation results
- metric_name: name of metric. Predicted val on y-axis,
- observed val on x-axis.
- """
-
- # Opportunistically sort if arm names are in {trial}_{arm} format
- cv_results = sorted(
- cv_results,
- key=lambda c: arm_name_to_sort_key(c.observed.arm_name),
- reverse=True,
- )
-
- x, x_se, y, y_se = error_scatter_data_from_cv_results(
- cv_results=cv_results,
- metric_name=metric_name,
- )
-
- arm_names = [c.observed.arm_name for c in cv_results]
- records = []
-
- for i in range(len(arm_names)):
- records.append(
- {
- "arm_name": arm_names[i],
- "metric_name": metric_name,
- "x": x[i],
- "x_se": x_se[i],
- "y": y[i],
- "y_se": y_se[i],
- "arm_parameters": cv_results[i].observed.features.parameters,
- }
- )
- return pd.DataFrame.from_records(records)
-
-
-# Helper functions for plotting model fits
-def get_min_max_with_errors(
- x: list[float], y: list[float], se_x: list[float], se_y: list[float]
-) -> tuple[float, float]:
- """Get min and max of a bivariate dataset (across variables).
-
- Args:
- x: point estimate of x variable.
- y: point estimate of y variable.
- se_x: standard error of x variable.
- se_y: standard error of y variable.
-
- Returns:
- min_: minimum of points, including uncertainty.
- max_: maximum of points, including uncertainty.
-
- """
- min_ = min(
- min(np.array(x) - np.multiply(se_x, Z)), min(np.array(y) - np.multiply(se_y, Z))
- )
- max_ = max(
- max(np.array(x) + np.multiply(se_x, Z)), max(np.array(y) + np.multiply(se_y, Z))
- )
- return min_, max_
-
-
-def get_plotting_limit_ignore_outliers(
- x: list[float], y: list[float], se_x: list[float], se_y: list[float]
-) -> tuple[list[float], tuple[float, float]]:
- """Get a range for a bivarite dataset based on the 25th and 75th percentiles
- Used as plotting limit to ignore outliers.
-
- Args:
- x: point estimate of x variable.
- y: point estimate of y variable.
- se_x: standard error of x variable.
- se_y: standard error of y variable.
-
- Returns:
- (min, max): layout axis range
- (min, max): diagonal trace range
-
- """
- se_x = default_value_se_raw(se_raw=se_x, out_length=len(x))
-
- min_, max_ = get_min_max_with_errors(x=x, y=y, se_x=se_x, se_y=se_y)
-
- x_np = np.array(x)
- # TODO: replace interpolation->method once it becomes standard.
- # pyre-fixme[28]: Unexpected keyword argument `interpolation`.
- q1 = np.nanpercentile(x_np, q=25, interpolation="lower").min()
- # pyre-fixme[28]: Unexpected keyword argument `interpolation`.
- q3 = np.nanpercentile(x_np, q=75, interpolation="higher").max()
- quartile_difference = q3 - q1
-
- y_lower = q1 - 1.5 * quartile_difference
- y_upper = q3 + 1.5 * quartile_difference
-
- # clip outliers from x
- x_np = x_np.clip(y_lower, y_upper).tolist()
- min_robust, max_robust = get_min_max_with_errors(x=x_np, y=y, se_x=se_x, se_y=se_y)
- y_padding = 0.05 * (max_robust - min_robust)
-
- layout_range = [
- max(min_robust, min_) - y_padding,
- min(max_robust, max_) + y_padding,
- ]
- diagonal_trace_range = (
- min(min_robust, min_) - y_padding,
- max(max_robust, max_) + y_padding,
- )
-
- return (layout_range, diagonal_trace_range)
-
-
-def diagonal_trace(min_: float, max_: float, visible: bool = True) -> dict[str, Any]:
- """Diagonal line trace from (min_, min_) to (max_, max_).
-
- Args:
- min_: minimum to be used for starting point of line.
- max_: maximum to be used for ending point of line.
- visible: if True, trace is set to visible.
- """
- # pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
- return go.Scatter(
- x=[min_, max_],
- y=[min_, max_],
- line=dict(color="black", width=2, dash="dot"), # noqa: C408
- mode="lines",
- hoverinfo="none",
- visible=visible,
- showlegend=False,
- )
-
-
-def default_value_se_raw(se_raw: list[float] | None, out_length: int) -> list[float]:
- """
- Takes a list of standard errors and maps edge cases to default list
- of floats.
- """
- new_se_raw = (
- [0.0 if np.isnan(se) else se for se in se_raw]
- if se_raw is not None
- else [0.0] * out_length
- )
- return new_se_raw
diff --git a/ax/analysis/old/helpers/layout_helpers.py b/ax/analysis/old/helpers/layout_helpers.py
deleted file mode 100644
index 67bd3a09c19..00000000000
--- a/ax/analysis/old/helpers/layout_helpers.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-from typing import Any
-
-import plotly.graph_objs as go
-
-
-def updatemenus_format(metric_dropdown: list[dict[str, Any]]) -> list[dict[str, Any]]:
- """
- Formats for use in the cross validation plot
- """
- return [
- {
- "x": 0,
- "y": 1.125,
- "yanchor": "top",
- "xanchor": "left",
- "buttons": metric_dropdown,
- },
- {
- "buttons": [
- {
- "args": [
- {
- "error_x.width": 4,
- "error_x.thickness": 2,
- "error_y.width": 4,
- "error_y.thickness": 2,
- }
- ],
- "label": "Yes",
- "method": "restyle",
- },
- {
- "args": [
- {
- "error_x.width": 0,
- "error_x.thickness": 0,
- "error_y.width": 0,
- "error_y.thickness": 0,
- }
- ],
- "label": "No",
- "method": "restyle",
- },
- ],
- "x": 1.125,
- "xanchor": "left",
- "y": 0.8,
- "yanchor": "middle",
- },
- ]
-
-
-def layout_format(
- layout_axis_range_value: tuple[float, float],
- xlabel: str,
- ylabel: str,
- updatemenus: list[dict[str, Any]],
-) -> type[go.Figure]:
- """
- Constructs a layout object for a CrossValidation figure.
- args:
- layout_axis_range_value: A tuple containing the range of values
- for the x-axis and y-axis.
- xlabel: Label for the x-axis.
- ylabel: Label for the y-axis.
- updatemenus: A list of dictionaries containing information to use on update.
- """
- layout = go.Layout(
- annotations=[
- {
- "showarrow": False,
- "text": "Show CI",
- "x": 1.125,
- "xanchor": "left",
- "xref": "paper",
- "y": 0.9,
- "yanchor": "middle",
- "yref": "paper",
- }
- ],
- xaxis={
- "range": layout_axis_range_value,
- "title": xlabel,
- "zeroline": False,
- "mirror": True,
- "linecolor": "black",
- "linewidth": 0.5,
- },
- yaxis={
- "range": layout_axis_range_value,
- "title": ylabel,
- "zeroline": False,
- "mirror": True,
- "linecolor": "black",
- "linewidth": 0.5,
- },
- showlegend=False,
- hovermode="closest",
- updatemenus=updatemenus,
- width=530,
- height=500,
- )
- # pyre-fixme[7]: Expected `Type[Figure]` but got `Layout`.
- return layout
diff --git a/ax/analysis/old/helpers/plot_data_df_helpers.py b/ax/analysis/old/helpers/plot_data_df_helpers.py
deleted file mode 100644
index 445a1f4340e..00000000000
--- a/ax/analysis/old/helpers/plot_data_df_helpers.py
+++ /dev/null
@@ -1,92 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-
-import numpy as np
-import pandas as pd
-
-from ax.modelbridge import ModelBridge
-from ax.modelbridge.prediction_utils import predict_at_point
-
-from ax.modelbridge.transforms.ivw import IVW
-
-
-def get_plot_data_in_sample_arms_df(
- model: ModelBridge,
- metric_names: set[str],
-) -> pd.DataFrame:
- """Get in-sample arms from a model with observed and predicted values
- for specified metrics.
-
- Returns a dataframe in which repeated observations are merged
- with IVW (inverse variance weighting)
-
- Args:
- model: An instance of the model bridge.
- metric_names: Restrict predictions to these metrics. If None, uses all
- metrics in the model.
-
- Returns:
- A dataframe containing
- columns:
- {
- "arm_name": name of the arm in the cross validation result
- "metric_name": name of the observed/predicted metric
- "x": value observed for the metric for this arm
- "x_se": standard error of observed metric (0 for observations)
- "y": value predicted for the metric for this arm
- "y_se": standard error of predicted metric for this arm
- "arm_parameters": Parametrization of the arm
- }
- """
- observations = model.get_training_data()
- training_in_design: list[bool] = model.training_in_design
-
- # Merge multiple measurements within each Observation with IVW to get
- # un-modeled prediction
- observations = IVW(None, []).transform_observations(observations)
-
- # Create records for dict
- records = []
-
- for i, obs in enumerate(observations):
- # Extract raw measurement
- features = obs.features
-
- if training_in_design[i]:
- pred_y_dict, pred_se_dict = predict_at_point(model, features, metric_names)
- else:
- pred_y_dict = None
- pred_se_dict = None
-
- for metric_name in obs.data.metric_names:
- if metric_name not in metric_names:
- continue
- obs_y = obs.data.means_dict[metric_name]
- obs_se = np.sqrt(obs.data.covariance_matrix[metric_name][metric_name])
-
- if pred_y_dict and pred_se_dict:
- pred_y = pred_y_dict[metric_name]
- pred_se = pred_se_dict[metric_name]
- else:
- pred_y = obs_y
- pred_se = obs_se
-
- records.append(
- {
- "arm_name": obs.arm_name,
- "metric_name": metric_name,
- "x": obs_y,
- "x_se": obs_se,
- "y": pred_y,
- "y_se": pred_se,
- "arm_parameters": obs.features.parameters,
- }
- )
-
- return pd.DataFrame.from_records(records)
diff --git a/ax/analysis/old/helpers/plot_helpers.py b/ax/analysis/old/helpers/plot_helpers.py
deleted file mode 100644
index 5ed1a7dfeee..00000000000
--- a/ax/analysis/old/helpers/plot_helpers.py
+++ /dev/null
@@ -1,91 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-
-from logging import Logger
-from typing import Optional, Union
-
-from ax.analysis.old.helpers.constants import DECIMALS, Z
-
-from ax.core.generator_run import GeneratorRun
-
-from ax.core.types import TParameterization
-from ax.utils.common.logger import get_logger
-
-from plotly import graph_objs as go
-
-
-logger: Logger = get_logger(__name__)
-
-# Typing alias
-RawData = list[dict[str, Union[str, float]]]
-
-TNullableGeneratorRunsDict = Optional[dict[str, GeneratorRun]]
-
-
-def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str:
- """Format a dictionary for labels.
-
- Args:
- param_dict: Dictionary to be formatted
- name: String name of the thing being formatted.
-
- Returns: stringified blob.
- """
- if len(param_dict) >= 10:
- blob = "{} has too many items to render on hover ({}).".format(
- name, len(param_dict)
- )
- else:
- blob = "
{}:
{}".format(
- name, "
".join(f"{n}: {v}" for n, v in param_dict.items())
- )
- return blob
-
-
-def _format_CI(estimate: float, sd: float, zval: float = Z) -> str:
- """Format confidence intervals given estimate and standard deviation.
-
- Args:
- estimate: point estimate.
- sd: standard deviation of point estimate.
- zval: z-value associated with desired CI (e.g. 1.96 for 95% CIs)
-
- Returns: formatted confidence interval.
- """
- return "[{lb:.{digits}f}, {ub:.{digits}f}]".format(
- lb=estimate - zval * sd,
- ub=estimate + zval * sd,
- digits=DECIMALS,
- )
-
-
-def resize_subtitles(figure: go.Figure, size: int) -> go.Figure:
- """Resize subtitles in a plotly figure
- args:
- figure: plotly figure to resize subtitles of
- size: font size to resize subtitles to
- """
- for ant in figure["layout"]["annotations"]:
- ant["font"].update(size=size)
- return figure
-
-
-def arm_name_to_sort_key(arm_name: str) -> tuple[str, int, int]:
- """Parses arm name into tuple suitable for reverse sorting by key
-
- Example:
- arm_names = ["0_0", "1_10", "1_2", "10_0", "control"]
- sorted(arm_names, key=arm_name_to_sort_key, reverse=True)
- ["control", "0_0", "1_2", "1_10", "10_0"]
- """
- try:
- trial_index, arm_index = arm_name.split("_")
- return ("", -int(trial_index), -int(arm_index))
- except (ValueError, IndexError):
- return (arm_name, 0, 0)
diff --git a/ax/analysis/old/helpers/scatter_helpers.py b/ax/analysis/old/helpers/scatter_helpers.py
deleted file mode 100644
index 13fd9144ac8..00000000000
--- a/ax/analysis/old/helpers/scatter_helpers.py
+++ /dev/null
@@ -1,314 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-import numbers
-
-from typing import Any
-
-import numpy as np
-import pandas as pd
-
-import plotly.graph_objs as go
-from ax.analysis.old.helpers.color_helpers import rgba
-
-from ax.analysis.old.helpers.constants import CI_OPACITY, COLORS, DECIMALS, Z
-
-from ax.analysis.old.helpers.plot_helpers import _format_CI, _format_dict
-
-from ax.core.types import TParameterization
-
-from ax.utils.stats.statstools import relativize
-
-# disable false positive "SettingWithCopyWarning"
-pd.options.mode.chained_assignment = None
-
-
-def relativize_dataframe(df: pd.DataFrame, status_quo_name: str) -> pd.DataFrame:
- """
- Relativizes the dataframe with respect to "status_quo_name". Assumes as a
- precondition that for each metric in the dataframe, there is a row with
- arm_name == status_quo_name to relativize against.
-
- Args:
- df: dataframe with the following columns
- {
- "arm_name": name of the arm in the cross validation result
- "metric_name": name of the observed/predicted metric
- "x": value of the observation for the metric for this arm
- "x_se": standard error of the observation for the metric of this arm
- "y": value predicted for the metric for this arm
- "y_se": standard error of predicted metric for this arm
- }
- status_quo_name: name of the status quo arm in the dataframe to use
- for relativization.
- Returns:
- A dataframe containing the same rows as df, with the observation and predicted
- data values relativized with respect to the status quo.
- An additional column "rel" is added to indicate whether the data is relativized.
-
- """
- metrics = df["metric_name"].unique()
-
- def _relativize_filtered_dataframe(
- df: pd.DataFrame, metric_name: str, status_quo_name: str
- ) -> pd.DataFrame:
- df = df.loc[df["metric_name"] == metric_name]
- status_quo_row = df.loc[df["arm_name"] == status_quo_name]
-
- mean_c = status_quo_row["y"].iloc[0]
- sem_c = status_quo_row["y_se"].iloc[0]
- y_rel, y_se_rel = relativize(
- means_t=df["y"].tolist(),
- sems_t=df["y_se"].tolist(),
- mean_c=mean_c,
- sem_c=sem_c,
- as_percent=True,
- )
- df["y"] = y_rel
- df["y_se"] = y_se_rel
-
- mean_c = status_quo_row["x"].iloc[0]
- sem_c = status_quo_row["x_se"].iloc[0]
- x_rel, x_se_rel = relativize(
- means_t=df["x"].tolist(),
- sems_t=df["x_se"].tolist(),
- mean_c=mean_c,
- sem_c=sem_c,
- as_percent=True,
- )
- df["x"] = x_rel
- df["x_se"] = x_se_rel
- df["rel"] = True
- return df
-
- return pd.concat(
- [
- _relativize_filtered_dataframe(
- df=df, metric_name=metric, status_quo_name=status_quo_name
- )
- for metric in metrics
- ]
- )
-
-
-def extract_mean_and_error_from_df(
- df: pd.DataFrame,
-) -> tuple[list[float], list[float], list[float], list[float]]:
- """Extract mean and error from dataframe.
-
- Args:
- df: dataframe containing the scatter plot data
- Returns:
- x: list of x values
- x_se: list of x standard error
- y: list of y values
- y_se: list of y standard error
- """
- x = df["x"]
- x_se = df["x_se"]
- y = df["y"]
- y_se = df["y_se"]
-
- return (x, x_se, y, y_se)
-
-
-def make_label(
- arm_name: str,
- x_axis_values: tuple[str, float, float] | None,
- y_axis_values: tuple[str, float, float],
- param_blob: TParameterization,
- rel: bool,
-) -> str:
- """Make label for scatter plot.
-
- Args:
- arm_name: Name of arm
- x_axis_values: Optional Tuple of
- x_name: Name of x variable
- x_val: Value of x variable
- x_se: Standard error of x variable
- y_axis_values: Tuple of
- y_name: Name of y variable
- y_val: Value of y variable
- y_se: Standard error of y variable
- param_blob: Parameterization of arm
- rel: whether the data is relativized as a %
-
- Returns:
- Label for scatter plot.
- """
- heading = f"Arm {arm_name}
"
- x_lab = ""
- if x_axis_values is not None:
- x_name, x_val, x_se = x_axis_values
- x_lab = "{name}: {estimate}{perc} {ci}
".format(
- name=x_name,
- estimate=(
- round(x_val, DECIMALS) if isinstance(x_val, numbers.Number) else x_val
- ),
- ci=_format_CI(estimate=x_val, sd=x_se),
- perc="%" if rel else "",
- )
-
- y_name, y_val, y_se = y_axis_values
- y_lab = "{name}: {estimate}{perc} {ci}
".format(
- name=y_name,
- estimate=(
- round(y_val, DECIMALS) if isinstance(y_val, numbers.Number) else y_val
- ),
- ci=_format_CI(estimate=y_val, sd=y_se),
- perc="%" if rel else "",
- )
-
- parameterization = _format_dict(param_blob, "Parameterization")
-
- return "{arm_name}
{xlab}{ylab}{param_blob}".format(
- arm_name=heading,
- xlab=x_lab,
- ylab=y_lab,
- param_blob=parameterization,
- )
-
-
-def error_dot_plot_trace_from_df(
- df: pd.DataFrame,
- show_CI: bool = True,
- visible: bool = True,
-) -> dict[str, Any]:
- """Creates trace for dot plot with confidence intervals.
- Categorizes by arm name.
-
- Args:
- df: dataframe containing the scatter plot data
- show_CI: if True, plot confidence intervals.
- visible: if True, trace will be visible in figure
- """
-
- _, _, y, y_se = extract_mean_and_error_from_df(df)
-
- labels = []
-
- metric_name = df["metric_name"].iloc[0]
-
- for _, row in df.iterrows():
- labels.append(
- make_label(
- arm_name=row["arm_name"],
- x_axis_values=(None),
- y_axis_values=(
- metric_name,
- row["y"],
- row["y_se"],
- ),
- param_blob=row["arm_parameters"],
- rel=(False if "rel" not in row else row["rel"]),
- )
- )
-
- trace = go.Scatter(
- x=df["arm_name"],
- y=y,
- marker={"color": rgba(COLORS.STEELBLUE.value)},
- mode="markers",
- name="In-sample",
- text=labels,
- hoverinfo="text",
- )
-
- if show_CI:
- if y_se is not None:
- trace.update(
- error_y={
- "type": "data",
- "array": np.multiply(y_se, Z),
- "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY),
- }
- )
-
- trace.update(visible=visible)
- trace.update(showlegend=False)
- # pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
- return trace
-
-
-def error_scatter_trace_from_df(
- df: pd.DataFrame,
- show_CI: bool = True,
- visible: bool = True,
- y_axis_label: str | None = None,
- x_axis_label: str | None = None,
-) -> dict[str, Any]:
- """Plot scatterplot with error bars.
-
- Args:
- df: dataframe containing the scatter plot data
- show_CI: if True, plot confidence intervals.
- visible: if True, trace will be visible in figure
- y_axis_label: custom label to use for y axis.
- If None, use metric name from `y_axis_var`.
- x_axis_label: custom label to use for x axis.
- If None, use metric name from `x_axis_var` if that is not None.
- """
-
- x, x_se, y, y_se = extract_mean_and_error_from_df(df)
-
- labels = []
-
- metric_name = df["metric_name"].iloc[0]
-
- for _, row in df.iterrows():
- labels.append(
- make_label(
- arm_name=row["arm_name"],
- x_axis_values=(
- metric_name if x_axis_label is None else x_axis_label,
- row["x"],
- row["x_se"],
- ),
- y_axis_values=(
- (metric_name if y_axis_label is None else y_axis_label),
- row["y"],
- row["y_se"],
- ),
- param_blob=row["arm_parameters"],
- rel=(False if "rel" not in row else row["rel"]),
- )
- )
-
- trace = go.Scatter(
- x=x,
- y=y,
- marker={"color": rgba(COLORS.STEELBLUE.value)},
- mode="markers",
- name="In-sample",
- text=labels,
- hoverinfo="text",
- )
-
- if show_CI:
- if x_se is not None:
- trace.update(
- error_x={
- "type": "data",
- "array": np.multiply(x_se, Z),
- "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY),
- }
- )
- if y_se is not None:
- trace.update(
- error_y={
- "type": "data",
- "array": np.multiply(y_se, Z),
- "color": rgba(COLORS.STEELBLUE.value, CI_OPACITY),
- }
- )
-
- trace.update(visible=visible)
- trace.update(showlegend=True)
- # pyre-fixme[7]: Expected `Dict[str, typing.Any]` but got `Scatter`.
- return trace
diff --git a/ax/analysis/old/helpers/tests/test_cross_validation_helpers.py b/ax/analysis/old/helpers/tests/test_cross_validation_helpers.py
deleted file mode 100644
index 3b615aa2ae0..00000000000
--- a/ax/analysis/old/helpers/tests/test_cross_validation_helpers.py
+++ /dev/null
@@ -1,91 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-import tempfile
-
-import plotly.graph_objects as go
-import plotly.io as pio
-from ax.analysis.old.cross_validation_plot import CrossValidationPlot
-from ax.analysis.old.helpers.constants import Z
-from ax.analysis.old.helpers.cross_validation_helpers import get_min_max_with_errors
-from ax.modelbridge.registry import Models
-from ax.utils.common.testutils import TestCase
-from ax.utils.testing.core_stubs import get_branin_experiment
-from ax.utils.testing.mock import fast_botorch_optimize
-from pandas import read_json
-from pandas.testing import assert_frame_equal
-
-
-class TestCrossValidationHelpers(TestCase):
- @fast_botorch_optimize
- def setUp(self) -> None:
- super().setUp()
- self.exp = get_branin_experiment(with_batch=True)
- self.exp.trials[0].run()
- self.model = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- self.exp_status_quo = get_branin_experiment(
- with_batch=True, with_status_quo=True
- )
- self.exp_status_quo.trials[0].run()
- self.model_status_quo = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- def test_get_min_max_with_errors(self) -> None:
- # Test with sample data
- x = [1.0, 2.0, 3.0]
- y = [4.0, 5.0, 6.0]
- sd_x = [0.1, 0.2, 0.3]
- sd_y = [0.1, 0.2, 0.3]
- min_, max_ = get_min_max_with_errors(x, y, sd_x, sd_y)
-
- expected_min = 1.0 - 0.1 * Z
- expected_max = 6.0 + 0.3 * Z
- # Check that the returned values are correct
- print(f"min: {min_} {expected_min=}")
- print(f"max: {max_} {expected_max=}")
- self.assertAlmostEqual(min_, expected_min, delta=1e-4)
- self.assertAlmostEqual(max_, expected_max, delta=1e-4)
-
- def test_obs_vs_pred_dropdown_plot(self) -> None:
- cross_validation_plot = CrossValidationPlot(
- experiment=self.exp, model=self.model
- )
- fig = cross_validation_plot.get_fig()
-
- self.assertIsInstance(fig, go.Figure)
-
- def test_store_df_to_file(self) -> None:
- with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f:
- cross_validation_plot = CrossValidationPlot(
- experiment=self.exp, model=self.model
- )
- cv_df = cross_validation_plot.get_df()
- cv_df.to_json(f.name)
-
- loaded_dataframe = read_json(f.name, dtype={"arm_name": "str"})
-
- assert_frame_equal(cv_df, loaded_dataframe, check_dtype=False)
-
- def test_store_plot_as_dict(self) -> None:
- cross_validation_plot = CrossValidationPlot(
- experiment=self.exp, model=self.model
- )
- cv_fig = cross_validation_plot.get_fig()
-
- json_obj = pio.to_json(cv_fig, validate=True, remove_uids=False)
-
- loaded_json_obj = pio.from_json(json_obj, output_type="Figure")
- self.assertEqual(cv_fig, loaded_json_obj)
diff --git a/ax/analysis/old/helpers/tests/test_cv_consistency_checks.py b/ax/analysis/old/helpers/tests/test_cv_consistency_checks.py
deleted file mode 100644
index d36b4c8ab9d..00000000000
--- a/ax/analysis/old/helpers/tests/test_cv_consistency_checks.py
+++ /dev/null
@@ -1,129 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-import copy
-
-import plotly.graph_objects as go
-from ax.analysis.old.cross_validation_plot import CrossValidationPlot
-from ax.analysis.old.helpers.cross_validation_helpers import (
- error_scatter_data_from_cv_results,
-)
-from ax.analysis.old.helpers.scatter_helpers import error_scatter_trace_from_df
-from ax.modelbridge.cross_validation import cross_validate
-from ax.modelbridge.registry import Models
-from ax.plot.base import PlotMetric
-from ax.plot.diagnostic import (
- _get_cv_plot_data as PLOT_get_cv_plot_data,
- interact_cross_validation_plotly as PLOT_interact_cross_validation_plotly,
-)
-from ax.plot.scatter import (
- _error_scatter_data as PLOT_error_scatter_data,
- _error_scatter_trace as PLOT_error_scatter_trace,
-)
-from ax.utils.common.testutils import TestCase
-from ax.utils.testing.core_stubs import get_branin_experiment
-from ax.utils.testing.mock import fast_botorch_optimize
-
-
-class TestCVConsistencyCheck(TestCase):
- @fast_botorch_optimize
- def setUp(self) -> None:
- super().setUp()
- self.exp = get_branin_experiment(with_batch=True)
- self.exp.trials[0].run()
- self.model = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- self.exp_status_quo = get_branin_experiment(
- with_batch=True, with_status_quo=True
- )
- self.exp_status_quo.trials[0].run()
- self.model_status_quo = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- def test_error_scatter_data_branin(self) -> None:
- cv_results = cross_validate(self.model)
- cv_results_plot = copy.deepcopy(cv_results)
-
- result_analysis = error_scatter_data_from_cv_results(
- cv_results=cv_results,
- metric_name="branin",
- )
-
- data = PLOT_get_cv_plot_data(cv_results_plot, label_dict={})
- result_plot = PLOT_error_scatter_data(
- list(data.in_sample.values()),
- y_axis_var=PlotMetric("branin", pred=True, rel=False),
- x_axis_var=PlotMetric("branin", pred=False, rel=False),
- )
- self.assertEqual(result_analysis, result_plot)
-
- def test_error_scatter_trace_branin(self) -> None:
- cv_results = cross_validate(self.model)
- cv_results_plot = copy.deepcopy(cv_results)
-
- cross_validation_plot = CrossValidationPlot(
- experiment=self.exp, model=self.model
- )
- df = cross_validation_plot.get_df()
-
- metric_filtered_df = df.loc[df["metric_name"] == "branin"]
- result_analysis = error_scatter_trace_from_df(
- df=metric_filtered_df,
- show_CI=True,
- visible=True,
- x_axis_label="Actual Outcome",
- y_axis_label="Predicted Outcome",
- )
-
- data = data = PLOT_get_cv_plot_data(cv_results_plot, label_dict={})
-
- result_plot = PLOT_error_scatter_trace(
- arms=list(data.in_sample.values()),
- hoverinfo="text",
- show_arm_details_on_hover=True,
- show_CI=True,
- show_context=False,
- status_quo_arm=None,
- visible=True,
- y_axis_var=PlotMetric("branin", pred=True, rel=False),
- x_axis_var=PlotMetric("branin", pred=False, rel=False),
- x_axis_label="Actual Outcome",
- y_axis_label="Predicted Outcome",
- )
-
- print(str(result_analysis))
- print(str(result_plot))
-
- self.assertEqual(result_analysis, result_plot)
-
- def test_obs_vs_pred_dropdown_plot_branin(self) -> None:
- label_dict = {"branin": "BrAnIn"}
-
- cross_validation_plot = CrossValidationPlot(
- experiment=self.exp, model=self.model, label_dict=label_dict
- )
- fig = cross_validation_plot.get_fig()
-
- self.assertIsInstance(fig, go.Figure)
-
- cv_results_plot = cross_validate(self.model)
-
- fig_PLOT = PLOT_interact_cross_validation_plotly(
- cv_results_plot,
- show_context=False,
- label_dict=label_dict,
- caption=CrossValidationPlot.CROSS_VALIDATION_CAPTION,
- )
- self.assertEqual(fig, fig_PLOT)
diff --git a/ax/analysis/old/predicted_outcomes_dot_plot.py b/ax/analysis/old/predicted_outcomes_dot_plot.py
deleted file mode 100644
index 6022134616d..00000000000
--- a/ax/analysis/old/predicted_outcomes_dot_plot.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-from typing import Any
-
-import numpy as np
-
-import pandas as pd
-
-from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
-
-from ax.analysis.old.helpers.layout_helpers import updatemenus_format
-
-from ax.analysis.old.helpers.plot_data_df_helpers import get_plot_data_in_sample_arms_df
-from ax.analysis.old.helpers.plot_helpers import arm_name_to_sort_key, resize_subtitles
-
-from ax.analysis.old.helpers.scatter_helpers import (
- error_dot_plot_trace_from_df,
- relativize_dataframe,
-)
-
-from ax.core.experiment import Experiment
-
-from ax.exceptions.core import UnsupportedPlotError
-
-from ax.modelbridge import ModelBridge
-
-from plotly import graph_objs as go
-
-
-class PredictedOutcomesDotPlot(BasePlotlyVisualization):
- def __init__(
- self,
- experiment: Experiment,
- model: ModelBridge,
- ) -> None:
- """
- Args:
- experiment: The experiment associated with this plot
- model: model which is used to fetch the plotting data.
- """
-
- self.model = model
- self.metrics: set[str] = model.metric_names
- if model.status_quo is None or model.status_quo.arm_name is None:
- raise UnsupportedPlotError(
- "status quo must be specified for PredictedOutcomesDotPlot"
- )
- self.status_quo_name: str = model.status_quo.arm_name
-
- super().__init__(experiment=experiment)
-
- def get_df(self) -> pd.DataFrame:
- """
- Returns:
- A dataframe containing:
- {
- "arm_name": name of the arm in the cross validation result
- "metric_name": name of the observed/predicted metric
- "x": value of the observation for the metric for this arm
- "x_se": standard error of the observation for the metric of this arm
- "y": value predicted for the metric for this arm
- "y_se": standard error of predicted metric for this arm
- "arm_parameters": Parametrization of the arm
- "rel": whether the data is relativized with respect to status quo
- }"""
- return relativize_dataframe(
- get_plot_data_in_sample_arms_df(
- model=self.model, metric_names=self.metrics
- ),
- status_quo_name=self.status_quo_name,
- )
-
- def get_fig(
- self,
- ) -> go.Figure:
- """
- For each metric, we plot the predicted values for each arm along with its CI
- These values are relativized with respect to the status quo.
- """
- name_order_axes: dict[str, dict[str, Any]] = {}
-
- in_sample_df = self.get_df()
- traces = []
- metric_dropdown = []
-
- for i, metric in enumerate(self.metrics):
- filtered_df = in_sample_df.loc[in_sample_df["metric_name"] == metric]
- data_single: dict[str, Any] = error_dot_plot_trace_from_df(
- df=filtered_df, show_CI=True, visible=(i == 0)
- )
-
- # order arm name sorting arm numbers within batch
- names_by_arm = sorted(
- np.unique(data_single["x"]),
- key=lambda x: arm_name_to_sort_key(x),
- reverse=True,
- )
-
- name_order_axes[f"xaxis{i + 1}"] = {
- "categoryorder": "array",
- "categoryarray": names_by_arm,
- "type": "category",
- }
- name_order_axes[f"yaxis{i + 1}"] = {
- "ticksuffix": "%",
- "zerolinecolor": "red",
- }
-
- traces.append(data_single)
-
- is_visible = [False] * (len(metric))
- is_visible[i] = True
- metric_dropdown.append(
- {
- "args": [
- {
- "visible": is_visible,
- },
- ],
- "label": metric,
- "method": "update",
- }
- )
-
- updatemenus = updatemenus_format(metric_dropdown=metric_dropdown)
-
- fig = go.Figure(data=traces)
-
- fig["layout"].update(
- updatemenus=updatemenus,
- width=1030,
- height=500,
- **name_order_axes,
- )
-
- fig = resize_subtitles(figure=fig, size=10)
- fig["layout"]["title"] = "Predicted Outcomes by Metric"
- return fig
diff --git a/ax/analysis/old/tests/test_analysis_report.py b/ax/analysis/old/tests/test_analysis_report.py
deleted file mode 100644
index 919baf8aaa6..00000000000
--- a/ax/analysis/old/tests/test_analysis_report.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-import pandas as pd
-import plotly.graph_objects as go
-
-from ax.analysis.old.analysis_report import AnalysisReport
-
-from ax.analysis.old.base_analysis import BaseAnalysis
-from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
-
-from ax.modelbridge.registry import Models
-from ax.utils.common.testutils import TestCase
-
-from ax.utils.common.timeutils import current_timestamp_in_millis
-from ax.utils.testing.core_stubs import get_branin_experiment
-from ax.utils.testing.mock import fast_botorch_optimize
-
-
-class TestCrossValidationPlot(TestCase):
- class TestAnalysis(BaseAnalysis):
- def get_df(self) -> pd.DataFrame:
- return pd.DataFrame()
-
- class TestPlotlyVisualization(BasePlotlyVisualization):
- def get_df(self) -> pd.DataFrame:
- return pd.DataFrame()
-
- def get_fig(self) -> go.Figure:
- return go.Figure()
-
- @fast_botorch_optimize
- def setUp(self) -> None:
- super().setUp()
-
- self.exp = get_branin_experiment(with_batch=True)
- self.exp.trials[0].run()
- self.model = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- self.test_analysis = self.TestAnalysis(experiment=self.exp)
- self.test_plotly_visualization = self.TestPlotlyVisualization(
- experiment=self.exp
- )
-
- def test_init_analysis_report(self) -> None:
- analysis_report = AnalysisReport(
- experiment=self.exp,
- analyses=[self.test_analysis, self.test_plotly_visualization],
- )
-
- self.assertEqual(len(analysis_report.analyses), 2)
-
- self.assertIsInstance(analysis_report.analyses[0], BaseAnalysis)
- self.assertIsInstance(analysis_report.analyses[1], BasePlotlyVisualization)
-
- self.assertIsNone(analysis_report.time_started)
- self.assertIsNone(analysis_report.time_completed)
- self.assertFalse(analysis_report.report_completed)
-
- def test_execute_analysis_report(self) -> None:
- analysis_report = AnalysisReport(
- experiment=self.exp,
- analyses=[self.test_analysis, self.test_plotly_visualization],
- )
-
- self.assertEqual(len(analysis_report.analyses), 2)
-
- results = analysis_report.run_analysis_report()
- self.assertEqual(len(results), 2)
-
- self.assertIsNotNone(analysis_report.time_started)
- self.assertIsNotNone(analysis_report.time_completed)
- self.assertTrue(analysis_report.report_completed)
-
- # assert no plot is returned for BaseAnalysis
- self.assertIsInstance(results[0][1], pd.DataFrame)
- self.assertIsNone(results[0][2])
-
- # assert plot is returned for BasePlotlyVisualization
- self.assertIsInstance(results[1][1], pd.DataFrame)
- self.assertIsInstance(results[1][2], go.Figure)
-
- def test_analysis_report_repeated_execute(self) -> None:
- analysis_report = AnalysisReport(
- experiment=self.exp,
- analyses=[self.test_analysis, self.test_plotly_visualization],
- )
-
- self.assertEqual(len(analysis_report.analyses), 2)
-
- _ = analysis_report.run_analysis_report()
-
- saved_start = analysis_report.time_started
- self.assertIsNotNone(saved_start)
-
- # ensure analyses are not re-ran
- _ = analysis_report.run_analysis_report()
- self.assertEqual(saved_start, analysis_report.time_started)
-
- def test_no_analysis_report(self) -> None:
- analysis_report = AnalysisReport(
- experiment=self.exp,
- analyses=[],
- )
-
- self.assertEqual(len(analysis_report.analyses), 0)
-
- results = analysis_report.run_analysis_report()
- self.assertEqual(len(results), 0)
-
- def test_singleton_analysis_report(self) -> None:
- analysis_report = AnalysisReport(
- experiment=self.exp,
- analyses=[self.test_plotly_visualization],
- )
-
- self.assertEqual(len(analysis_report.analyses), 1)
-
- results = analysis_report.run_analysis_report()
- self.assertEqual(len(results), 1)
-
- # assert plot is returned for BasePlotlyVisualization
- self.assertIsInstance(results[0][1], pd.DataFrame)
- self.assertIsInstance(results[0][2], go.Figure)
-
- def test_create_report_as_completed(self) -> None:
- analysis_report = AnalysisReport(
- experiment=self.exp,
- analyses=[self.test_analysis, self.test_plotly_visualization],
- time_started=current_timestamp_in_millis(),
- time_completed=current_timestamp_in_millis(),
- )
-
- self.assertIsNotNone(analysis_report.time_started)
- self.assertIsNotNone(analysis_report.time_completed)
- self.assertTrue(analysis_report.report_completed)
diff --git a/ax/analysis/old/tests/test_base_classes.py b/ax/analysis/old/tests/test_base_classes.py
deleted file mode 100644
index 64d2040e056..00000000000
--- a/ax/analysis/old/tests/test_base_classes.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-import pandas as pd
-
-import plotly.express as px
-
-import plotly.graph_objects as go
-
-from ax.analysis.old.base_analysis import BaseAnalysis
-from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization
-
-from ax.modelbridge.registry import Models
-
-from ax.utils.common.testutils import TestCase
-
-from ax.utils.testing.core_stubs import get_branin_experiment
-from ax.utils.testing.mock import fast_botorch_optimize
-
-
-class TestBaseClasses(TestCase):
- class TestAnalysis(BaseAnalysis):
- get_df_call_count: int = 0
-
- def get_df(self) -> pd.DataFrame:
- self.get_df_call_count = self.get_df_call_count + 1
- return pd.DataFrame()
-
- class TestPlotlyVisualization(BasePlotlyVisualization):
- get_df_call_count: int = 0
- get_fig_call_count: int = 0
-
- def get_df(self) -> pd.DataFrame:
- self.get_df_call_count = self.get_df_call_count + 1
- return pd.DataFrame()
-
- def get_fig(self) -> go.Figure:
- self.get_fig_call_count = self.get_fig_call_count + 1
- return go.Figure()
-
- @fast_botorch_optimize
- def setUp(self) -> None:
- super().setUp()
-
- self.exp = get_branin_experiment(with_batch=True)
- self.exp.trials[0].run()
- self.model = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- self.test_analysis = self.TestAnalysis(experiment=self.exp)
- self.test_plotly_visualization = self.TestPlotlyVisualization(
- experiment=self.exp
- )
-
- def test_base_analysis_df_property(self) -> None:
- test_analysis = self.TestAnalysis(experiment=self.exp)
-
- self.assertEqual(test_analysis.get_df_call_count, 0)
-
- # accessing the df property calls get_df
- _ = test_analysis.df
- self.assertEqual(test_analysis.get_df_call_count, 1)
-
- # once saved, get_df is not called again.
- _ = test_analysis.df
- self.assertEqual(test_analysis.get_df_call_count, 1)
-
- def test_base_analysis_pass_df_in(self) -> None:
- existing_df = pd.DataFrame([1])
- test_analysis = self.TestAnalysis(experiment=self.exp, df_input=existing_df)
-
- self.assertEqual(test_analysis.get_df_call_count, 0)
- saved_df = test_analysis.df
- self.assertTrue(existing_df.equals(saved_df))
-
- # when df is passed in directly, get_df is never called.
- self.assertEqual(test_analysis.get_df_call_count, 0)
-
- def test_base_plotly_visualization_fig_property(self) -> None:
- test_analysis = self.TestPlotlyVisualization(experiment=self.exp)
-
- self.assertEqual(test_analysis.get_fig_call_count, 0)
-
- # accessing the fig property calls get_fig
- _ = test_analysis.fig
- self.assertEqual(test_analysis.get_fig_call_count, 1)
-
- # once saved, get_fig is not called again.
- _ = test_analysis.fig
- self.assertEqual(test_analysis.get_fig_call_count, 1)
-
- def test_base_plotly_visualization_pass_fig_in(self) -> None:
- fig = px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16])
- test_analysis = self.TestPlotlyVisualization(
- experiment=self.exp,
- fig_input=fig,
- )
-
- self.assertEqual(test_analysis.get_fig_call_count, 0)
- saved_fig = test_analysis.fig
- self.assertEqual(fig, saved_fig)
-
- # when fig is passed in directly, get_fig is never called.
- self.assertEqual(test_analysis.get_fig_call_count, 0)
-
- def test_base_plotly_visualization_pass_df_and_fig(self) -> None:
- df = pd.DataFrame([1])
- fig = px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16])
-
- test_analysis = self.TestPlotlyVisualization(
- experiment=self.exp,
- df_input=df,
- fig_input=fig,
- )
-
- self.assertEqual(test_analysis.get_df_call_count, 0)
- saved_df = test_analysis.df
- self.assertTrue(df.equals(saved_df))
- # when df is passed in directly, get_df is never called.
- self.assertEqual(test_analysis.get_df_call_count, 0)
-
- self.assertEqual(test_analysis.get_fig_call_count, 0)
- saved_fig = test_analysis.fig
- self.assertEqual(fig, saved_fig)
- # when fig is passed in directly, get_fig is never called.
- self.assertEqual(test_analysis.get_fig_call_count, 0)
-
- def test_instantiate_base_classes(self) -> None:
- test_analysis = BaseAnalysis(experiment=self.exp)
- with self.assertRaises(NotImplementedError):
- _ = test_analysis.df
-
- test_fig = BasePlotlyVisualization(experiment=self.exp)
- with self.assertRaises(NotImplementedError):
- _ = test_fig.fig
diff --git a/ax/analysis/old/tests/test_cross_validation_plot.py b/ax/analysis/old/tests/test_cross_validation_plot.py
deleted file mode 100644
index 5a6f68dea7a..00000000000
--- a/ax/analysis/old/tests/test_cross_validation_plot.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-import plotly.graph_objects as go
-from ax.analysis.old.cross_validation_plot import CrossValidationPlot
-from ax.modelbridge.registry import Models
-from ax.utils.common.testutils import TestCase
-from ax.utils.testing.core_stubs import get_branin_experiment
-from ax.utils.testing.mock import fast_botorch_optimize
-
-
-class TestCrossValidationPlot(TestCase):
- @fast_botorch_optimize
- def setUp(self) -> None:
- super().setUp()
- self.exp = get_branin_experiment(with_batch=True)
- self.exp.trials[0].run()
- self.model = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- def test_cross_validation_plot(self) -> None:
- plot = CrossValidationPlot(experiment=self.exp, model=self.model).get_fig()
- x_range = plot.layout.updatemenus[0].buttons[0].args[1]["xaxis.range"]
- y_range = plot.layout.updatemenus[0].buttons[0].args[1]["yaxis.range"]
- self.assertTrue((len(x_range) == 2) and (x_range[0] < x_range[1]))
- self.assertTrue((len(y_range) == 2) and (y_range[0] < y_range[1]))
-
- self.assertIsInstance(plot, go.Figure)
-
- def test_cross_validation_plot_get_df(self) -> None:
- plot = CrossValidationPlot(experiment=self.exp, model=self.model)
- _ = plot.get_df()
diff --git a/ax/analysis/old/tests/test_predicted_outcomes_dot_plot.py b/ax/analysis/old/tests/test_predicted_outcomes_dot_plot.py
deleted file mode 100644
index 89669330abc..00000000000
--- a/ax/analysis/old/tests/test_predicted_outcomes_dot_plot.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-strict
-
-
-import unittest
-
-import plotly.graph_objects as go
-from ax.analysis.old.predicted_outcomes_dot_plot import PredictedOutcomesDotPlot
-from ax.exceptions.core import UnsupportedPlotError
-from ax.modelbridge.registry import Models
-from ax.utils.testing.core_stubs import get_branin_experiment
-
-
-class TestPredictedOutcomesDotPlot(unittest.TestCase):
- def setUp(self) -> None:
- super().setUp()
- self.exp = get_branin_experiment(with_batch=True, with_status_quo=True)
- self.exp.trials[0].run()
- self.model = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=self.exp,
- data=self.exp.fetch_data(),
- )
-
- def test_predicted_outcomes_dot_plot_no_status_quo(self) -> None:
- exp = get_branin_experiment(with_batch=True, with_status_quo=False)
- exp.trials[0].run()
- model = Models.BOTORCH_MODULAR(
- # Model bridge kwargs
- experiment=exp,
- data=exp.fetch_data(),
- )
-
- with self.assertRaisesRegex(
- UnsupportedPlotError,
- "status quo must be specified for PredictedOutcomesDotPlot",
- ):
- _ = PredictedOutcomesDotPlot(
- experiment=exp,
- model=model,
- )
-
- def test_predicted_outcomes_dot_plot(self) -> None:
- predicted_outcomes_dot_plot = PredictedOutcomesDotPlot(
- experiment=self.exp,
- model=self.model,
- )
-
- _ = predicted_outcomes_dot_plot.get_df()
-
- fig = predicted_outcomes_dot_plot.get_fig()
- self.assertIsInstance(fig, go.Figure)