diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index d460dfd4bad..1d2f5e7eb8c 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -5,7 +5,6 @@ # pyre-strict -# pyre-unsafe import math from typing import Any @@ -19,11 +18,9 @@ from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard -from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures -from ax.exceptions.core import UserInputError from ax.modelbridge.registry import Models from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -40,7 +37,6 @@ from plotly import graph_objects as go, io as pio from plotly.subplots import make_subplots from pyre_extensions import none_throws -from torch import Tensor TOP_K_TOO_LARGE_ERROR = ( @@ -142,57 +138,31 @@ class InteractionPlot(PlotlyAnalysis): def __init__( self, metric_name: str, - top_k: int = 6, - data: Data | None = None, - most_important: bool = True, fit_interactions: bool = True, - display_components: bool = False, - decompose_components: bool = False, - plots_share_range: bool = True, - num_mc_samples: int = 10_000, - model_fit_seed: int = 0, + most_important: bool = True, + seed: int = 0, torch_device: torch.device | None = None, ) -> None: """Constructor for InteractionAnalysis. Args: metric_name: The metric to analyze. - top_k: The 'k' most imortant interactions according to Sobol indices. - Supports up to 6 components visualized at once. - data: The data to analyze. Defaults to None, in which case the data is taken - from the experiment. - most_important: Whether to plot the most or least important interactions. fit_interactions: Whether to fit interaction effects in addition to main effects. - display_components: Display individual components instead of the summarized - plot of sobol index values. - decompose_components: Whether to visualize surfaces as the total effect of - x1 & x2 (False) or only the interaction term (True). Setting - decompose_components = True thus plots f(x1, x2) - f(x1) - f(x2). - plots_share_range: Whether to have all plots share the same output range in - the final visualization. - num_mc_samples: The number of Monte Carlo samples to use for the Sobol - index calculations. - model_fit_seed: The seed with which to fit the model. Defaults to 0. Used + most_important: Whether to sort by most or least important features in the + bar subplot. Also controls whether the six most or least important + features are plotted in the surface subplots. + seed: The seed with which to fit the model. Defaults to 0. Used to ensure that the model fit is identical across the generation of various plots. torch_device: The torch device to use for the model. """ - super().__init__() - if top_k > 6 and display_components: - raise UserInputError(TOP_K_TOO_LARGE_ERROR.format(str(top_k))) - self.metric_name: str = metric_name - self.top_k: int = top_k - self.data: Data | None = data - self.most_important: bool = most_important - self.fit_interactions: bool = fit_interactions - self.display_components: bool = display_components - self.decompose_components: bool = decompose_components - self.num_mc_samples: int = num_mc_samples - self.model_fit_seed: int = model_fit_seed - self.torch_device: torch.device | None = torch_device - self.plots_share_range: bool = plots_share_range + self.metric_name = metric_name + self.fit_interactions = fit_interactions + self.most_important = most_important + self.seed = seed + self.torch_device = torch_device def get_model( self, experiment: Experiment, metric_names: list[str] | None = None @@ -207,19 +177,16 @@ def get_model( num_parameters=len(experiment.search_space.tunable_parameters), torch_device=self.torch_device, ) - data = experiment.lookup_data() if self.data is None else self.data + data = experiment.lookup_data() if metric_names: data = data.filter(metric_names=metric_names) - with torch.random.fork_rng(): - # fixing the seed to ensure that the model is fit identically across - # different analyses of the same experiment - torch.torch.manual_seed(self.model_fit_seed) - model_bridge = Models.BOTORCH_MODULAR( - search_space=experiment.search_space, - experiment=experiment, - data=data, - surrogate=Surrogate(**covar_module_kwargs), - ) + + model_bridge = Models.BOTORCH_MODULAR( + search_space=experiment.search_space, + experiment=experiment, + data=data, + surrogate=Surrogate(**covar_module_kwargs), + ) return model_bridge # pyre-ignore[7] Return type is always a TorchModelBridge # pyre-ignore[14] Must pass in an Experiment (not Experiment | None) @@ -241,19 +208,29 @@ def compute( """ experiment = none_throws(experiment) model_bridge = self.get_model(experiment, [self.metric_name]) - with torch.random.fork_rng(): - # fixing the seed to ensure that the model is fit identically across - # different analyses of the same experiment - torch.torch.manual_seed(self.model_fit_seed) - sens = ax_parameter_sens( - model_bridge=model_bridge, - metrics=[self.metric_name], - order="second" if self.fit_interactions else "first", - signed=not self.fit_interactions, - num_mc_samples=self.num_mc_samples, - ) + sens = ax_parameter_sens( + model_bridge=model_bridge, + metrics=[self.metric_name], + order="second" if self.fit_interactions else "first", + signed=not self.fit_interactions, + ) sens = sort_and_filter_top_k_components( - indices=sens, k=self.top_k, most_important=self.most_important + indices=sens, + k=6, + ) + # reformat the keys from tuple to a proper "x1 & x2" string + interaction_name = "Interaction" if self.fit_interactions else "Main Effect" + return PlotlyAnalysisCard( + name="Interaction Analysis", + title="Feature Importance Analysis", + subtitle=f"{interaction_name} Analysis for {self.metric_name}", + level=AnalysisCardLevel.MID, + df=pd.DataFrame(sens), + blob=pio.to_json( + plot_feature_importance_by_feature_plotly( + sensitivity_values=sens, # pyre-ignore[6] + ) + ), ) if not self.display_components: return PlotlyAnalysisCard(