Skip to content

Commit

Permalink
Save error cards in scheduler (#3004)
Browse files Browse the repository at this point in the history
Summary:

TODO: create a separate error card class in a follow up.
TODO: Better reflect the would be card title in the error card title.

Reviewed By: Cesar-Cardoso

Differential Revision: D64993416
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 31, 2024
1 parent a9a9a7c commit 12ed9f7
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 17 deletions.
41 changes: 38 additions & 3 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# pyre-strict

import json
from collections.abc import Iterable
from enum import IntEnum
from logging import Logger
Expand Down Expand Up @@ -146,9 +147,10 @@ def compute_result(
logger.error(f"Failed to compute {self.__class__.__name__}: {e}")

return Err(
value=ExceptionE(
value=AnalysisE(
message=f"Failed to compute {self.__class__.__name__}",
exception=e,
analysis=self,
)
)

Expand All @@ -164,11 +166,44 @@ def _create_analysis_card(
details about the Analysis class.
"""
return AnalysisCard(
name=self.__class__.__name__,
attributes=self.__dict__,
name=self.name,
attributes=self.attributes,
title=title,
subtitle=subtitle,
level=level,
df=df,
blob=df.to_json(),
)

@property
def name(self) -> str:
"""The name the AnalysisCard will be given in compute."""
return self.__class__.__name__

@property
def attributes(self) -> dict[str, Any]:
"""The attributes the AnalysisCard will be given in compute."""
return self.__dict__

def __repr__(self) -> str:
try:
return (
f"{self.__class__.__name__}(name={self.name}, "
f"attributes={json.dumps(self.attributes)})"
)
# in case there is logic in name or attributes that throws a json error
except Exception:
return self.__class__.__name__


class AnalysisE(ExceptionE):
analysis: Analysis

def __init__(
self,
message: str,
exception: Exception,
analysis: Analysis,
) -> None:
super().__init__(message, exception)
self.analysis = analysis
4 changes: 2 additions & 2 deletions ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def _create_markdown_analysis_card(
details about the Analysis class.
"""
return MarkdownAnalysisCard(
name=self.__class__.__name__,
attributes=self.__dict__,
name=self.name,
attributes=self.attributes,
title=title,
subtitle=subtitle,
level=level,
Expand Down
14 changes: 10 additions & 4 deletions ax/analysis/plotly/arm_effects/insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,31 @@ def compute(
max_trial_index = max(experiment.trial_indices_expecting_data, default=0)
nudge -= min(max_trial_index - self.trial_index, 9)

plot_type = "Modeled" if self.use_modeled_effects else "Observed"
subtitle = (
"View a trial and its arms' "
f"{'predicted' if self.use_modeled_effects else 'observed'} "
f"{self._plot_type_string.lower()} "
"metric values"
)
card = self._create_plotly_analysis_card(
title=(
f"{plot_type} Effects for {self.metric_name} "
f"{self._plot_type_string} Effects for {self.metric_name} "
f"on trial {self.trial_index}"
),
subtitle=subtitle,
level=level + nudge,
df=df,
fig=fig,
)
card.name = f"{plot_type}EffectsPlot"
return card

@property
def name(self) -> str:
return f"{self._plot_type_string}EffectsPlot"

@property
def _plot_type_string(self) -> str:
return "Modeled" if self.use_modeled_effects else "Observed"


def _get_max_observed_trial_index(model: ModelBridge) -> int | None:
"""Returns the max observed trial index to appease multitask models for prediction
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/plotly_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def _create_plotly_analysis_card(
details about the Analysis class.
"""
return PlotlyAnalysisCard(
name=self.__class__.__name__,
attributes=self.__dict__,
name=self.name,
attributes=self.attributes,
title=title,
subtitle=subtitle,
level=level,
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/tests/test_insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_compute_modeled_can_use_ebts_for_gs_with_non_predictive_model(
self.assertEqual(card.name, "ModeledEffectsPlot")
self.assertEqual(card.title, "Modeled Effects for branin on trial 0")
self.assertEqual(
card.subtitle, "View a trial and its arms' predicted metric values"
card.subtitle, "View a trial and its arms' modeled metric values"
)
# +2 because it's on objective, +1 because it's modeled
self.assertEqual(card.level, AnalysisCardLevel.MID + 3)
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_compute_modeled_can_use_ebts_for_no_gs(self) -> None:
self.assertEqual(card.name, "ModeledEffectsPlot")
self.assertEqual(card.title, "Modeled Effects for branin on trial 0")
self.assertEqual(
card.subtitle, "View a trial and its arms' predicted metric values"
card.subtitle, "View a trial and its arms' modeled metric values"
)
# +2 because it's on objective, +1 because it's modeled
self.assertEqual(card.level, AnalysisCardLevel.MID + 3)
Expand Down
31 changes: 30 additions & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

import traceback

from collections import defaultdict
from collections.abc import Callable, Generator, Iterable
from copy import deepcopy
Expand All @@ -17,7 +19,9 @@
from typing import Any, cast, NamedTuple, Optional

import ax.service.utils.early_stopping as early_stopping_utils
from ax.analysis.analysis import Analysis, AnalysisCard
import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -68,6 +72,7 @@
set_stderr_log_level,
)
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import assert_is_instance, none_throws


Expand Down Expand Up @@ -664,6 +669,30 @@ def compute_analyses(
# TODO Accumulate Es into their own card, perhaps via unwrap_or_else
cards = [result.unwrap() for result in results if result.is_ok()]

for result in results:
if result.is_err():
e = checked_cast(AnalysisE, result.err)
traceback_str = "".join(
traceback.format_exception(
type(result.err.exception),
e.exception,
e.exception.__traceback__,
)
)
cards.append(
MarkdownAnalysisCard(
name=e.analysis.name,
# It would be better if we could reliably compute the title
# without risking another error
title=f"{e.analysis.name} Error",
subtitle=f"An error occurred while computing {e.analysis}",
attributes=e.analysis.attributes,
blob=traceback_str,
df=pd.DataFrame(),
level=AnalysisCardLevel.DEBUG,
)
)

self._save_analysis_cards_to_db_if_possible(
analysis_cards=cards,
experiment=self.experiment,
Expand Down
19 changes: 17 additions & 2 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2763,21 +2763,36 @@ def test_generate_candidates_does_not_generate_if_missing_opt_config(self) -> No
self.assertEqual(len(scheduler.experiment.trials), 1)

def test_compute_analyses(self) -> None:
scheduler = Scheduler(
init_test_engine_and_session_factory(force_init=True)
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=get_generation_strategy(),
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
total_trials=0,
tolerated_trial_failure_rate=0.2,
init_seconds_between_polls=10,
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings,
)

with self.assertLogs(logger="ax.analysis", level="ERROR") as lg:
cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()])

self.assertEqual(len(cards), 0)
self.assertEqual(len(cards), 1)
# it saved the error card
self.assertIsNotNone(cards[0].db_id)
self.assertEqual(cards[0].name, "ParallelCoordinatesPlot")
self.assertEqual(cards[0].title, "ParallelCoordinatesPlot Error")
self.assertEqual(
cards[0].subtitle,
f"An error occurred while computing {ParallelCoordinatesPlot()}",
)
self.assertIn("Traceback", cards[0].blob)
self.assertTrue(
any(
(
Expand Down
4 changes: 3 additions & 1 deletion ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def analysis_card_from_sqa(
analysis_card_sqa: SQAAnalysisCard,
) -> AnalysisCard:
"""Convert SQLAlchemy Analysis to Ax Analysis Object."""
return AnalysisCard(
card = AnalysisCard(
name=analysis_card_sqa.name,
title=analysis_card_sqa.title,
subtitle=analysis_card_sqa.subtitle,
Expand All @@ -1037,6 +1037,8 @@ def analysis_card_from_sqa(
else json.loads(analysis_card_sqa.attributes)
),
)
card.db_id = analysis_card_sqa.id
return card

def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric:
"""Convert SQLAlchemy Metric to Ax Metric"""
Expand Down

0 comments on commit 12ed9f7

Please sign in to comment.