Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TileFittedPlot #2303

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions ax/analysis/helpers/plot_data_df_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Set

import numpy as np
import pandas as pd

from ax.modelbridge import ModelBridge
from ax.modelbridge.prediction_utils import predict_at_point

from ax.modelbridge.transforms.ivw import IVW


def get_plot_data_in_sample_arms_df(
model: ModelBridge,
metric_names: Set[str],
) -> pd.DataFrame:
"""Get in-sample arms from a model with observed and predicted values
for specified metrics.

Returns a dataframe in which repeated observations are merged
with IVW (inverse variance weighting)

Args:
model: An instance of the model bridge.
metric_names: Restrict predictions to these metrics. If None, uses all
metrics in the model.

Returns:
A dataframe containing
columns:
{
"arm_name": name of the arm in the cross validation result
"metric_name": name of the observed/predicted metric
"x": value observed for the metric for this arm
"x_se": standard error of observed metric (0 for observations)
"y": value predicted for the metric for this arm
"y_se": standard error of predicted metric for this arm
"arm_parameters": Parametrization of the arm
}
"""
observations = model.get_training_data()
training_in_design: List[bool] = model.training_in_design

# Merge multiple measurements within each Observation with IVW to get
# un-modeled prediction
observations = IVW(None, []).transform_observations(observations)

# Create records for dict
records = []

for i, obs in enumerate(observations):
# Extract raw measurement
features = obs.features

if training_in_design[i]:
pred_y_dict, pred_se_dict = predict_at_point(model, features, metric_names)
else:
pred_y_dict = None
pred_se_dict = None

for metric_name in obs.data.metric_names:
if metric_name not in metric_names:
continue
obs_y = obs.data.means_dict[metric_name]
obs_se = np.sqrt(obs.data.covariance_matrix[metric_name][metric_name])

if pred_y_dict and pred_se_dict:
pred_y = pred_y_dict[metric_name]
pred_se = pred_se_dict[metric_name]
else:
pred_y = obs_y
pred_se = obs_se

records.append(
{
"arm_name": obs.arm_name,
"metric_name": metric_name,
"x": obs_y,
"x_se": obs_se,
"y": pred_y,
"y_se": pred_se,
"arm_parameters": obs.features.parameters,
}
)

return pd.DataFrame.from_records(records)
22 changes: 12 additions & 10 deletions ax/analysis/helpers/scatter_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def make_label(
estimate=(
round(x_val, DECIMALS) if isinstance(x_val, numbers.Number) else x_val
),
ci=_format_CI(x_val, x_se),
ci=_format_CI(estimate=x_val, sd=x_se),
)
y_lab = "{name}: {estimate} {ci}<br>".format(
name=y_name,
estimate=(
round(y_val, DECIMALS) if isinstance(y_val, numbers.Number) else y_val
),
ci=_format_CI(y_val, y_se),
ci=_format_CI(estimate=y_val, sd=y_se),
)

parameterization = _format_dict(param_blob, "Parameterization")
Expand Down Expand Up @@ -114,22 +114,24 @@ def error_scatter_trace_from_df(
x, x_se, y, y_se = extract_mean_and_error_from_df(df)

labels = []
param_blobs = df["arm_parameters"]
arm_names = df["arm_name"]

metric_name = df["metric_name"].iloc[0]

for i in range(len(param_blobs)):
print("Data frame: " + str(df))
print("Arm names: " + str(arm_names))
print("x" + str(x))
for _, row in df.iterrows():
labels.append(
make_label(
arm_name=arm_names[i],
arm_name=row["arm_name"],
x_name=metric_name if x_axis_label is None else x_axis_label,
x_val=x[i],
x_se=x_se[i],
x_val=row["x"],
x_se=row["x_se"],
y_name=(metric_name if y_axis_label is None else y_axis_label),
y_val=y[i],
y_se=y_se[i],
param_blob=param_blobs[i],
y_val=row["y"],
y_se=row["y_se"],
param_blob=row["arm_parameters"],
)
)

Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/parallel_coordinates_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pandas as pd

from ax.analysis.base_plotly_visualization import BasePlotlyVisualization

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial

Expand All @@ -18,8 +20,6 @@

from plotly import express as px, graph_objs as go

from .base_plotly_visualization import BasePlotlyVisualization


class ParallelCoordinatesPlot(BasePlotlyVisualization):
def __init__(
Expand Down
40 changes: 40 additions & 0 deletions ax/analysis/tests/test_tile_fitted_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import unittest

import plotly.graph_objects as go

from ax.analysis.tile_fitted_plot import TileFittedPlot

from ax.modelbridge.registry import Models

from ax.utils.testing.core_stubs import get_branin_experiment


class TestTileFittedPlot(unittest.TestCase):
def setUp(self) -> None:
self.exp = get_branin_experiment(with_batch=True)
self.exp.trials[0].run()
self.model = Models.BOTORCH_MODULAR(
# Model bridge kwargs
experiment=self.exp,
data=self.exp.fetch_data(),
)

super().setUp()

def test_tile_fitted_plot(self) -> None:
tile_fitted_plot = TileFittedPlot(
experiment=self.exp,
model=self.model,
)

df = tile_fitted_plot.get_df()
print(df)

fig = tile_fitted_plot.get_fig()
self.assertIsInstance(fig, go.Figure)
185 changes: 185 additions & 0 deletions ax/analysis/tile_fitted_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Dict, Set

import numpy as np

import pandas as pd

from ax.analysis.base_plotly_visualization import BasePlotlyVisualization

from ax.analysis.helpers.plot_data_df_helpers import get_plot_data_in_sample_arms_df

from ax.analysis.helpers.scatter_helpers import error_scatter_trace_from_df

from ax.core.experiment import Experiment

from ax.modelbridge import ModelBridge

from plotly import graph_objs as go, subplots


class TileFittedPlot(BasePlotlyVisualization):
def __init__(
self,
experiment: Experiment,
model: ModelBridge,
) -> None:
""" """

self.model = model
self.metrics: Set[str] = model.metric_names

super().__init__(experiment=experiment)

def get_df(self) -> pd.DataFrame:
""" """
return get_plot_data_in_sample_arms_df(
model=self.model, metric_names=self.metrics
)

def get_fig(
self,
) -> go.Figure:
"""Tile version of fitted outcome plots."""
metrics = self.metrics
nrows = int(np.ceil(len(metrics) / 2))
ncols = min(len(metrics), 2)

subplot_titles = metrics

fig = subplots.make_subplots(
rows=nrows,
cols=ncols,
print_grid=False,
shared_xaxes=False,
shared_yaxes=False,
subplot_titles=tuple(subplot_titles),
horizontal_spacing=0.05,
vertical_spacing=0.30 / nrows,
)

name_order_args: Dict[str, Any] = {}
name_order_axes: Dict[str, Dict[str, Any]] = {}
effect_order_args: Dict[str, Any] = {}

in_sample_df = self.get_df()

for i, metric in enumerate(metrics):
filtered_df = in_sample_df.loc[in_sample_df["metric_name"] == metric]
data: Dict[str, Any] = error_scatter_trace_from_df(
df=filtered_df,
show_CI=True,
)

# order arm name sorting arm numbers within batch
"""names_by_arm = sorted(
np.unique(np.concatenate([d["x"] for d in data])),
key=lambda x: arm_name_to_sort_key(x),
)"""
names_by_arm = []

# get arm names sorted by effect size
"""names_by_effect = list(
OrderedDict.fromkeys(
np.concatenate([d["x"] for d in data])
.flatten()
.take(np.argsort(np.concatenate([d["y"] for d in data]).flatten()))
)
)"""
names_by_effect = []

# options for ordering arms (x-axis)
# Note that xaxes need to be references as xaxis, xaxis2, xaxis3, etc.
# for the purposes of updatemenus argument (dropdown) in layout.
# However, when setting the initial ordering layout, the keys should be
# xaxis1, xaxis2, xaxis3, etc. Note the discrepancy for the initial
# axis.
label = "" if i == 0 else i + 1
name_order_args["xaxis{}.categoryorder".format(label)] = "array"
name_order_args["xaxis{}.categoryarray".format(label)] = names_by_arm
effect_order_args["xaxis{}.categoryorder".format(label)] = "array"
effect_order_args["xaxis{}.categoryarray".format(label)] = names_by_effect
name_order_axes["xaxis{}".format(i + 1)] = {
"categoryorder": "array",
"categoryarray": names_by_arm,
"type": "category",
}
name_order_axes["yaxis{}".format(i + 1)] = {
"ticksuffix": "",
"zerolinecolor": "red",
}

fig.append_trace( # pyre-ignore[16]
data, int(np.floor(i / ncols)) + 1, i % ncols + 1
)

order_options = [
{"args": [name_order_args], "label": "Name", "method": "relayout"},
{"args": [effect_order_args], "label": "Effect Size", "method": "relayout"},
]

# if odd number of plots, need to manually remove the last blank subplot
# generated by `subplots.make_subplots`
if len(metrics) % 2 == 1:
fig["layout"].pop("xaxis{}".format(nrows * ncols))
fig["layout"].pop("yaxis{}".format(nrows * ncols))

# allocate 400 px per plot
fig["layout"].update(
margin={"t": 0},
hovermode="closest",
updatemenus=[
{
"x": 0.15,
"y": 1 + 0.40 / nrows,
"buttons": order_options,
"xanchor": "left",
"yanchor": "middle",
}
],
font={"size": 10},
width=650 if ncols == 1 else 950,
height=300 * nrows,
legend={
"orientation": "h",
"x": 0,
"y": 1 + 0.20 / nrows,
"xanchor": "left",
"yanchor": "middle",
},
**name_order_axes,
)

# append dropdown annotations
fig["layout"]["annotations"] = fig["layout"]["annotations"] + (
{
"x": 0.5,
"y": 1 + 0.40 / nrows,
"xref": "paper",
"yref": "paper",
"font": {"size": 14},
"text": "Predicted Outcomes",
"showarrow": False,
"xanchor": "center",
"yanchor": "middle",
},
{
"x": 0.05,
"y": 1 + 0.40 / nrows,
"xref": "paper",
"yref": "paper",
"text": "Sort By",
"showarrow": False,
"xanchor": "left",
"yanchor": "middle",
},
)

# fig = resize_subtitles(figure=fig, size=10)
return fig
Loading