Skip to content

Commit

Permalink
Redo constraint violation indicators and legend on insample and predi…
Browse files Browse the repository at this point in the history
…cted plots (facebook#2955)

Summary:

Instead of varying sizing, we're adding a red outline of varying opacity based on the overall probability of constraint violation.

There is a hack to make sure the outlines don't appear in the legend, as the legend normally reflects the first point in each group.  We're hiding the original and creating a duplicate group of the same name at point (None, None).  In doing this, we give up the ability to toggle points on and off the graph by clicking on their legend item, which I just discovered existed.

{F1944580069}

Reviewed By: ItsMrLin

Differential Revision: D64850289
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 25, 2024
1 parent 9803b25 commit 5f1c669
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 6 deletions.
89 changes: 86 additions & 3 deletions ax/analysis/plotly/arm_effects/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,60 @@ def prepare_arm_effects_plot(
color="source",
# TODO: can we format this by callable or string template?
hover_data=_get_parameter_columns(df),
size="size_column",
size_max=10,
# avoid red because it will match the constraint violation indicator
color_discrete_sequence=px.colors.qualitative.Vivid,
)
dot_size = 8
# set all dots to size 8 in plots
fig.update_traces(marker={"line": {"width": 2}, "size": dot_size})

# Manually create each constraint violation indicator
# as a red outline around the dot, with alpha based on the
# probability of constraint violation.
for trace in fig.data:
# there is a trace per source, so get the rows of df
# pertaining to this trace
indices = df["source"] == trace.name
trace.marker.line.color = [
# raising the alpha to a power < 1 makes the colors more
# visible when there is a lower chance of constraint violation
f"rgba(255, 0, 0, {(alpha) ** .75})"
for alpha in df.loc[indices, "overall_probability_constraints_violated"]
]
# Create a separate trace for the legend, otherwise the legend
# will have the constraint violation indicator of the first arm
# in the source group
legend_trace = go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="markers",
marker={
"size": dot_size,
"color": trace.marker.color,
},
name=trace.name,
)
fig.add_trace(legend_trace)
trace.showlegend = False

# Add an item to the legend for the constraint violation indicator
legend_trace = go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="markers",
marker={
"size": dot_size,
"color": "white",
"line": {"width": 2, "color": "red"},
},
name="Constraint Violation",
)
fig.add_trace(legend_trace)

_add_style_to_effects_by_arm_plot(
fig=fig, df=df, metric_name=metric_name, outcome_constraints=outcome_constraints
)
Expand Down Expand Up @@ -100,6 +151,20 @@ def _add_style_to_effects_by_arm_plot(
y=df[df["arm_name"] == "status_quo"]["mean"].iloc[0],
line_width=1,
line_color="red",
showlegend=True,
name="Status Quo Mean",
)
# Add the status quo mean to the legend
fig.add_trace(
go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="lines",
line={"color": "red", "width": 1},
name="Status Quo Mean",
)
)
for constraint in outcome_constraints:
if constraint.metric.name == metric_name:
Expand All @@ -110,10 +175,25 @@ def _add_style_to_effects_by_arm_plot(
line_color="red",
line_dash="dash",
)
# Add the constraint bound to the legend
fig.add_trace(
go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="lines",
line={"color": "red", "width": 1, "dash": "dash"},
name="Constraint Bound",
)
)
fig.update_layout(
xaxis={
"tickangle": 45,
},
legend={
"title": None,
},
)


Expand Down Expand Up @@ -206,7 +286,10 @@ def get_predictions_by_arm(
"constraints_violated": format_constraint_violated_probabilities(
constraints_violated[i]
),
"size_column": 100 - probabilities_not_feasible[i] * 100,
# used for constraint violation indicator
"overall_probability_constraints_violated": round(
probabilities_not_feasible[i], ndigits=2
),
"parameters": format_parameters_for_effects_by_arm_plot(
parameters=features[i].parameters
),
Expand Down
4 changes: 3 additions & 1 deletion ax/analysis/plotly/tests/test_insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ def test_constraints(self) -> None:
str(non_sq_df["constraints_violated"][0]),
)
# AND THEN it marks that constraints are not violated for the SQ
self.assertEqual(sq_row["size_column"].iloc[0], 100)
self.assertEqual(
sq_row["overall_probability_constraints_violated"].iloc[0], 0
)
self.assertEqual(
sq_row["constraints_violated"].iloc[0], "No constraints violated"
)
Expand Down
6 changes: 4 additions & 2 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_compute(self) -> None:
"sem",
"error_margin",
"constraints_violated",
"size_column",
"overall_probability_constraints_violated",
},
)
self.assertIsNotNone(card.blob)
Expand Down Expand Up @@ -380,7 +380,9 @@ def test_constraints(self) -> None:
str(non_sq_df["constraints_violated"][0]),
)
# AND THEN it marks that constraints are not violated for the SQ
self.assertEqual(sq_row["size_column"].iloc[0], 100)
self.assertEqual(
sq_row["overall_probability_constraints_violated"].iloc[0], 0
)
self.assertEqual(
sq_row["constraints_violated"].iloc[0], "No constraints violated"
)
Expand Down

0 comments on commit 5f1c669

Please sign in to comment.