From 12ed9f76e5272927ceffe369adcae172e512f0b0 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Thu, 31 Oct 2024 12:06:02 -0700 Subject: [PATCH] Save error cards in scheduler (#3004) Summary: TODO: create a separate error card class in a follow up. TODO: Better reflect the would be card title in the error card title. Reviewed By: Cesar-Cardoso Differential Revision: D64993416 --- ax/analysis/analysis.py | 41 +++++++++++++++++-- ax/analysis/markdown/markdown_analysis.py | 4 +- .../plotly/arm_effects/insample_effects.py | 14 +++++-- ax/analysis/plotly/plotly_analysis.py | 4 +- .../plotly/tests/test_insample_effects.py | 4 +- ax/service/scheduler.py | 31 +++++++++++++- ax/service/tests/scheduler_test_utils.py | 19 ++++++++- ax/storage/sqa_store/decoder.py | 4 +- 8 files changed, 104 insertions(+), 17 deletions(-) diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index a6d7eb2564f..5d079802455 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -5,6 +5,7 @@ # pyre-strict +import json from collections.abc import Iterable from enum import IntEnum from logging import Logger @@ -146,9 +147,10 @@ def compute_result( logger.error(f"Failed to compute {self.__class__.__name__}: {e}") return Err( - value=ExceptionE( + value=AnalysisE( message=f"Failed to compute {self.__class__.__name__}", exception=e, + analysis=self, ) ) @@ -164,11 +166,44 @@ def _create_analysis_card( details about the Analysis class. """ return AnalysisCard( - name=self.__class__.__name__, - attributes=self.__dict__, + name=self.name, + attributes=self.attributes, title=title, subtitle=subtitle, level=level, df=df, blob=df.to_json(), ) + + @property + def name(self) -> str: + """The name the AnalysisCard will be given in compute.""" + return self.__class__.__name__ + + @property + def attributes(self) -> dict[str, Any]: + """The attributes the AnalysisCard will be given in compute.""" + return self.__dict__ + + def __repr__(self) -> str: + try: + return ( + f"{self.__class__.__name__}(name={self.name}, " + f"attributes={json.dumps(self.attributes)})" + ) + # in case there is logic in name or attributes that throws a json error + except Exception: + return self.__class__.__name__ + + +class AnalysisE(ExceptionE): + analysis: Analysis + + def __init__( + self, + message: str, + exception: Exception, + analysis: Analysis, + ) -> None: + super().__init__(message, exception) + self.analysis = analysis diff --git a/ax/analysis/markdown/markdown_analysis.py b/ax/analysis/markdown/markdown_analysis.py index ff75d30b5bf..75393630a0d 100644 --- a/ax/analysis/markdown/markdown_analysis.py +++ b/ax/analysis/markdown/markdown_analysis.py @@ -51,8 +51,8 @@ def _create_markdown_analysis_card( details about the Analysis class. """ return MarkdownAnalysisCard( - name=self.__class__.__name__, - attributes=self.__dict__, + name=self.name, + attributes=self.attributes, title=title, subtitle=subtitle, level=level, diff --git a/ax/analysis/plotly/arm_effects/insample_effects.py b/ax/analysis/plotly/arm_effects/insample_effects.py index c027e88bbb7..a42b57cdeed 100644 --- a/ax/analysis/plotly/arm_effects/insample_effects.py +++ b/ax/analysis/plotly/arm_effects/insample_effects.py @@ -129,15 +129,14 @@ def compute( max_trial_index = max(experiment.trial_indices_expecting_data, default=0) nudge -= min(max_trial_index - self.trial_index, 9) - plot_type = "Modeled" if self.use_modeled_effects else "Observed" subtitle = ( "View a trial and its arms' " - f"{'predicted' if self.use_modeled_effects else 'observed'} " + f"{self._plot_type_string.lower()} " "metric values" ) card = self._create_plotly_analysis_card( title=( - f"{plot_type} Effects for {self.metric_name} " + f"{self._plot_type_string} Effects for {self.metric_name} " f"on trial {self.trial_index}" ), subtitle=subtitle, @@ -145,9 +144,16 @@ def compute( df=df, fig=fig, ) - card.name = f"{plot_type}EffectsPlot" return card + @property + def name(self) -> str: + return f"{self._plot_type_string}EffectsPlot" + + @property + def _plot_type_string(self) -> str: + return "Modeled" if self.use_modeled_effects else "Observed" + def _get_max_observed_trial_index(model: ModelBridge) -> int | None: """Returns the max observed trial index to appease multitask models for prediction diff --git a/ax/analysis/plotly/plotly_analysis.py b/ax/analysis/plotly/plotly_analysis.py index 0f3dfeec86c..e45a28310b9 100644 --- a/ax/analysis/plotly/plotly_analysis.py +++ b/ax/analysis/plotly/plotly_analysis.py @@ -53,8 +53,8 @@ def _create_plotly_analysis_card( details about the Analysis class. """ return PlotlyAnalysisCard( - name=self.__class__.__name__, - attributes=self.__dict__, + name=self.name, + attributes=self.attributes, title=title, subtitle=subtitle, level=level, diff --git a/ax/analysis/plotly/tests/test_insample_effects.py b/ax/analysis/plotly/tests/test_insample_effects.py index 4124aa7b20c..7fa06004beb 100644 --- a/ax/analysis/plotly/tests/test_insample_effects.py +++ b/ax/analysis/plotly/tests/test_insample_effects.py @@ -172,7 +172,7 @@ def test_compute_modeled_can_use_ebts_for_gs_with_non_predictive_model( self.assertEqual(card.name, "ModeledEffectsPlot") self.assertEqual(card.title, "Modeled Effects for branin on trial 0") self.assertEqual( - card.subtitle, "View a trial and its arms' predicted metric values" + card.subtitle, "View a trial and its arms' modeled metric values" ) # +2 because it's on objective, +1 because it's modeled self.assertEqual(card.level, AnalysisCardLevel.MID + 3) @@ -218,7 +218,7 @@ def test_compute_modeled_can_use_ebts_for_no_gs(self) -> None: self.assertEqual(card.name, "ModeledEffectsPlot") self.assertEqual(card.title, "Modeled Effects for branin on trial 0") self.assertEqual( - card.subtitle, "View a trial and its arms' predicted metric values" + card.subtitle, "View a trial and its arms' modeled metric values" ) # +2 because it's on objective, +1 because it's modeled self.assertEqual(card.level, AnalysisCardLevel.MID + 3) diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 4b104baa415..2ba0f7b51e5 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -8,6 +8,8 @@ from __future__ import annotations +import traceback + from collections import defaultdict from collections.abc import Callable, Generator, Iterable from copy import deepcopy @@ -17,7 +19,9 @@ from typing import Any, cast, NamedTuple, Optional import ax.service.utils.early_stopping as early_stopping_utils -from ax.analysis.analysis import Analysis, AnalysisCard +import pandas as pd +from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE +from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment @@ -68,6 +72,7 @@ set_stderr_log_level, ) from ax.utils.common.timeutils import current_timestamp_in_millis +from ax.utils.common.typeutils import checked_cast from pyre_extensions import assert_is_instance, none_throws @@ -664,6 +669,30 @@ def compute_analyses( # TODO Accumulate Es into their own card, perhaps via unwrap_or_else cards = [result.unwrap() for result in results if result.is_ok()] + for result in results: + if result.is_err(): + e = checked_cast(AnalysisE, result.err) + traceback_str = "".join( + traceback.format_exception( + type(result.err.exception), + e.exception, + e.exception.__traceback__, + ) + ) + cards.append( + MarkdownAnalysisCard( + name=e.analysis.name, + # It would be better if we could reliably compute the title + # without risking another error + title=f"{e.analysis.name} Error", + subtitle=f"An error occurred while computing {e.analysis}", + attributes=e.analysis.attributes, + blob=traceback_str, + df=pd.DataFrame(), + level=AnalysisCardLevel.DEBUG, + ) + ) + self._save_analysis_cards_to_db_if_possible( analysis_cards=cards, experiment=self.experiment, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 56095e70e3b..5602bdb7e76 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -2763,21 +2763,36 @@ def test_generate_candidates_does_not_generate_if_missing_opt_config(self) -> No self.assertEqual(len(scheduler.experiment.trials), 1) def test_compute_analyses(self) -> None: - scheduler = Scheduler( + init_test_engine_and_session_factory(force_init=True) + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=get_generation_strategy(), + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions( total_trials=0, tolerated_trial_failure_rate=0.2, init_seconds_between_polls=10, **self.scheduler_options_kwargs, ), + db_settings=self.db_settings, ) with self.assertLogs(logger="ax.analysis", level="ERROR") as lg: cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()]) - self.assertEqual(len(cards), 0) + self.assertEqual(len(cards), 1) + # it saved the error card + self.assertIsNotNone(cards[0].db_id) + self.assertEqual(cards[0].name, "ParallelCoordinatesPlot") + self.assertEqual(cards[0].title, "ParallelCoordinatesPlot Error") + self.assertEqual( + cards[0].subtitle, + f"An error occurred while computing {ParallelCoordinatesPlot()}", + ) + self.assertIn("Traceback", cards[0].blob) self.assertTrue( any( ( diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 2283341be84..e30704aef12 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -1024,7 +1024,7 @@ def analysis_card_from_sqa( analysis_card_sqa: SQAAnalysisCard, ) -> AnalysisCard: """Convert SQLAlchemy Analysis to Ax Analysis Object.""" - return AnalysisCard( + card = AnalysisCard( name=analysis_card_sqa.name, title=analysis_card_sqa.title, subtitle=analysis_card_sqa.subtitle, @@ -1037,6 +1037,8 @@ def analysis_card_from_sqa( else json.loads(analysis_card_sqa.attributes) ), ) + card.db_id = analysis_card_sqa.id + return card def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric: """Convert SQLAlchemy Metric to Ax Metric"""