Skip to content

Commit

Permalink
Enable setting labels in the batch comparison plot (#1881)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1881

A few weeks ago when the batch comparison plot was being shown to some collaborators outside the team, there's was some brief confusion around what the axes were since the "Batch XX" labels are not always super descriptive. It'd be helpful to be able to give more meaningful labels for future times that these plots are shared.

Reviewed By: Balandat

Differential Revision: D49746704

fbshipit-source-id: f043ef97fc64824640140f9ac1154e87a467010d
  • Loading branch information
bletham authored and facebook-github-bot committed Sep 29, 2023
1 parent bf74e8e commit 7a8069b
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions ax/plot/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,8 @@ def interact_batch_comparison(
batch_y: int,
rel: bool = False,
status_quo_name: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
) -> AxPlotConfig:
"""Compare repeated arms from two trials; select metric via dropdown.
Expand All @@ -662,6 +664,8 @@ def interact_batch_comparison(
batch_y: Index of bach for y-axis.
rel: Whether to relativize data against status_quo arm.
status_quo_name: Name of the status_quo arm.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
"""
if isinstance(experiment, MultiTypeExperiment):
observations = convert_mt_observations(observations, experiment)
Expand All @@ -670,11 +674,15 @@ def interact_batch_comparison(
plot_data = _get_batch_comparison_plot_data(
observations, batch_x, batch_y, rel=rel, status_quo_name=status_quo_name
)
if x_label is None:
x_label = f"Batch {batch_x}"
if y_label is None:
y_label = f"Batch {batch_y}"
fig = _obs_vs_pred_dropdown_plot(
data=plot_data,
rel=rel,
xlabel="Batch {}".format(batch_x),
ylabel="Batch {}".format(batch_y),
xlabel=x_label,
ylabel=y_label,
)
fig["layout"]["title"] = "Repeated arms across trials"
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)

0 comments on commit 7a8069b

Please sign in to comment.