From 200168e6ee92ad29656667cc0204e0cae6eae37d Mon Sep 17 00:00:00 2001 From: Eric Onofrey Date: Thu, 11 Jan 2024 15:14:29 -0800 Subject: [PATCH] Validate required metrics in AxClient for complete trials, (uncommitted/untracked changes), (uncommitted/untracked changes) (#2082) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2082 Add a `_validate_all_required_metrics_present` and call in `AxClient.complete_trial` to ensure a trial has all required data to be marked `COMPLETE` Reviewed By: saitcakmak, lena-kashtelyan Differential Revision: D52057299 fbshipit-source-id: ecad7686b275f072579471c7b865589e5b61eb76 --- ax/service/ax_client.py | 32 +++++++++++++++++++++++++++++- ax/service/tests/test_ax_client.py | 30 ++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 3caee6a4625..bd7aebddfd3 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -32,6 +32,7 @@ from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import DataType, Experiment +from ax.core.formatting_utils import data_and_evaluations_from_raw_data from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData @@ -1626,7 +1627,16 @@ def _update_trial_with_raw_data( ) if complete_trial: - trial.mark_completed() + if not self._validate_all_required_metrics_present( + raw_data=raw_data, trial_index=trial_index + ): + logger.warning( + "Marking the trial as failed because it is missing one" + "or more required metrics." + ) + trial.mark_failed() + else: + trial.mark_completed() self._save_or_update_trial_in_db_if_possible( experiment=self.experiment, @@ -1789,6 +1799,26 @@ def _find_last_trial_with_parameterization( f"No trial on experiment matches parameterization {parameterization}." ) + def _validate_all_required_metrics_present( + self, raw_data: TEvaluationOutcome, trial_index: int + ) -> bool: + """Check if all required metrics are present in the given raw data.""" + opt_config = self.experiment.optimization_config + if opt_config is None: + return True + + _, data = data_and_evaluations_from_raw_data( + raw_data={"data": raw_data}, + sample_sizes={}, + trial_index=trial_index, + data_type=self.experiment.default_data_type, + metric_names=opt_config.objective.metric_names, + ) + required_metrics = set(opt_config.metrics.keys()) + provided_metrics = data.metric_names + missing_metrics = required_metrics - provided_metrics + return not missing_metrics + @classmethod def _get_pending_observation_features( cls, diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 491616683cd..13adae51326 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -1618,6 +1618,36 @@ def test_update_trial_data(self) -> None: df = ax_client.experiment.lookup_data_for_trial(idx)[0].df self.assertEqual(df["mean"].item(), 3.0) + # Incomplete trial fails + params, idx = ax_client.get_next_trial() + ax_client.complete_trial(trial_index=idx, raw_data={"missing_metric": (1, 0.0)}) + self.assertTrue(ax_client.get_trial(idx).status.is_failed) + + def test_incomplete_multi_fidelity_trial(self) -> None: + ax_client = AxClient() + ax_client.create_experiment( + parameters=[ + {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, + {"name": "y", "type": "range", "bounds": [0.0, 1.0]}, + ], + minimize=True, + objective_name="branin", + support_intermediate_data=True, + ) + # Trial with complete data + params, idx = ax_client.get_next_trial() + ax_client.complete_trial( + trial_index=idx, raw_data=[({"fidelity": 1}, {"branin": (123, 0.0)})] + ) + self.assertTrue(ax_client.get_trial(idx).status.is_completed) + # Trial with incomplete data + params, idx = ax_client.get_next_trial() + ax_client.complete_trial( + trial_index=idx, + raw_data=[({"fidelity": 2}, {"missing_metric": (456, 0.0)})], + ) + self.assertTrue(ax_client.get_trial(idx).status.is_failed) + def test_trial_completion_with_metadata_with_iso_times(self) -> None: ax_client = get_branin_optimization() params, idx = ax_client.get_next_trial()