Skip to content

Commit 7a8069b

Browse files
blethamfacebook-github-bot
authored andcommitted
Enable setting labels in the batch comparison plot (#1881)
Summary: Pull Request resolved: #1881 A few weeks ago when the batch comparison plot was being shown to some collaborators outside the team, there's was some brief confusion around what the axes were since the "Batch XX" labels are not always super descriptive. It'd be helpful to be able to give more meaningful labels for future times that these plots are shared. Reviewed By: Balandat Differential Revision: D49746704 fbshipit-source-id: f043ef97fc64824640140f9ac1154e87a467010d
1 parent bf74e8e commit 7a8069b

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

ax/plot/diagnostic.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,8 @@ def interact_batch_comparison(
653653
batch_y: int,
654654
rel: bool = False,
655655
status_quo_name: Optional[str] = None,
656+
x_label: Optional[str] = None,
657+
y_label: Optional[str] = None,
656658
) -> AxPlotConfig:
657659
"""Compare repeated arms from two trials; select metric via dropdown.
658660
@@ -662,6 +664,8 @@ def interact_batch_comparison(
662664
batch_y: Index of bach for y-axis.
663665
rel: Whether to relativize data against status_quo arm.
664666
status_quo_name: Name of the status_quo arm.
667+
x_label: Label for the x-axis.
668+
y_label: Label for the y-axis.
665669
"""
666670
if isinstance(experiment, MultiTypeExperiment):
667671
observations = convert_mt_observations(observations, experiment)
@@ -670,11 +674,15 @@ def interact_batch_comparison(
670674
plot_data = _get_batch_comparison_plot_data(
671675
observations, batch_x, batch_y, rel=rel, status_quo_name=status_quo_name
672676
)
677+
if x_label is None:
678+
x_label = f"Batch {batch_x}"
679+
if y_label is None:
680+
y_label = f"Batch {batch_y}"
673681
fig = _obs_vs_pred_dropdown_plot(
674682
data=plot_data,
675683
rel=rel,
676-
xlabel="Batch {}".format(batch_x),
677-
ylabel="Batch {}".format(batch_y),
684+
xlabel=x_label,
685+
ylabel=y_label,
678686
)
679687
fig["layout"]["title"] = "Repeated arms across trials"
680688
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)

0 commit comments

Comments
 (0)