Skip to content

Commit

Permalink
Remove various args from InteractionPlot.__init__
Browse files Browse the repository at this point in the history
Summary:
Simplifying InteractionPlot by removing some arguments from its initializer. If this simplifications seem appropriate then I will continue to simplify the Plot.

Removals:
* top_k: Prefer to show sobol indices for all components on bar chart, slice/sufrace for top 6 always
* data: Let's always use the data on the experiment
* most_important: Always sort most important to least important, never least to most
* display_components: Always display components
* decompose_components: Never decompose components
* plots_share_range: Always share range
* num_mc_samples: Always use 10k samples
* [RFC] model_fit_seed: Do not bother with seed setting -- we dont do this for any other plots so its probably not worth the complexity here

The following diffs will restructure the code here to take advantage of the simplified options

Differential Revision: D65148289
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Nov 21, 2024
1 parent 81257e6 commit 8f5fcf3
Showing 1 changed file with 41 additions and 64 deletions.
105 changes: 41 additions & 64 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

# pyre-unsafe

import math
from typing import Any
Expand All @@ -19,11 +18,9 @@
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.modelbridge.registry import Models
from ax.modelbridge.torch import TorchModelBridge
from ax.models.torch.botorch_modular.surrogate import Surrogate
Expand All @@ -40,7 +37,6 @@
from plotly import graph_objects as go, io as pio
from plotly.subplots import make_subplots
from pyre_extensions import none_throws
from torch import Tensor


TOP_K_TOO_LARGE_ERROR = (
Expand Down Expand Up @@ -142,57 +138,31 @@ class InteractionPlot(PlotlyAnalysis):
def __init__(
self,
metric_name: str,
top_k: int = 6,
data: Data | None = None,
most_important: bool = True,
fit_interactions: bool = True,
display_components: bool = False,
decompose_components: bool = False,
plots_share_range: bool = True,
num_mc_samples: int = 10_000,
model_fit_seed: int = 0,
most_important: bool = True,
seed: int = 0,
torch_device: torch.device | None = None,
) -> None:
"""Constructor for InteractionAnalysis.
Args:
metric_name: The metric to analyze.
top_k: The 'k' most imortant interactions according to Sobol indices.
Supports up to 6 components visualized at once.
data: The data to analyze. Defaults to None, in which case the data is taken
from the experiment.
most_important: Whether to plot the most or least important interactions.
fit_interactions: Whether to fit interaction effects in addition to main
effects.
display_components: Display individual components instead of the summarized
plot of sobol index values.
decompose_components: Whether to visualize surfaces as the total effect of
x1 & x2 (False) or only the interaction term (True). Setting
decompose_components = True thus plots f(x1, x2) - f(x1) - f(x2).
plots_share_range: Whether to have all plots share the same output range in
the final visualization.
num_mc_samples: The number of Monte Carlo samples to use for the Sobol
index calculations.
model_fit_seed: The seed with which to fit the model. Defaults to 0. Used
most_important: Whether to sort by most or least important features in the
bar subplot. Also controls whether the six most or least important
features are plotted in the surface subplots.
seed: The seed with which to fit the model. Defaults to 0. Used
to ensure that the model fit is identical across the generation of
various plots.
torch_device: The torch device to use for the model.
"""

super().__init__()
if top_k > 6 and display_components:
raise UserInputError(TOP_K_TOO_LARGE_ERROR.format(str(top_k)))
self.metric_name: str = metric_name
self.top_k: int = top_k
self.data: Data | None = data
self.most_important: bool = most_important
self.fit_interactions: bool = fit_interactions
self.display_components: bool = display_components
self.decompose_components: bool = decompose_components
self.num_mc_samples: int = num_mc_samples
self.model_fit_seed: int = model_fit_seed
self.torch_device: torch.device | None = torch_device
self.plots_share_range: bool = plots_share_range
self.metric_name = metric_name
self.fit_interactions = fit_interactions
self.most_important = most_important
self.seed = seed
self.torch_device = torch_device

def get_model(
self, experiment: Experiment, metric_names: list[str] | None = None
Expand All @@ -207,19 +177,16 @@ def get_model(
num_parameters=len(experiment.search_space.tunable_parameters),
torch_device=self.torch_device,
)
data = experiment.lookup_data() if self.data is None else self.data
data = experiment.lookup_data()
if metric_names:
data = data.filter(metric_names=metric_names)
with torch.random.fork_rng():
# fixing the seed to ensure that the model is fit identically across
# different analyses of the same experiment
torch.torch.manual_seed(self.model_fit_seed)
model_bridge = Models.BOTORCH_MODULAR(
search_space=experiment.search_space,
experiment=experiment,
data=data,
surrogate=Surrogate(**covar_module_kwargs),
)

model_bridge = Models.BOTORCH_MODULAR(
search_space=experiment.search_space,
experiment=experiment,
data=data,
surrogate=Surrogate(**covar_module_kwargs),
)
return model_bridge # pyre-ignore[7] Return type is always a TorchModelBridge

# pyre-ignore[14] Must pass in an Experiment (not Experiment | None)
Expand All @@ -241,19 +208,29 @@ def compute(
"""
experiment = none_throws(experiment)
model_bridge = self.get_model(experiment, [self.metric_name])
with torch.random.fork_rng():
# fixing the seed to ensure that the model is fit identically across
# different analyses of the same experiment
torch.torch.manual_seed(self.model_fit_seed)
sens = ax_parameter_sens(
model_bridge=model_bridge,
metrics=[self.metric_name],
order="second" if self.fit_interactions else "first",
signed=not self.fit_interactions,
num_mc_samples=self.num_mc_samples,
)
sens = ax_parameter_sens(
model_bridge=model_bridge,
metrics=[self.metric_name],
order="second" if self.fit_interactions else "first",
signed=not self.fit_interactions,
)
sens = sort_and_filter_top_k_components(
indices=sens, k=self.top_k, most_important=self.most_important
indices=sens,
k=6,
)
# reformat the keys from tuple to a proper "x1 & x2" string
interaction_name = "Interaction" if self.fit_interactions else "Main Effect"
return PlotlyAnalysisCard(
name="Interaction Analysis",
title="Feature Importance Analysis",
subtitle=f"{interaction_name} Analysis for {self.metric_name}",
level=AnalysisCardLevel.MID,
df=pd.DataFrame(sens),
blob=pio.to_json(
plot_feature_importance_by_feature_plotly(
sensitivity_values=sens, # pyre-ignore[6]
)
),
)
if not self.display_components:
return PlotlyAnalysisCard(
Expand Down

0 comments on commit 8f5fcf3

Please sign in to comment.