Skip to content

Commit

Permalink
Validate required metrics in AxClient for complete trials, (uncommitt…
Browse files Browse the repository at this point in the history
…ed/untracked changes), (uncommitted/untracked changes) (#2082)

Summary:
Pull Request resolved: #2082

Add a `_validate_all_required_metrics_present` and call in `AxClient.complete_trial` to ensure a trial has all required data to be marked `COMPLETE`

Reviewed By: saitcakmak, lena-kashtelyan

Differential Revision: D52057299

fbshipit-source-id: ecad7686b275f072579471c7b865589e5b61eb76
  • Loading branch information
eonofrey authored and facebook-github-bot committed Jan 11, 2024
1 parent 15d3122 commit 200168e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
32 changes: 31 additions & 1 deletion ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import DataType, Experiment
from ax.core.formatting_utils import data_and_evaluations_from_raw_data
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData
Expand Down Expand Up @@ -1626,7 +1627,16 @@ def _update_trial_with_raw_data(
)

if complete_trial:
trial.mark_completed()
if not self._validate_all_required_metrics_present(
raw_data=raw_data, trial_index=trial_index
):
logger.warning(
"Marking the trial as failed because it is missing one"
"or more required metrics."
)
trial.mark_failed()
else:
trial.mark_completed()

self._save_or_update_trial_in_db_if_possible(
experiment=self.experiment,
Expand Down Expand Up @@ -1789,6 +1799,26 @@ def _find_last_trial_with_parameterization(
f"No trial on experiment matches parameterization {parameterization}."
)

def _validate_all_required_metrics_present(
self, raw_data: TEvaluationOutcome, trial_index: int
) -> bool:
"""Check if all required metrics are present in the given raw data."""
opt_config = self.experiment.optimization_config
if opt_config is None:
return True

_, data = data_and_evaluations_from_raw_data(
raw_data={"data": raw_data},
sample_sizes={},
trial_index=trial_index,
data_type=self.experiment.default_data_type,
metric_names=opt_config.objective.metric_names,
)
required_metrics = set(opt_config.metrics.keys())
provided_metrics = data.metric_names
missing_metrics = required_metrics - provided_metrics
return not missing_metrics

@classmethod
def _get_pending_observation_features(
cls,
Expand Down
30 changes: 30 additions & 0 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,36 @@ def test_update_trial_data(self) -> None:
df = ax_client.experiment.lookup_data_for_trial(idx)[0].df
self.assertEqual(df["mean"].item(), 3.0)

# Incomplete trial fails
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=idx, raw_data={"missing_metric": (1, 0.0)})
self.assertTrue(ax_client.get_trial(idx).status.is_failed)

def test_incomplete_multi_fidelity_trial(self) -> None:
ax_client = AxClient()
ax_client.create_experiment(
parameters=[
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "y", "type": "range", "bounds": [0.0, 1.0]},
],
minimize=True,
objective_name="branin",
support_intermediate_data=True,
)
# Trial with complete data
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=idx, raw_data=[({"fidelity": 1}, {"branin": (123, 0.0)})]
)
self.assertTrue(ax_client.get_trial(idx).status.is_completed)
# Trial with incomplete data
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=idx,
raw_data=[({"fidelity": 2}, {"missing_metric": (456, 0.0)})],
)
self.assertTrue(ax_client.get_trial(idx).status.is_failed)

def test_trial_completion_with_metadata_with_iso_times(self) -> None:
ax_client = get_branin_optimization()
params, idx = ax_client.get_next_trial()
Expand Down

0 comments on commit 200168e

Please sign in to comment.