From 6544a94ed2c8b734b13a34b8b3d74fb1675237c2 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Thu, 31 Oct 2024 09:12:02 -0700 Subject: [PATCH] Replace not_none with none_throws in Ax (#3002) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3002 Was working on something else and it was bothering me that we had both not_none and none_throws both being used in our codebase. Since we seem to be moving away from typeutils generally I think its worthwhile to coalesce all usage around none_throws. Reviewed By: Balandat Differential Revision: D65275698 --- ax/benchmark/tests/test_benchmark.py | 19 +++---- ax/benchmark/tests/test_benchmark_method.py | 6 +-- ax/core/base_trial.py | 16 +++--- ax/core/batch_trial.py | 17 ++++--- ax/core/data.py | 5 +- ax/core/experiment.py | 17 ++++--- ax/core/generation_strategy_interface.py | 4 +- ax/core/generator_run.py | 4 +- ax/core/objective.py | 6 +-- ax/core/observation.py | 5 +- ax/core/parameter.py | 29 ++++++----- ax/core/search_space.py | 12 ++--- ax/core/tests/test_observation.py | 6 +-- ax/core/tests/test_parameter.py | 4 +- ax/core/trial.py | 13 +++-- ax/early_stopping/strategies/base.py | 7 +-- ax/early_stopping/strategies/percentile.py | 4 +- ax/early_stopping/tests/test_strategies.py | 4 +- ax/global_stopping/strategies/improvement.py | 7 +-- ax/global_stopping/tests/test_strategies.py | 9 ++-- ax/metrics/branin_map.py | 5 +- ax/metrics/chemistry.py | 4 +- ax/metrics/torchx.py | 4 +- ax/modelbridge/base.py | 11 +++-- ax/modelbridge/best_model_selector.py | 4 +- ax/modelbridge/dispatch_utils.py | 6 +-- ax/modelbridge/generation_node.py | 6 +-- ax/modelbridge/generation_strategy.py | 2 +- ax/modelbridge/map_torch.py | 4 +- ax/modelbridge/model_spec.py | 8 +-- ax/modelbridge/modelbridge_utils.py | 11 +++-- ax/modelbridge/registry.py | 9 ++-- ax/modelbridge/tests/test_base_modelbridge.py | 8 +-- ax/modelbridge/tests/test_dispatch_utils.py | 38 +++++++------- .../tests/test_external_generation_node.py | 8 +-- .../tests/test_generation_strategy.py | 31 ++++++------ .../tests/test_hierarchical_search_space.py | 11 +++-- ax/modelbridge/tests/test_model_spec.py | 6 +-- .../tests/test_modelbridge_utils.py | 24 +++++---- ax/modelbridge/tests/test_prediction_utils.py | 10 ++-- .../tests/test_torch_modelbridge.py | 19 +++---- .../tests/test_torch_moo_modelbridge.py | 13 +++-- ax/modelbridge/tests/test_utils.py | 3 +- ax/modelbridge/torch.py | 24 ++++----- ax/modelbridge/transforms/cast.py | 5 +- .../transforms/convert_metric_names.py | 4 +- ax/modelbridge/transforms/derelativize.py | 4 +- ax/modelbridge/transforms/int_to_float.py | 5 +- ax/modelbridge/transforms/relativize.py | 6 +-- .../transforms/stratified_standardize_y.py | 13 ++--- .../tests/test_relativize_transform.py | 21 ++++---- ax/modelbridge/transforms/time_as_feature.py | 5 +- ax/modelbridge/transition_criterion.py | 4 +- ax/models/random/sobol.py | 6 +-- ax/models/tests/test_botorch_defaults.py | 5 +- ax/models/tests/test_botorch_model.py | 4 +- ax/models/tests/test_random.py | 6 +-- ax/models/tests/test_torch_model_utils.py | 4 +- ax/models/tests/test_torch_utils.py | 17 ++++--- .../torch/botorch_modular/acquisition.py | 3 +- ax/models/torch/botorch_modular/sebo.py | 6 +-- ax/models/torch/botorch_modular/utils.py | 5 +- ax/models/torch/botorch_moo.py | 9 ++-- ax/models/torch/botorch_moo_defaults.py | 7 +-- ax/models/torch/tests/test_sebo.py | 8 +-- ax/models/torch/tests/test_surrogate.py | 6 +-- ax/models/torch/tests/test_utils.py | 9 ++-- ax/plot/diagnostic.py | 4 +- ax/plot/helper.py | 8 +-- ax/plot/pareto_frontier.py | 13 ++--- ax/plot/slice.py | 6 +-- ax/plot/trace.py | 8 +-- ax/runners/torchx.py | 4 +- ax/service/ax_client.py | 41 +++++++++------- ax/service/managed_loop.py | 6 +-- ax/service/tests/scheduler_test_utils.py | 12 ++--- ax/service/tests/test_ax_client.py | 18 +++---- ax/service/tests/test_best_point.py | 7 +-- ax/service/tests/test_best_point_utils.py | 32 ++++++------ ax/service/tests/test_report_utils.py | 5 +- ax/service/utils/best_point.py | 15 +++--- ax/service/utils/best_point_mixin.py | 25 +++++----- ax/service/utils/early_stopping.py | 4 +- ax/service/utils/instantiation.py | 4 +- ax/service/utils/report_utils.py | 35 +++++++------ ax/service/utils/with_db_settings_base.py | 8 +-- ax/storage/json_store/decoder.py | 5 +- ax/storage/sqa_store/decoder.py | 49 +++++++++---------- ax/storage/sqa_store/encoder.py | 13 ++--- ax/storage/sqa_store/load.py | 5 +- ax/storage/sqa_store/save.py | 5 +- ax/storage/sqa_store/tests/test_sqa_store.py | 21 ++++---- ax/utils/measurement/synthetic_functions.py | 6 +-- ax/utils/sensitivity/derivative_measures.py | 19 +++---- ax/utils/testing/core_stubs.py | 17 ++++--- ax/utils/testing/preference_stubs.py | 7 +-- 96 files changed, 542 insertions(+), 485 deletions(-) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 80347724bcb..3f994262bc9 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -32,7 +32,6 @@ from ax.storage.json_store.load import load_experiment from ax.storage.json_store.save import save_experiment from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.benchmark_stubs import ( get_moo_surrogate, get_multi_objective_benchmark_problem, @@ -102,7 +101,7 @@ def test_storage(self) -> None: # test saving to temporary file with tempfile.NamedTemporaryFile(mode="w", delete=True, suffix=".json") as f: - save_experiment(not_none(res.experiment), f.name) + save_experiment(none_throws(res.experiment), f.name) res.experiment_storage_id = f.name res.experiment = None self.assertIsNone(res.experiment) @@ -154,7 +153,9 @@ def test_replication_sobol_synthetic(self) -> None: for problem in problems: res = benchmark_replication(problem=problem, method=method, seed=0) - self.assertEqual(problem.num_trials, len(not_none(res.experiment).trials)) + self.assertEqual( + problem.num_trials, len(none_throws(res.experiment).trials) + ) self.assertTrue(np.isfinite(res.score_trace).all()) self.assertTrue(np.all(res.score_trace <= 100)) experiment = none_throws(res.experiment) @@ -180,7 +181,7 @@ def test_replication_sobol_surrogate(self) -> None: self.assertEqual( problem.num_trials, - len(not_none(res.experiment).trials), + len(none_throws(res.experiment).trials), ) self.assertTrue(np.isfinite(res.score_trace).all()) @@ -328,7 +329,7 @@ def test_replication_mbm(self) -> None: res = benchmark_replication(problem=problem, method=method, seed=0) self.assertEqual( problem.num_trials, - len(not_none(res.experiment).trials), + len(none_throws(res.experiment).trials), ) self.assertTrue(np.all(res.score_trace <= 100)) self.assertEqual(method.name, method.generation_strategy.name) @@ -345,11 +346,11 @@ def test_replication_moo_sobol(self) -> None: self.assertEqual( problem.num_trials, - len(not_none(res.experiment).trials), + len(none_throws(res.experiment).trials), ) self.assertEqual( problem.num_trials * 2, - len(not_none(res.experiment).fetch_data().df), + len(none_throws(res.experiment).fetch_data().df), ) self.assertTrue(np.all(res.score_trace <= 100)) @@ -365,7 +366,7 @@ def test_benchmark_one_method_problem(self) -> None: self.assertEqual(len(agg.results), 2) self.assertTrue( all( - len(not_none(result.experiment).trials) == problem.num_trials + len(none_throws(result.experiment).trials) == problem.num_trials for result in agg.results ), "All experiments must have 4 trials", @@ -457,5 +458,5 @@ def test_replication_with_generation_node(self) -> None: problem = get_single_objective_benchmark_problem() res = benchmark_replication(problem=problem, method=method, seed=0) - self.assertEqual(problem.num_trials, len(not_none(res.experiment).trials)) + self.assertEqual(problem.num_trials, len(none_throws(res.experiment).trials)) self.assertTrue(np.isnan(res.score_trace).all()) diff --git a/ax/benchmark/tests/test_benchmark_method.py b/ax/benchmark/tests/test_benchmark_method.py index a566535e4c3..46d6d01eea3 100644 --- a/ax/benchmark/tests/test_benchmark_method.py +++ b/ax/benchmark/tests/test_benchmark_method.py @@ -17,7 +17,7 @@ from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import Models from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws class TestBenchmarkMethod(TestCase): @@ -33,7 +33,7 @@ def test_benchmark_method(self) -> None: # test that `fit_tracking_metrics` has been correctly set to False for step in method.generation_strategy._steps: - self.assertFalse(not_none(step.model_kwargs).get("fit_tracking_metrics")) + self.assertFalse(none_throws(step.model_kwargs).get("fit_tracking_metrics")) self.assertEqual(method.scheduler_options, options) self.assertEqual(options.max_pending_trials, 1) @@ -60,7 +60,7 @@ def test_benchmark_method(self) -> None: ) for node in method.generation_strategy._nodes: self.assertFalse( - not_none(node.model_spec_to_gen_from.model_kwargs).get( + none_throws(node.model_spec_to_gen_from.model_kwargs).get( "fit_tracking_metrics" ) ) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 3717d49a073..148c289c9cc 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -26,7 +26,7 @@ from ax.core.types import TCandidateMetadata, TEvaluationOutcome from ax.exceptions.core import UnsupportedError from ax.utils.common.base import SortableBase -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws if TYPE_CHECKING: @@ -283,7 +283,7 @@ def index(self) -> int: def status(self) -> TrialStatus: """The status of the trial in the experimentation lifecycle.""" self._mark_failed_if_past_TTL() - return not_none(self._status) + return none_throws(self._status) @status.setter def status(self, status: TrialStatus) -> None: @@ -413,9 +413,9 @@ def run(self) -> BaseTrial: if self._runner is None: raise ValueError("No runner set on trial or experiment.") - self.update_run_metadata(not_none(self._runner).run(self)) + self.update_run_metadata(none_throws(self._runner).run(self)) - if not_none(self._runner).staging_required: + if none_throws(self._runner).staging_required: self.mark_staged() else: self.mark_running() @@ -452,7 +452,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial: self.assign_runner() if self._runner is None: raise ValueError("No runner set on trial or experiment.") - runner = not_none(self._runner) + runner = none_throws(self._runner) self._stop_metadata = runner.stop(self, reason=reason) self.mark_as(new_status) @@ -713,7 +713,7 @@ def mark_abandoned( Returns: The trial instance. """ - if not unsafe and not_none(self._status).is_terminal: + if not unsafe and none_throws(self._status).is_terminal: raise ValueError("Cannot abandon a trial in a terminal state.") self._abandoned_reason = reason @@ -792,12 +792,12 @@ def _mark_failed_if_past_TTL(self) -> None: """If trial has TTL set and is running, check if the TTL has elapsed and mark the trial failed if so. """ - if self.ttl_seconds is None or not not_none(self._status).is_running: + if self.ttl_seconds is None or not none_throws(self._status).is_running: return time_run_started = self._time_run_started assert time_run_started is not None dt = datetime.now() - time_run_started - if dt > timedelta(seconds=not_none(self.ttl_seconds)): + if dt > timedelta(seconds=none_throws(self.ttl_seconds)): self.mark_failed() @property diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index fbc36778ce5..0507105131f 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -34,7 +34,8 @@ from ax.utils.common.docutils import copy_doc from ax.utils.common.equality import datetime_equals, equality_typechecker from ax.utils.common.logger import _round_floats_for_logging, get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -317,7 +318,9 @@ def add_generator_run( generator_run.index = len(self._generator_run_structs) - 1 if self.status_quo is not None and self.optimize_for_power: - self.set_status_quo_and_optimize_power(status_quo=not_none(self.status_quo)) + self.set_status_quo_and_optimize_power( + status_quo=none_throws(self.status_quo) + ) if generator_run._generation_step_index is not None: self._set_generation_step_index( @@ -403,7 +406,7 @@ def set_status_quo_and_optimize_power(self, status_quo: Arm) -> BatchTrial: return self # arm_weights should always have at least one arm now - arm_weights = not_none(self.arm_weights) + arm_weights = none_throws(self.arm_weights) sum_weights = sum(w for arm, w in arm_weights.items() if arm != status_quo) optimal_status_quo_weight_override = np.sqrt(sum_weights) self.set_status_quo_with_weight( @@ -480,7 +483,7 @@ def is_factorial(self) -> bool: param_levels: defaultdict[str, dict[str | float, int]] = defaultdict(dict) for arm in self.arms: for param_name, param_value in arm.parameters.items(): - param_levels[param_name][not_none(param_value)] = 1 + param_levels[param_name][none_throws(param_value)] = 1 param_cardinality = 1 for param_values in param_levels.values(): param_cardinality *= len(param_values) @@ -701,7 +704,7 @@ def _get_candidate_metadata(self, arm_name: str) -> TCandidateMetadata: for gr_struct in self._generator_run_structs: gr = gr_struct.generator_run if gr and gr.candidate_metadata_by_arm_signature and arm in gr.arms: - return not_none(gr.candidate_metadata_by_arm_signature).get( + return none_throws(gr.candidate_metadata_by_arm_signature).get( arm.signature ) return None @@ -710,8 +713,8 @@ def _validate_batch_trial_data(self, data: Data) -> None: """Utility function to validate batch data before further processing.""" if ( self.status_quo - and not_none(self.status_quo).name in self.arms_by_name - and not_none(self.status_quo).name not in data.df["arm_name"].values + and none_throws(self.status_quo).name in self.arms_by_name + and none_throws(self.status_quo).name not in data.df["arm_name"].values ): raise AxError( f"Trial #{self.index} was completed with data that did " diff --git a/ax/core/data.py b/ax/core/data.py index 18df7f4acae..b48119ef724 100644 --- a/ax/core/data.py +++ b/ax/core/data.py @@ -28,7 +28,8 @@ TClassDecoderRegistry, TDecoderRegistry, ) -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws TBaseData = TypeVar("TBaseData", bound="BaseData") DF_REPR_MAX_LENGTH = 1000 @@ -235,7 +236,7 @@ def df_hash(self) -> str: str: The hash of the DataFrame. """ - return md5(not_none(self.df.to_json()).encode("utf-8")).hexdigest() + return md5(none_throws(self.df.to_json()).encode("utf-8")).hexdigest() def get_filtered_results( self: TBaseData, **filters: dict[str, Any] diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 6d33a26915d..3e06eae0f06 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -49,7 +49,8 @@ from ax.utils.common.logger import _round_floats_for_logging, get_logger from ax.utils.common.result import Err, Ok from ax.utils.common.timeutils import current_timestamp_in_millis -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws logger: logging.Logger = get_logger(__name__) @@ -366,7 +367,7 @@ def is_moo_problem(self) -> bool: """Whether the experiment's optimization config contains multiple objectives.""" if self.optimization_config is None: return False - return not_none(self.optimization_config).is_moo_problem + return none_throws(self.optimization_config).is_moo_problem @property def data_by_trial(self) -> dict[int, OrderedDict[int, Data]]: @@ -1296,7 +1297,7 @@ def warm_start_from_old_experiment( "Only experiments with 1-arm trials currently supported." ) self.search_space.check_membership( - not_none(trial.arm).parameters, + none_throws(trial.arm).parameters, raise_error=search_space_check_membership_raise_error, ) dat, ts = old_experiment.lookup_data_for_trial(trial_index=trial.index) @@ -1331,7 +1332,7 @@ def warm_start_from_old_experiment( {trial.index: new_trial.index}, inplace=True ) new_df["arm_name"].replace( - {not_none(trial.arm).name: not_none(new_trial.arm).name}, + {none_throws(trial.arm).name: none_throws(new_trial.arm).name}, inplace=True, ) # Attach updated data to new trial on experiment. @@ -1783,20 +1784,20 @@ def add_arm_and_prevent_naming_collision( # the arm name. Preserves all names not matching the automatic naming format. # experiment is not named, clear the arm's name. # `arm_index` is 0 since all trials are single-armed. - old_arm_name = not_none(old_trial.arm).name + old_arm_name = none_throws(old_trial.arm).name has_default_name = bool(old_arm_name == old_trial._get_default_name(arm_index=0)) if has_default_name: - new_arm = not_none(old_trial.arm).clone(clear_name=True) + new_arm = none_throws(old_trial.arm).clone(clear_name=True) if old_experiment_name is not None: new_arm.name = f"{old_arm_name}_{old_experiment_name}" new_trial.add_arm(new_arm) else: try: - new_trial.add_arm(not_none(old_trial.arm).clone(clear_name=False)) + new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=False)) except ValueError as e: warnings.warn( f"Attaching arm {old_trial.arm} to trial {new_trial} while preserving " f"its name failed with error: {e}. Retrying with `clear_name=True`.", stacklevel=2, ) - new_trial.add_arm(not_none(old_trial.arm).clone(clear_name=True)) + new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True)) diff --git a/ax/core/generation_strategy_interface.py b/ax/core/generation_strategy_interface.py index 636d0865707..7c4d14a8bb7 100644 --- a/ax/core/generation_strategy_interface.py +++ b/ax/core/generation_strategy_interface.py @@ -17,7 +17,7 @@ from ax.core.observation import ObservationFeatures from ax.exceptions.core import AxError, UnsupportedError from ax.utils.common.base import Base -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws class GenerationStrategyInterface(ABC, Base): @@ -148,7 +148,7 @@ def experiment(self) -> Experiment: """Experiment, currently set on this generation strategy.""" if self._experiment is None: raise AxError("No experiment set on generation strategy.") - return not_none(self._experiment) + return none_throws(self._experiment) @experiment.setter def experiment(self, experiment: Experiment) -> None: diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index ed08febbdbf..280af473480 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -30,7 +30,7 @@ from ax.exceptions.core import UnsupportedError from ax.utils.common.base import Base, SortableBase from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -303,7 +303,7 @@ def model_predictions_by_arm(self) -> dict[str, TModelPredictArm] | None: predictions: dict[str, TModelPredictArm] = {} for idx, cond in enumerate(self.arms): predictions[cond.signature] = extract_arm_predictions( - model_predictions=not_none(self._model_predictions), arm_idx=idx + model_predictions=none_throws(self._model_predictions), arm_idx=idx ) return predictions diff --git a/ax/core/objective.py b/ax/core/objective.py index 3530df1518b..c3d89946b85 100644 --- a/ax/core/objective.py +++ b/ax/core/objective.py @@ -17,7 +17,7 @@ from ax.exceptions.core import UserInputError from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -56,7 +56,7 @@ def __init__(self, metric: Metric, minimize: bool | None = None) -> None: f"{minimize=}." ) self._metric: Metric = metric - self.minimize: bool = not_none(minimize) + self.minimize: bool = none_throws(minimize) @property def metric(self) -> Metric: @@ -136,7 +136,7 @@ def __init__( objectives.append(Objective(metric=metric, minimize=minimize)) # pyre-fixme[4]: Attribute must be annotated. - self._objectives = not_none(objectives) + self._objectives = none_throws(objectives) # For now, assume all objectives are weighted equally. # This might be used in the future to change emphasis on the diff --git a/ax/core/observation.py b/ax/core/observation.py index 7dd6a7bd768..a28be42520e 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -28,7 +28,8 @@ from ax.utils.common.base import Base from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -308,7 +309,7 @@ def _observations_from_dataframe( metadata = trial._get_candidate_metadata(arm_name) or {} if Keys.TRIAL_COMPLETION_TIMESTAMP not in metadata: if trial._time_completed is not None: - metadata[Keys.TRIAL_COMPLETION_TIMESTAMP] = not_none( + metadata[Keys.TRIAL_COMPLETION_TIMESTAMP] = none_throws( trial._time_completed ).timestamp() obs_kwargs[Keys.METADATA] = metadata diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 9ee8445ef75..c2fbefd929c 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -20,8 +20,7 @@ from ax.exceptions.core import AxParameterWarning, UserInputError from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none -from pyre_extensions import assert_is_instance +from pyre_extensions import assert_is_instance, none_throws logger: Logger = get_logger(__name__) @@ -278,8 +277,8 @@ def __init__( raise UserInputError("RangeParameter type must be int or float.") self._parameter_type = parameter_type self._digits = digits - self._lower: TNumeric = not_none(self.cast(lower)) - self._upper: TNumeric = not_none(self.cast(upper)) + self._lower: TNumeric = none_throws(self.cast(lower)) + self._upper: TNumeric = none_throws(self.cast(upper)) self._log_scale = log_scale self._logit_scale = logit_scale self._is_fidelity = is_fidelity @@ -356,7 +355,7 @@ def upper(self, value: TNumeric) -> None: log_scale=self.log_scale, logit_scale=self.logit_scale, ) - self._upper = not_none(self.cast(value)) + self._upper = none_throws(self.cast(value)) @property def lower(self) -> TNumeric: @@ -375,7 +374,7 @@ def lower(self, value: TNumeric) -> None: log_scale=self.log_scale, logit_scale=self.logit_scale, ) - self._lower = not_none(self.cast(value)) + self._lower = none_throws(self.cast(value)) @property def digits(self) -> int | None: @@ -411,8 +410,8 @@ def update_range( if upper is None: upper = self._upper - cast_lower = not_none(self.cast(lower)) - cast_upper = not_none(self.cast(upper)) + cast_lower = none_throws(self.cast(lower)) + cast_upper = none_throws(self.cast(upper)) self._validate_range_param( lower=cast_lower, upper=cast_upper, @@ -427,8 +426,8 @@ def set_digits(self, digits: int) -> RangeParameter: self._digits = digits # Re-scale min and max to new digits definition - cast_lower = not_none(self.cast(self._lower)) - cast_upper = not_none(self.cast(self._upper)) + cast_lower = none_throws(self.cast(self._lower)) + cast_upper = none_throws(self.cast(self._upper)) if cast_lower >= cast_upper: raise UserInputError( f"Lower bound {cast_lower} is >= upper bound {cast_upper}." @@ -479,7 +478,7 @@ def is_valid_type(self, value: TParamValue) -> bool: # This might have issues with ints > 2^24 if self.parameter_type is ParameterType.INT: - return isinstance(value, int) or float(not_none(value)).is_integer() + return isinstance(value, int) or float(none_throws(value)).is_integer() return True def clone(self) -> RangeParameter: @@ -499,7 +498,7 @@ def cast(self, value: TParamValue) -> TNumeric | None: if value is None: return None if self.parameter_type is ParameterType.FLOAT and self._digits is not None: - return round(float(value), not_none(self._digits)) + return round(float(value), none_throws(self._digits)) return assert_is_instance(self.python_type(value), TNumeric) def __repr__(self) -> str: @@ -607,7 +606,7 @@ def __init__( else self._get_default_sort_values_and_warn() ) if self.sort_values: - values = cast(list[TParamValue], sorted([not_none(v) for v in values])) + values = cast(list[TParamValue], sorted([none_throws(v) for v in values])) self._values: list[TParamValue] = self._cast_values(values) if dependents: @@ -714,7 +713,7 @@ def dependents(self) -> dict[TParamValue, list[str]]: raise NotImplementedError( "Only hierarchical parameters support the `dependents` property." ) - return not_none(self._dependents) + return none_throws(self._dependents) def _cast_values(self, values: list[TParamValue]) -> list[TParamValue]: return [self.cast(value) for value in values] @@ -829,7 +828,7 @@ def dependents(self) -> dict[TParamValue, list[str]]: raise NotImplementedError( "Only hierarchical parameters support the `dependents` property." ) - return not_none(self._dependents) + return none_throws(self._dependents) def clone(self) -> FixedParameter: return FixedParameter( diff --git a/ax/core/search_space.py b/ax/core/search_space.py index b77cfd8dff2..37c795fdb2a 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -40,7 +40,7 @@ from ax.utils.common.base import Base from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws from scipy.special import expit, logit @@ -241,7 +241,7 @@ def check_membership( # parameter constraints only accept numeric parameters numerical_param_dict = { - name: float(not_none(value)) + name: float(none_throws(value)) for name, value in parameterization.items() if self.parameters[name].is_numeric } @@ -343,7 +343,7 @@ def construct_arm( raise ValueError( f"`{p_value}` is not a valid value for parameter {p_name}." ) - final_parameters.update(not_none(parameters)) + final_parameters.update(none_throws(parameters)) return Arm(parameters=final_parameters, name=name) def clone(self) -> SearchSpace: @@ -530,7 +530,7 @@ def flatten_observation_features( if has_full_parameterization: # If full parameterization is recorded, use it to fill in missing values. - full_parameterization = not_none(obs_feats.metadata)[ + full_parameterization = none_throws(obs_feats.metadata)[ Keys.FULL_PARAMETERIZATION ] obs_feats.parameters = {**full_parameterization, **obs_feats.parameters} @@ -629,7 +629,7 @@ def hierarchical_structure_str(self, parameter_names_only: bool = False) -> str: def _hrepr(param: Parameter | None, value: str | None, level: int) -> str: is_level_param = param and not value if is_level_param: - param = not_none(param) + param = none_throws(param) node_name = f"{param.name if parameter_names_only else param}" ret = "\t" * level + node_name + "\n" if param.is_hierarchical: @@ -642,7 +642,7 @@ def _hrepr(param: Parameter | None, value: str | None, level: int) -> str: level=level + 2, ) else: - value = not_none(value) + value = none_throws(value) node_name = f"({value})" ret = "\t" * level + node_name + "\n" diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c39bce9f9b4..2c304353502 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -31,7 +31,7 @@ from ax.core.trial import Trial from ax.core.types import TParameterization from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws class ObservationsTest(TestCase): @@ -842,14 +842,14 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None: self.assertEqual(obs.arm_name, obs_truth["arm_name"][i]) if i == 0: self.assertEqual( - not_none(obs.features.start_time).strftime("%Y-%m-%d %X"), + none_throws(obs.features.start_time).strftime("%Y-%m-%d %X"), "2024-03-20 08:45:00", ) self.assertIsNone(obs.features.end_time) else: self.assertIsNone(obs.features.start_time) self.assertEqual( - not_none(obs.features.end_time).strftime("%Y-%m-%d %X"), + none_throws(obs.features.end_time).strftime("%Y-%m-%d %X"), "2024-03-20 08:46:00", ) diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index 465609c2e9c..28c9d5293d9 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -18,7 +18,7 @@ ) from ax.exceptions.core import AxParameterWarning, AxWarning, UserInputError from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws class RangeParameterTest(TestCase): @@ -358,7 +358,7 @@ def test_Clone(self) -> None: dependents={"foo": ["y", "z"], "bar": ["w"]}, ) param_clone = param.clone() - not_none(param_clone._dependents)["foo"] = ["y"] + none_throws(param_clone._dependents)["foo"] = ["y"] self.assertEqual(param.dependents, {"foo": ["y", "z"], "bar": ["w"]}) self.assertEqual(param_clone.dependents, {"foo": ["y"], "bar": ["w"]}) diff --git a/ax/core/trial.py b/ax/core/trial.py index a6980d1f603..c81b80e1cb3 100644 --- a/ax/core/trial.py +++ b/ax/core/trial.py @@ -22,8 +22,7 @@ from ax.exceptions.core import UnsupportedError from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import _round_floats_for_logging, get_logger -from ax.utils.common.typeutils import not_none -from pyre_extensions import override +from pyre_extensions import none_throws, override logger: Logger = get_logger(__name__) @@ -100,7 +99,7 @@ def arm(self) -> Arm | None: if self.generator_run is None: return None - generator_run = not_none(self.generator_run) + generator_run = none_throws(self.generator_run) if len(generator_run.arms) == 0: return None elif len(generator_run.arms) > 1: @@ -193,7 +192,7 @@ def arms_by_name(self) -> dict[str, Arm]: def abandoned_arms(self) -> list[Arm]: """Abandoned arms attached to this trial.""" return ( - [not_none(self.arm)] + [none_throws(self.arm)] if self.generator_run is not None and self.arm is not None and self.is_abandoned @@ -251,7 +250,7 @@ def _get_candidate_metadata_from_all_generator_runs( if gr is None or gr.candidate_metadata_by_arm_signature is None: return {} - cand_metadata = not_none(gr.candidate_metadata_by_arm_signature) + cand_metadata = none_throws(gr.candidate_metadata_by_arm_signature) return {a.name: cand_metadata.get(a.signature) for a in gr.arms} def _get_candidate_metadata(self, arm_name: str) -> TCandidateMetadata: @@ -267,7 +266,7 @@ def _get_candidate_metadata(self, arm_name: str) -> TCandidateMetadata: return None arm = gr.arms[0] - return not_none(gr.candidate_metadata_by_arm_signature).get(arm.signature) + return none_throws(gr.candidate_metadata_by_arm_signature).get(arm.signature) def validate_data_for_trial(self, data: Data) -> None: """Utility method to validate data before further processing.""" @@ -310,7 +309,7 @@ def update_trial_data( Returns: A string message summarizing the update. """ - arm_name = not_none(self.arm).name + arm_name = none_throws(self.arm).name sample_sizes = {arm_name: sample_size} if sample_size else {} raw_data_by_arm = {arm_name: raw_data} diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index a5f073f9158..e5989b02eac 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -33,7 +33,8 @@ from ax.models.torch_base import TorchModel from ax.utils.common.base import Base from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -290,7 +291,7 @@ def is_eligible_any( self._log_and_return_completed_trials( logger=logger, num_completed=num_completed, - min_curves=not_none(self.min_curves), + min_curves=none_throws(self.min_curves), ) return False @@ -388,7 +389,7 @@ def _default_objective_and_direction( self, experiment: Experiment ) -> tuple[str, bool]: if self.metric_names is None: - optimization_config = not_none(experiment.optimization_config) + optimization_config = none_throws(experiment.optimization_config) objective = optimization_config.objective # if multi-objective optimization, infer as first objective diff --git a/ax/early_stopping/strategies/percentile.py b/ax/early_stopping/strategies/percentile.py index 36fad77a13c..f6620d0bb36 100644 --- a/ax/early_stopping/strategies/percentile.py +++ b/ax/early_stopping/strategies/percentile.py @@ -16,7 +16,7 @@ from ax.early_stopping.utils import align_partial_results from ax.exceptions.core import UnsupportedError from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -193,7 +193,7 @@ def _should_stop_trial_early( # dropna() here will exclude trials that have not made it to the # last progression of the trial under consideration, and therefore # can't be included in the comparison - df_trial = not_none(df[trial_index].dropna()) + df_trial = none_throws(df[trial_index].dropna()) trial_last_prog = df_trial.index.max() data_at_last_progression = df.loc[trial_last_prog].dropna() logger.info( diff --git a/ax/early_stopping/tests/test_strategies.py b/ax/early_stopping/tests/test_strategies.py index 51edb756b64..2afd9ae3d83 100644 --- a/ax/early_stopping/tests/test_strategies.py +++ b/ax/early_stopping/tests/test_strategies.py @@ -28,7 +28,7 @@ from ax.early_stopping.utils import align_partial_results from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( get_branin_arms, get_branin_experiment, @@ -698,7 +698,7 @@ def _evaluate_early_stopping_with_df( ) -> dict[int, str | None]: """Helper function for testing PercentileEarlyStoppingStrategy on an arbitrary (MapData) df.""" - data = not_none( + data = none_throws( early_stopping_strategy._check_validity_and_get_data(experiment, [metric_name]) ) metric_to_aligned_means, _ = align_partial_results( diff --git a/ax/global_stopping/strategies/improvement.py b/ax/global_stopping/strategies/improvement.py index 77b4d7c2af1..afe2236cefd 100644 --- a/ax/global_stopping/strategies/improvement.py +++ b/ax/global_stopping/strategies/improvement.py @@ -27,7 +27,8 @@ infer_reference_point_from_experiment, ) from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -188,7 +189,7 @@ def _should_stop_optimization( return self._should_stop_moo( experiment=experiment, trial_to_check=trial_to_check, - objective_thresholds=not_none(objective_thresholds), + objective_thresholds=none_throws(objective_thresholds), ) else: return self._should_stop_single_objective( @@ -347,7 +348,7 @@ def constraint_satisfaction(trial: BaseTrial) -> bool: Returns: A boolean which is True iff all outcome constraints are satisfied. """ - outcome_constraints = not_none( + outcome_constraints = none_throws( trial.experiment.optimization_config ).outcome_constraints if len(outcome_constraints) == 0: diff --git a/ax/global_stopping/tests/test_strategies.py b/ax/global_stopping/tests/test_strategies.py index 34847369f87..6b5b0141f9a 100644 --- a/ax/global_stopping/tests/test_strategies.py +++ b/ax/global_stopping/tests/test_strategies.py @@ -30,8 +30,9 @@ ImprovementGlobalStoppingStrategy, ) from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import get_experiment, get_experiment_with_data +from pyre_extensions import none_throws class TestImprovementGlobalStoppingStrategy(TestCase): @@ -110,21 +111,21 @@ def _get_data_for_trial( { "trial_index": trial.index, "metric_name": "m1", - "arm_name": not_none(trial.arm).name, + "arm_name": none_throws(trial.arm).name, "mean": values[0], "sem": 0.0, }, { "trial_index": trial.index, "metric_name": "m2", - "arm_name": not_none(trial.arm).name, + "arm_name": none_throws(trial.arm).name, "mean": values[1], "sem": 0.0, }, { "trial_index": trial.index, "metric_name": "m3", - "arm_name": not_none(trial.arm).name, + "arm_name": none_throws(trial.arm).name, "mean": values[2], "sem": 0.0, }, diff --git a/ax/metrics/branin_map.py b/ax/metrics/branin_map.py index 686e3bca353..c90919c39b4 100644 --- a/ax/metrics/branin_map.py +++ b/ax/metrics/branin_map.py @@ -23,8 +23,9 @@ from ax.core.metric import MetricFetchE from ax.metrics.noisy_function_map import NoisyFunctionMapMetric from ax.utils.common.result import Err, Ok -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.measurement.synthetic_functions import branin +from pyre_extensions import none_throws FIDELITY = [0.1, 0.4, 0.7, 1.0] @@ -122,7 +123,7 @@ def f(self, x: npt.NDArray, timestamp: int) -> Mapping[str, Any]: x1, x2 = x if self.rate is not None: - weight = 1.0 + np.exp(-not_none(self.rate) * timestamp) + weight = 1.0 + np.exp(-none_throws(self.rate) * timestamp) else: weight = 1.0 diff --git a/ax/metrics/chemistry.py b/ax/metrics/chemistry.py index 9d1c3862ad2..ceb94044517 100644 --- a/ax/metrics/chemistry.py +++ b/ax/metrics/chemistry.py @@ -45,7 +45,7 @@ from ax.core.metric import Metric, MetricFetchE, MetricFetchResult from ax.core.types import TParameterization, TParamValue from ax.utils.common.result import Err, Ok -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws class ChemistryProblemType(Enum): @@ -112,7 +112,7 @@ def clone(self) -> ChemistryMetric: name=self._name, noiseless=self.noiseless, problem_type=self.problem_type, - lower_is_better=not_none(self.lower_is_better), + lower_is_better=none_throws(self.lower_is_better), ) def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: diff --git a/ax/metrics/torchx.py b/ax/metrics/torchx.py index a9c60bc75d6..98455631602 100644 --- a/ax/metrics/torchx.py +++ b/ax/metrics/torchx.py @@ -15,7 +15,7 @@ from ax.core.metric import Metric, MetricFetchE, MetricFetchResult from ax.utils.common.logger import get_logger from ax.utils.common.result import Err, Ok -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -70,7 +70,7 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult ) df_dict = { - "arm_name": not_none(cast(Trial, trial).arm).name, + "arm_name": none_throws(cast(Trial, trial).arm).name, "trial_index": trial.index, "metric_name": self.name, "mean": mean, diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 2c2cca8f8fb..cb6695f08d9 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -40,8 +40,9 @@ from ax.modelbridge.transforms.cast import Cast from ax.models.types import TConfig from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.exceptions.warnings import InputDataWarning +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -995,13 +996,13 @@ def _get_serialized_model_state(self) -> dict[str, Any]: """Obtains the state of the underlying model (if using a stateful one) in a readily JSON-serializable form. """ - model = not_none(self.model) + model = none_throws(self.model) return model.serialize_state(raw_state=model._get_state()) def _deserialize_model_state( self, serialized_state: dict[str, Any] ) -> dict[str, Any]: - model = not_none(self.model) + model = none_throws(self.model) return model.deserialize_state(serialized_state=serialized_state) def feature_importances(self, metric_name: str) -> dict[str, float]: @@ -1179,7 +1180,7 @@ def _get_status_quo_by_trial( if status_quo_name is not None: # identify status quo by arm name trial_idx_to_sq_data = { - int(not_none(obs.features.trial_index)): obs.data + int(none_throws(obs.features.trial_index)): obs.data for obs in observations if obs.arm_name == status_quo_name } @@ -1189,7 +1190,7 @@ def _get_status_quo_by_trial( status_quo_features.parameters, sort_keys=True ) trial_idx_to_sq_data = { - int(not_none(obs.features.trial_index)): obs.data + int(none_throws(obs.features.trial_index)): obs.data for obs in observations if json.dumps(obs.features.parameters, sort_keys=True) == status_quo_signature diff --git a/ax/modelbridge/best_model_selector.py b/ax/modelbridge/best_model_selector.py index fef28a68e03..7100a241e79 100644 --- a/ax/modelbridge/best_model_selector.py +++ b/ax/modelbridge/best_model_selector.py @@ -19,7 +19,7 @@ from ax.exceptions.core import UserInputError from ax.modelbridge.model_spec import ModelSpec from ax.utils.common.base import Base -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. ARRAYLIKE = Union[np.ndarray, list[float], list[np.ndarray]] @@ -122,7 +122,7 @@ def best_model(self, model_specs: list[ModelSpec]) -> ModelSpec: model_spec.cross_validate(model_cv_kwargs=self.model_cv_kwargs) aggregated_diagnostic_values = [ self.metric_aggregation( - list(not_none(model_spec.diagnostics)[self.diagnostic].values()) + list(none_throws(model_spec.diagnostics)[self.diagnostic].values()) ) for model_spec in model_specs ] diff --git a/ax/modelbridge/dispatch_utils.py b/ax/modelbridge/dispatch_utils.py index f566c4fbab9..548369a866c 100644 --- a/ax/modelbridge/dispatch_utils.py +++ b/ax/modelbridge/dispatch_utils.py @@ -25,7 +25,7 @@ from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.deprecation import _validate_force_random_search from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: logging.Logger = get_logger(__name__) @@ -192,7 +192,7 @@ def _suggest_gp_model( num_possible_points *= num_param_discrete_values if should_enumerate_param: - num_enumerated_combinations *= not_none(num_param_discrete_values) + num_enumerated_combinations *= none_throws(num_param_discrete_values) else: all_parameters_are_enumerated = False @@ -281,7 +281,7 @@ def calculate_num_initialization_trials( ret = 2 * num_tunable_parameters if num_trials is not None: - ret = min(ret, not_none(num_trials) // 5) + ret = min(ret, none_throws(num_trials) // 5) return max(ret, 5) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 8ae76f4ab37..b4f399b3d67 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -43,7 +43,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.serialization import SerializationMixin -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -212,7 +212,7 @@ def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrat raise ValueError( "Generation strategy has not been initialized on this node." ) - return not_none(self._generation_strategy) + return none_throws(self._generation_strategy) @property def transition_criteria(self) -> Sequence[TransitionCriterion]: @@ -504,7 +504,7 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec: raise UserInputError(MISSING_MODEL_SELECTOR_MESSAGE) return self.model_specs[0] - best_model = not_none(self.best_model_selector).best_model( + best_model = none_throws(self.best_model_selector).best_model( model_specs=self.model_specs, ) return best_model diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 21143fa7b69..ffaa9428b43 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -238,7 +238,7 @@ def experiment(self) -> Experiment: """Experiment, currently set on this generation strategy.""" if self._experiment is None: raise ValueError("No experiment set on generation strategy.") - return not_none(self._experiment) + return none_throws(self._experiment) @experiment.setter def experiment(self, experiment: Experiment) -> None: diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index a3a6391767a..04ca083d476 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -36,7 +36,7 @@ from ax.models.torch_base import TorchModel from ax.models.types import TConfig from ax.utils.common.constants import Keys -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws # A mapping from map_key to its target (or final) value; by default, @@ -173,7 +173,7 @@ def _predict( X = observation_features_to_array( self.parameters_with_map_keys, observation_features ) - f, cov = not_none(self.model).predict(X=self._array_to_tensor(X)) + f, cov = none_throws(self.model).predict(X=self._array_to_tensor(X)) f = f.detach().cpu().clone().numpy() cov = cov.detach().cpu().clone().numpy() # Convert resulting arrays to observations diff --git a/ax/modelbridge/model_spec.py b/ax/modelbridge/model_spec.py index f3f3ac9fc02..333b275caa6 100644 --- a/ax/modelbridge/model_spec.py +++ b/ax/modelbridge/model_spec.py @@ -38,7 +38,7 @@ get_function_argument_names, ) from ax.utils.common.serialization import SerializationMixin -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws TModelFactory = Callable[..., ModelBridge] @@ -91,7 +91,7 @@ def __post_init__(self) -> None: def fitted_model(self) -> ModelBridge: """Returns the fitted Ax model, asserting fit() was called""" self._assert_fitted() - return not_none(self._fitted_model) + return none_throws(self._fitted_model) @property def fixed_features(self) -> ObservationFeatures | None: @@ -351,7 +351,7 @@ def __post_init__(self) -> None: try: # `model` is defined via a factory function. # pyre-ignore[16]: Anonymous callable has no attribute `__name__`. - self.model_key_override = not_none(self.factory_function).__name__ + self.model_key_override = none_throws(self.factory_function).__name__ except Exception: raise TypeError( f"{self.factory_function} is not a valid function, cannot extract " @@ -377,7 +377,7 @@ def fit( model kwargs set on the model spec, alongside any passed down as kwargs to this function (local kwargs take precedent) """ - factory_function = not_none(self.factory_function) + factory_function = none_throws(self.factory_function) all_kwargs = deepcopy(self.model_kwargs) all_kwargs.update(model_kwargs) self._fitted_model = factory_function( diff --git a/ax/modelbridge/modelbridge_utils.py b/ax/modelbridge/modelbridge_utils.py index cdea70c222d..550068d02ae 100644 --- a/ax/modelbridge/modelbridge_utils.py +++ b/ax/modelbridge/modelbridge_utils.py @@ -60,7 +60,6 @@ checked_cast, checked_cast_optional, checked_cast_to_tuple, - not_none, ) from botorch.acquisition.multi_objective.multi_output_risk_measures import ( IndependentCVaR, @@ -80,6 +79,8 @@ from botorch.utils.multi_objective.box_decompositions.dominated import ( DominatedPartitioning, ) + +from pyre_extensions import none_throws from torch import Tensor logger: Logger = get_logger(__name__) @@ -144,7 +145,7 @@ def check_has_multi_objective_and_data( optimization_config: OptimizationConfig | None = None, ) -> None: """Raise an error if not using a `MultiObjective` or if the data is empty.""" - optimization_config = not_none( + optimization_config = none_throws( optimization_config or experiment.optimization_config ) if not isinstance(optimization_config.objective, MultiObjective): @@ -729,7 +730,7 @@ def get_pareto_frontier_and_configs( observation_features=observation_features ) Y, Yvar = observation_data_to_array( - outcomes=modelbridge.outcomes, observation_data=not_none(observation_data) + outcomes=modelbridge.outcomes, observation_data=none_throws(observation_data) ) Y, Yvar = (array_to_tensor(Y), array_to_tensor(Yvar)) if arm_names is None: @@ -793,7 +794,7 @@ def get_pareto_frontier_and_configs( f, cov = f.detach().cpu().clone(), cov.detach().cpu().clone() indx = indx.tolist() frontier_observation_data = array_to_observation_data( - f=f.numpy(), cov=cov.numpy(), outcomes=not_none(modelbridge.outcomes) + f=f.numpy(), cov=cov.numpy(), outcomes=none_throws(modelbridge.outcomes) ) # Construct observations frontier_observations = [] @@ -1030,7 +1031,7 @@ def hypervolume( ) # Apply appropriate weights and thresholds obj, obj_t = get_weighted_mc_objective_and_objective_thresholds( - objective_weights=obj_w, objective_thresholds=not_none(obj_t) + objective_weights=obj_w, objective_thresholds=none_throws(obj_t) ) f_t = obj(f) obj_mask = obj_w.nonzero().view(-1) diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 42745b8c1b6..6682297ca5d 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -72,9 +72,10 @@ ) from ax.utils.common.logger import get_logger from ax.utils.common.serialization import callable_from_reference, callable_to_reference -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -276,7 +277,7 @@ def __call__( assert self.value in MODEL_KEY_TO_MODEL_SETUP, f"Unknown model {self.value}" # All model bridges require either a search space or an experiment. assert search_space or experiment, "Search space or experiment required." - search_space = search_space or not_none(experiment).search_space + search_space = search_space or none_throws(experiment).search_space model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value] model_class = model_setup_info.model_class bridge_class = model_setup_info.bridge_class @@ -335,7 +336,7 @@ def __call__( # Create model bridge with the consolidated kwargs. model_bridge = bridge_class( - search_space=search_space or not_none(experiment).search_space, + search_space=search_space or none_throws(experiment).search_space, experiment=experiment, data=data, model=model, @@ -362,7 +363,7 @@ def view_defaults(self) -> tuple[dict[str, Any], dict[str, Any]]: Returns: A tuple of default keyword arguments for the model and the model bridge. """ - model_setup_info = not_none(MODEL_KEY_TO_MODEL_SETUP.get(self.value)) + model_setup_info = none_throws(MODEL_KEY_TO_MODEL_SETUP.get(self.value)) return ( self._get_model_kwargs(info=model_setup_info), self._get_bridge_kwargs(info=model_setup_info), diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index 34ca5f1f751..1445439419c 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -33,7 +33,6 @@ from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import ( get_branin_experiment, get_branin_experiment_with_multi_objective, @@ -57,6 +56,7 @@ transform_2, ) from botorch.exceptions.warnings import InputDataWarning +from pyre_extensions import none_throws class BaseModelBridgeTest(TestCase): @@ -816,7 +816,7 @@ def test_set_status_quo(self) -> None: # status_quo is set self.assertIsNotNone(modelbridge.status_quo) # Status quo name is logged - self.assertEqual(modelbridge._status_quo_name, not_none(exp.status_quo).name) + self.assertEqual(modelbridge._status_quo_name, none_throws(exp.status_quo).name) # experiment with multiple status quos in different trials exp = get_branin_experiment( @@ -837,12 +837,12 @@ def test_set_status_quo(self) -> None: # status_quo is not set self.assertIsNone(modelbridge.status_quo) # Status quo name can still be logged - self.assertEqual(modelbridge._status_quo_name, not_none(exp.status_quo).name) + self.assertEqual(modelbridge._status_quo_name, none_throws(exp.status_quo).name) # a unique status_quo can be identified (by trial index) # if status_quo_features is specified status_quo_features = ObservationFeatures( - parameters=not_none(exp.status_quo).parameters, + parameters=none_throws(exp.status_quo).parameters, trial_index=0, ) modelbridge = ModelBridge( diff --git a/ax/modelbridge/tests/test_dispatch_utils.py b/ax/modelbridge/tests/test_dispatch_utils.py index 937dc46bbff..69827dd1bcc 100644 --- a/ax/modelbridge/tests/test_dispatch_utils.py +++ b/ax/modelbridge/tests/test_dispatch_utils.py @@ -25,7 +25,6 @@ from ax.modelbridge.transforms.winsorize import Winsorize from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import ( get_branin_search_space, get_discrete_search_space, @@ -37,6 +36,7 @@ run_branin_experiment_with_generation_strategy, ) from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import none_throws class TestDispatchUtils(TestCase): @@ -119,7 +119,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) self.assertEqual(sobol_gpei._steps[0].num_trials, 5) self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) - model_kwargs = not_none(sobol_gpei._steps[1].model_kwargs) + model_kwargs = none_throws(sobol_gpei._steps[1].model_kwargs) self.assertEqual( set(model_kwargs.keys()), { @@ -234,7 +234,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(moo_mixed._steps[0].model, Models.SOBOL) self.assertEqual(moo_mixed._steps[0].num_trials, 5) self.assertEqual(moo_mixed._steps[1].model, Models.BO_MIXED) - model_kwargs = not_none(moo_mixed._steps[1].model_kwargs) + model_kwargs = none_throws(moo_mixed._steps[1].model_kwargs) self.assertEqual( set(model_kwargs.keys()), { @@ -328,10 +328,10 @@ def test_make_botorch_step_extra(self) -> None: } bo_step = _make_botorch_step(model_kwargs=model_kwargs) self.assertEqual( - not_none(bo_step.model_kwargs)["transforms"], [Winsorize, LogY] + none_throws(bo_step.model_kwargs)["transforms"], [Winsorize, LogY] ) self.assertEqual( - not_none(bo_step.model_kwargs)["transform_configs"], + none_throws(bo_step.model_kwargs)["transform_configs"], { "LogY": {"metrics": ["metric_1"]}, "Derelativize": {"use_raw_status_quo": False}, @@ -351,18 +351,18 @@ def test_disable_progbar(self) -> None: self.assertEqual(sobol_saasbo._steps[0].model, Models.SOBOL) self.assertNotIn( "disable_progbar", - not_none(sobol_saasbo._steps[0].model_kwargs), + none_throws(sobol_saasbo._steps[0].model_kwargs), ) self.assertEqual(sobol_saasbo._steps[1].model, Models.SAASBO) self.assertNotIn( "disable_progbar", - not_none(sobol_saasbo._steps[0].model_kwargs), + none_throws(sobol_saasbo._steps[0].model_kwargs), ) # TODO[T164389105] Rewrite choose_generation_strategy to be MBM first # Once this task is complete we should check disable_progbar gets # propagated correctly (right now it is dropped). Ex.: # self.assertEqual( - # not_none(sobol_saasbo._steps[1].model_kwargs)["disable_progbar"], + # none_throws(sobol_saasbo._steps[1].model_kwargs)["disable_progbar"], # disable_progbar, # ) run_branin_experiment_with_generation_strategy( @@ -382,12 +382,12 @@ def test_disable_progbar_for_non_saasbo_discards_the_model_kwarg(self) -> None: self.assertEqual(gp_saasbo._steps[0].model, Models.SOBOL) self.assertNotIn( "disable_progbar", - not_none(gp_saasbo._steps[0].model_kwargs), + none_throws(gp_saasbo._steps[0].model_kwargs), ) self.assertEqual(gp_saasbo._steps[1].model, Models.BOTORCH_MODULAR) self.assertNotIn( "disable_progbar", - not_none(gp_saasbo._steps[1].model_kwargs), + none_throws(gp_saasbo._steps[1].model_kwargs), ) run_branin_experiment_with_generation_strategy( generation_strategy=gp_saasbo @@ -399,7 +399,7 @@ def test_setting_random_seed(self) -> None: ) sobol.gen(experiment=get_experiment(), n=1) # First model is actually a bridge, second is the Sobol engine. - self.assertEqual(not_none(sobol.model).model.seed, 9) + self.assertEqual(none_throws(sobol.model).model.seed, 9) with self.subTest("warns if use_saasbo is true"): with self.assertLogs( @@ -492,7 +492,7 @@ def test_winsorization(self) -> None: search_space=get_branin_search_space(), winsorization_config=WinsorizationConfig(upper_quantile_margin=2), ) - tc = not_none(winsorized._steps[1].model_kwargs).get("transform_configs") + tc = none_throws(winsorized._steps[1].model_kwargs).get("transform_configs") self.assertIn("Winsorize", tc) self.assertDictEqual( tc["Winsorize"], @@ -512,7 +512,7 @@ def test_winsorization(self) -> None: search_space=get_branin_search_space(), derelativize_with_raw_status_quo=True, ) - tc = not_none(winsorized._steps[1].model_kwargs).get("transform_configs") + tc = none_throws(winsorized._steps[1].model_kwargs).get("transform_configs") self.assertIn( "Winsorize", tc, @@ -539,7 +539,7 @@ def test_no_winzorization_wins(self) -> None: self.assertNotIn( "Winsorize", - not_none(unwinsorized._steps[1].model_kwargs)["transform_configs"], + none_throws(unwinsorized._steps[1].model_kwargs)["transform_configs"], ) def test_num_trials(self) -> None: @@ -735,18 +735,18 @@ def test_jit_compile(self) -> None: self.assertEqual(sobol_saasbo._steps[0].model, Models.SOBOL) self.assertNotIn( "jit_compile", - not_none(sobol_saasbo._steps[0].model_kwargs), + none_throws(sobol_saasbo._steps[0].model_kwargs), ) self.assertEqual(sobol_saasbo._steps[1].model, Models.SAASBO) self.assertNotIn( "jit_compile", - not_none(sobol_saasbo._steps[0].model_kwargs), + none_throws(sobol_saasbo._steps[0].model_kwargs), ) # TODO[T164389105] Rewrite choose_generation_strategy to be MBM first # Once this task is complete we should check jit_compile gets # propagated correctly (right now it is dropped). Ex.: # self.assertEqual( - # not_none(sobol_saasbo._steps[1].model_kwargs)["jit_compile"], + # none_throws(sobol_saasbo._steps[1].model_kwargs)["jit_compile"], # jit_compile, # ) run_branin_experiment_with_generation_strategy( @@ -766,12 +766,12 @@ def test_jit_compile_for_non_saasbo_discards_the_model_kwarg(self) -> None: self.assertEqual(gp_saasbo._steps[0].model, Models.SOBOL) self.assertNotIn( "jit_compile", - not_none(gp_saasbo._steps[0].model_kwargs), + none_throws(gp_saasbo._steps[0].model_kwargs), ) self.assertEqual(gp_saasbo._steps[1].model, Models.BOTORCH_MODULAR) self.assertNotIn( "jit_compile", - not_none(gp_saasbo._steps[1].model_kwargs), + none_throws(gp_saasbo._steps[1].model_kwargs), ) run_branin_experiment_with_generation_strategy( generation_strategy=gp_saasbo, diff --git a/ax/modelbridge/tests/test_external_generation_node.py b/ax/modelbridge/tests/test_external_generation_node.py index 2bd032f67fc..bfa446fcb07 100644 --- a/ax/modelbridge/tests/test_external_generation_node.py +++ b/ax/modelbridge/tests/test_external_generation_node.py @@ -18,12 +18,12 @@ from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.random import RandomModelBridge from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import ( get_branin_data, get_branin_experiment, get_sobol, ) +from pyre_extensions import none_throws class DummyNode(ExternalGenerationNode): @@ -43,7 +43,7 @@ def get_next_candidate( ) -> TParameterization: self.gen_count += 1 self.last_pending = deepcopy(pending_parameters) - return not_none(self.generator).gen(n=1).arms[0].parameters + return none_throws(self.generator).gen(n=1).arms[0].parameters class TestExternalGenerationNode(TestCase): @@ -97,7 +97,7 @@ def test_generation(self) -> None: self.assertEqual(node.gen_count, 9) self.assertEqual(node.update_count, 5) self.assertEqual(len(gr.arms), 5) - self.assertGreater(not_none(gr.fit_time), 0.0) - self.assertGreater(not_none(gr.gen_time), 0.0) + self.assertGreater(none_throws(gr.fit_time), 0.0) + self.assertGreater(none_throws(gr.gen_time), 0.0) self.assertEqual(gr._model_key, "dummy") self.assertEqual(len(node.last_pending), 4) diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 1bbb3b8308e..5fe3282ce41 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -59,7 +59,7 @@ from ax.utils.common.equality import same_elements from ax.utils.common.mock import mock_patch_method_original from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( get_branin_data, get_branin_experiment, @@ -69,6 +69,7 @@ get_hierarchical_search_space_experiment, ) from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import none_throws class TestGenerationStrategyWithoutModelBridgeMocks(TestCase): @@ -573,9 +574,9 @@ def test_sobol_MBM_strategy(self) -> None: "fit_on_init": True, }, ) - ms = not_none(g._model_state_after_gen).copy() + ms = none_throws(g._model_state_after_gen).copy() # Compare the model state to Sobol state. - sobol_model = not_none(gs.model).model + sobol_model = none_throws(gs.model).model self.assertTrue( np.array_equal( ms.pop("generated_points"), sobol_model.generated_points @@ -737,18 +738,18 @@ def test_trials_as_df(self) -> None: self.assertIsNone(sobol_generation_strategy.trials_as_df) # Now the trial should appear in the DF. trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp)) - trials_df = not_none(sobol_generation_strategy.trials_as_df) + trials_df = none_throws(sobol_generation_strategy.trials_as_df) self.assertFalse(trials_df.empty) self.assertEqual(trials_df.head()["Trial Status"][0], "CANDIDATE") # Changes in trial status should be reflected in the DF. trial._status = TrialStatus.RUNNING - trials_df = not_none(sobol_generation_strategy.trials_as_df) + trials_df = none_throws(sobol_generation_strategy.trials_as_df) self.assertEqual(trials_df.head()["Trial Status"][0], "RUNNING") # Check that rows are present for step 0 and 1 after moving to step 1 for _i in range(3): # attach necessary trials to fill up the Generation Strategy trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp)) - trials_df = not_none(sobol_generation_strategy.trials_as_df) + trials_df = none_throws(sobol_generation_strategy.trials_as_df) self.assertEqual(trials_df.head()["Generation Step"][0], ["GenerationStep_0"]) self.assertEqual(trials_df.head()["Generation Step"][2], ["GenerationStep_1"]) @@ -962,7 +963,7 @@ def test_gen_multiple(self) -> None: # Check case with pending features initially specified; we should get two # GRs now (remaining in Sobol step) even though we requested 3. - original_pending = not_none(get_pending(experiment=exp)) + original_pending = none_throws(get_pending(experiment=exp)) first_3_trials_obs_feats = [ ObservationFeatures.from_arm(arm=a, trial_index=idx) for idx, trial in exp.trials.items() @@ -1047,7 +1048,7 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None: # Check case with pending features initially specified; we should get two # GRs now (remaining in Sobol step) even though we requested 3. - original_pending = not_none(get_pending(experiment=exp)) + original_pending = none_throws(get_pending(experiment=exp)) first_3_trials_obs_feats = [ ObservationFeatures.from_arm(arm=a, trial_index=idx) for idx, trial in exp.trials.items() @@ -1410,7 +1411,7 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None: model_spec_gen_mock.reset_mock() # check that the pending points line up - original_pending = not_none(get_pending(experiment=exp)) + original_pending = none_throws(get_pending(experiment=exp)) first_3_trials_obs_feats = [ ObservationFeatures.from_arm(arm=a, trial_index=idx) for idx, trial in exp.trials.items() @@ -1518,9 +1519,9 @@ def test_gs_with_generation_nodes(self) -> None: "fit_on_init": True, }, ) - ms = not_none(g._model_state_after_gen).copy() + ms = none_throws(g._model_state_after_gen).copy() # Compare the model state to Sobol state. - sobol_model = not_none(self.sobol_MBM_GS_nodes.model).model + sobol_model = none_throws(self.sobol_MBM_GS_nodes.model).model self.assertTrue( np.array_equal( ms.pop("generated_points"), sobol_model.generated_points @@ -1848,20 +1849,22 @@ def test_trials_as_df_node_gs(self) -> None: trial = exp.new_batch_trial( generator_runs=gs.gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) ) - trials_df = not_none(gs.trials_as_df) + trials_df = none_throws(gs.trials_as_df) self.assertFalse(trials_df.empty) self.assertEqual(trials_df.head()["Trial Status"][0], "CANDIDATE") self.assertEqual(trials_df.head()["Generation Model(s)"][0], ["Sobol"]) # Changes in trial status should be reflected in the DF. trial.run() - self.assertEqual(not_none(gs.trials_as_df).head()["Trial Status"][0], "RUNNING") + self.assertEqual( + none_throws(gs.trials_as_df).head()["Trial Status"][0], "RUNNING" + ) # Add a new trial which will be generated from multiple nodes, and check that # is properly reflected in the DF. trial = exp.new_batch_trial( generator_runs=gs.gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) ) self.assertEqual( - not_none(gs.trials_as_df).head()["Generation Nodes"][1], + none_throws(gs.trials_as_df).head()["Generation Nodes"][1], ["mbm", "sobol_2", "sobol_3"], ) diff --git a/ax/modelbridge/tests/test_hierarchical_search_space.py b/ax/modelbridge/tests/test_hierarchical_search_space.py index 883d363a7ed..c9a94c641a4 100644 --- a/ax/modelbridge/tests/test_hierarchical_search_space.py +++ b/ax/modelbridge/tests/test_hierarchical_search_space.py @@ -28,8 +28,9 @@ from ax.runners.synthetic import SyntheticRunner from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import none_throws class TestHierarchicalSearchSpace(TestCase): @@ -166,10 +167,10 @@ def _test_gen_base( for t in experiment.trials.values(): trial = checked_cast(Trial, t) - arm = not_none(trial.arm) + arm = none_throws(trial.arm) self.assertIn(len(arm.parameters), expected_num_candidate_params) # Check that the trials have the full parameterization recorded. - full_parameterization = not_none( + full_parameterization = none_throws( trial._get_candidate_metadata(arm_name=arm.name) )[Keys.FULL_PARAMETERIZATION] self.assertEqual(full_parameterization.keys(), hss.parameters.keys()) @@ -195,9 +196,9 @@ def _base_test_predict_and_cv( ) for t in experiment.trials.values(): trial = checked_cast(Trial, t) - arm = not_none(trial.arm) + arm = none_throws(trial.arm) final_parameterization = arm.parameters - full_parameterization = not_none( + full_parameterization = none_throws( trial._get_candidate_metadata(arm_name=arm.name) )[Keys.FULL_PARAMETERIZATION] # Predict with full parameterization -- this should always work. diff --git a/ax/modelbridge/tests/test_model_spec.py b/ax/modelbridge/tests/test_model_spec.py index 9332cffe268..8f43eee00ac 100644 --- a/ax/modelbridge/tests/test_model_spec.py +++ b/ax/modelbridge/tests/test_model_spec.py @@ -17,9 +17,9 @@ from ax.modelbridge.modelbridge_utils import extract_search_space_digest from ax.modelbridge.registry import Models from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import none_throws class BaseModelSpecTest(TestCase): @@ -170,7 +170,7 @@ def test_gen_attaches_empty_model_fit_metadata_if_fit_not_applicable(self) -> No ms = ModelSpec(model_enum=Models.SOBOL) ms.fit(experiment=self.experiment, data=self.data) gr = ms.gen(n=1) - gen_metadata = not_none(gr.gen_metadata) + gen_metadata = none_throws(gr.gen_metadata) self.assertEqual(gen_metadata["model_fit_quality"], None) self.assertEqual(gen_metadata["model_std_quality"], None) self.assertEqual(gen_metadata["model_fit_generalization"], None) @@ -180,7 +180,7 @@ def test_gen_attaches_model_fit_metadata_if_applicable(self) -> None: ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) ms.fit(experiment=self.experiment, data=self.data) gr = ms.gen(n=1) - gen_metadata = not_none(gr.gen_metadata) + gen_metadata = none_throws(gr.gen_metadata) self.assertIsInstance(gen_metadata["model_fit_quality"], float) self.assertIsInstance(gen_metadata["model_std_quality"], float) self.assertIsInstance(gen_metadata["model_fit_generalization"], float) diff --git a/ax/modelbridge/tests/test_modelbridge_utils.py b/ax/modelbridge/tests/test_modelbridge_utils.py index b95e8a9dad0..8a8970efb63 100644 --- a/ax/modelbridge/tests/test_modelbridge_utils.py +++ b/ax/modelbridge/tests/test_modelbridge_utils.py @@ -31,10 +31,10 @@ ) from ax.modelbridge.registry import Cont_X_trans, Y_trans from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import get_robust_search_space, get_search_space from botorch.acquisition.risk_measures import VaR from botorch.utils.datasets import ContextualDataset, SupervisedDataset +from pyre_extensions import none_throws class TestModelBridgeUtils(TestCase): @@ -108,11 +108,13 @@ def test_extract_robust_digest(self) -> None: for p in rss.parameter_distributions: p.multiplicative = True rss.multiplicative = True - robust_digest = not_none(extract_robust_digest(rss, list(rss.parameters))) + robust_digest = none_throws( + extract_robust_digest(rss, list(rss.parameters)) + ) self.assertEqual(robust_digest.multiplicative, multiplicative) self.assertEqual(robust_digest.environmental_variables, []) self.assertIsNone(robust_digest.sample_environmental) - samples = not_none(robust_digest.sample_param_perturbations)() + samples = none_throws(robust_digest.sample_param_perturbations)() self.assertEqual(samples.shape, (8, 4)) constructor = np.ones if multiplicative else np.zeros self.assertTrue(np.equal(samples[:, 2:], constructor((8, 2))).all()) @@ -120,10 +122,10 @@ def test_extract_robust_digest(self) -> None: self.assertTrue(np.all(samples[:, 1] > 0)) # Check that it works as expected if param_names is missing some # non-distributional parameters. - robust_digest = not_none( + robust_digest = none_throws( extract_robust_digest(rss, list(rss.parameters)[:-1]) ) - samples = not_none(robust_digest.sample_param_perturbations)() + samples = none_throws(robust_digest.sample_param_perturbations)() self.assertEqual(samples.shape, (8, 3)) self.assertTrue(np.equal(samples[:, 2:], constructor((8, 1))).all()) self.assertTrue(np.all(samples[:, 1] > 0)) @@ -138,11 +140,11 @@ def test_extract_robust_digest(self) -> None: num_samples=8, environmental_variables=all_params[:2], ) - robust_digest = not_none(extract_robust_digest(rss, list(rss.parameters))) + robust_digest = none_throws(extract_robust_digest(rss, list(rss.parameters))) self.assertFalse(robust_digest.multiplicative) self.assertIsNone(robust_digest.sample_param_perturbations) self.assertEqual(robust_digest.environmental_variables, ["x", "y"]) - samples = not_none(robust_digest.sample_environmental)() + samples = none_throws(robust_digest.sample_environmental)() self.assertEqual(samples.shape, (8, 2)) # Both are continuous distributions, should be non-zero. self.assertTrue(np.all(samples != 0)) @@ -156,12 +158,14 @@ def test_extract_robust_digest(self) -> None: num_samples=8, environmental_variables=all_params[:1], ) - robust_digest = not_none(extract_robust_digest(rss, list(rss.parameters))) + robust_digest = none_throws(extract_robust_digest(rss, list(rss.parameters))) self.assertFalse(robust_digest.multiplicative) self.assertEqual( - not_none(robust_digest.sample_param_perturbations)().shape, (8, 3) + none_throws(robust_digest.sample_param_perturbations)().shape, (8, 3) + ) + self.assertEqual( + none_throws(robust_digest.sample_environmental)().shape, (8, 1) ) - self.assertEqual(not_none(robust_digest.sample_environmental)().shape, (8, 1)) self.assertEqual(robust_digest.environmental_variables, ["x"]) def test_feasible_hypervolume(self) -> None: diff --git a/ax/modelbridge/tests/test_prediction_utils.py b/ax/modelbridge/tests/test_prediction_utils.py index 2fb9c603ac0..442f3d4aa70 100644 --- a/ax/modelbridge/tests/test_prediction_utils.py +++ b/ax/modelbridge/tests/test_prediction_utils.py @@ -15,7 +15,7 @@ from ax.service.ax_client import AxClient from ax.service.utils.instantiation import ObjectiveProperties from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws class TestPredictionUtils(TestCase): @@ -28,7 +28,7 @@ def test_predict_at_point(self) -> None: observation_features = ObservationFeatures(parameters={"x1": 0.3, "x2": 0.5}) y_hat, se_hat = predict_at_point( - model=not_none(ax_client.generation_strategy.model), + model=none_throws(ax_client.generation_strategy.model), obsf=observation_features, metric_names={"test_metric1"}, ) @@ -37,7 +37,7 @@ def test_predict_at_point(self) -> None: self.assertEqual(len(se_hat), 1) y_hat, se_hat = predict_at_point( - model=not_none(ax_client.generation_strategy.model), + model=none_throws(ax_client.generation_strategy.model), obsf=observation_features, metric_names={"test_metric1", "test_metric2", "test_metric:agg"}, scalarized_metric_config=[ @@ -51,7 +51,7 @@ def test_predict_at_point(self) -> None: self.assertEqual(len(se_hat), 3) y_hat, se_hat = predict_at_point( - model=not_none(ax_client.generation_strategy.model), + model=none_throws(ax_client.generation_strategy.model), obsf=observation_features, metric_names={"test_metric1"}, scalarized_metric_config=[ @@ -75,7 +75,7 @@ def test_predict_by_features(self) -> None: 20: ObservationFeatures(parameters={"x1": 0.8, "x2": 0.5}), } predictions_map = predict_by_features( - model=not_none(ax_client.generation_strategy.model), + model=none_throws(ax_client.generation_strategy.model), label_to_feature_dict=observation_features_dict, metric_names={"test_metric1"}, ) diff --git a/ax/modelbridge/tests/test_torch_modelbridge.py b/ax/modelbridge/tests/test_torch_modelbridge.py index ee0e9bedc42..81ba99cf374 100644 --- a/ax/modelbridge/tests/test_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_modelbridge.py @@ -39,7 +39,7 @@ from ax.models.torch_base import TorchGenResults, TorchModel from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( get_branin_data, get_branin_experiment, @@ -55,6 +55,7 @@ MultiTaskDataset, SupervisedDataset, ) +from pyre_extensions import none_throws def _get_mock_modelbridge( @@ -132,8 +133,8 @@ def test_TorchModelBridge( for y1, y2, yvar1, yvar2 in zip( datasets["y1"].Y.tolist(), datasets["y2"].Y.tolist(), - not_none(datasets["y1"].Yvar).tolist(), - not_none(datasets["y2"].Yvar).tolist(), + none_throws(datasets["y1"].Yvar).tolist(), + none_throws(datasets["y2"].Yvar).tolist(), ) ] observations = recombine_observations(observation_features, observation_data) @@ -500,10 +501,10 @@ def test_best_point( autospec=True, ): run = modelbridge.gen(n=1, optimization_config=oc) - arm, predictions = not_none(run.best_arm_predictions) - model_arm, model_predictions = not_none(modelbridge.model_best_point()) - predictions = not_none(predictions) - model_predictions = not_none(model_predictions) + arm, predictions = none_throws(run.best_arm_predictions) + model_arm, model_predictions = none_throws(modelbridge.model_best_point()) + predictions = none_throws(predictions) + model_predictions = none_throws(model_predictions) self.assertEqual(arm.parameters, {}) self.assertEqual(predictions[0], {"m": 1.0}) self.assertEqual(predictions[1], {"m": {"m": 2.0}}) @@ -756,7 +757,7 @@ def test_convert_observations(self) -> None: search_space_digest=search_space_digest, ) self.assertEqual(len(converted_datasets), 1) - dataset = not_none(converted_datasets[0]) + dataset = none_throws(converted_datasets[0]) self.assertIs(dataset.__class__, expected_class) if use_task: sort_idx = torch.argsort(raw_X[:, -1]) @@ -854,7 +855,7 @@ def test_convert_contextual_observations(self) -> None: ["c0", "c1", "c2"], ) self.assertDictEqual( - not_none( + none_throws( checked_cast(ContextualDataset, dataset).metric_decomposition ), metric_decomposition, diff --git a/ax/modelbridge/tests/test_torch_moo_modelbridge.py b/ax/modelbridge/tests/test_torch_moo_modelbridge.py index 2aaf1cd9800..939336d8fbf 100644 --- a/ax/modelbridge/tests/test_torch_moo_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_moo_modelbridge.py @@ -40,7 +40,7 @@ from ax.service.utils.report_utils import exp_to_df from ax.utils.common.random import set_rng_seed from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( get_branin_data_multi_objective, get_branin_experiment_with_multi_objective, @@ -52,6 +52,7 @@ from ax.utils.testing.mock import mock_botorch_optimize from ax.utils.testing.modeling_stubs import transform_1, transform_2 from botorch.utils.multi_objective.pareto import is_non_dominated +from pyre_extensions import none_throws PARETO_FRONTIER_EVALUATOR_PATH = ( f"{get_pareto_frontier_and_configs.__module__}.pareto_frontier_evaluator" @@ -81,7 +82,7 @@ def helper_test_pareto_frontier( ) for trial in exp.trials.values(): trial.mark_running(no_runner_required=True).mark_completed() - metrics_dict = not_none(exp.optimization_config).metrics + metrics_dict = none_throws(exp.optimization_config).metrics objective_bound = 5.0 objective_thresholds = [ ObjectiveThreshold( @@ -190,7 +191,9 @@ def helper_test_pareto_frontier( self.assertTrue(torch.equal(obj_w[:2], -torch.ones(2, dtype=torch.double))) self.assertTrue(obj_t is not None) self.assertTrue( - torch.equal(not_none(obj_t)[:2], torch.full((2,), 5.0, dtype=torch.double)) + torch.equal( + none_throws(obj_t)[:2], torch.full((2,), 5.0, dtype=torch.double) + ) ) observed_frontier2 = pareto_frontier( modelbridge=modelbridge, @@ -254,7 +257,7 @@ def test_get_pareto_frontier_and_configs_input_validation(self) -> None: ) for trial in exp.trials.values(): trial.mark_running(no_runner_required=True).mark_completed() - metrics_dict = not_none(exp.optimization_config).metrics + metrics_dict = none_throws(exp.optimization_config).metrics objective_thresholds = [ ObjectiveThreshold( metric=metrics_dict[f"branin_{letter}"], @@ -457,7 +460,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: ParameterConstraint(constraint_dict={"x1": 1.0}, bound=10.0) ] search_space.add_parameter_constraints(param_constraints) - oc = not_none(exp.optimization_config).clone() + oc = none_throws(exp.optimization_config).clone() oc.objective._objectives[0].minimize = True for use_partial_thresholds in (False, True): diff --git a/ax/modelbridge/tests/test_utils.py b/ax/modelbridge/tests/test_utils.py index c4314128573..1be57ddff77 100644 --- a/ax/modelbridge/tests/test_utils.py +++ b/ax/modelbridge/tests/test_utils.py @@ -33,7 +33,6 @@ from ax.modelbridge.registry import Models from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import ( get_experiment, get_hierarchical_search_space_experiment, @@ -60,7 +59,7 @@ def setUp(self) -> None: self.hss_sobol = Models.SOBOL(search_space=self.hss_exp.search_space) self.hss_gr = self.hss_sobol.gen(n=1) self.hss_trial = self.hss_exp.new_trial(self.hss_gr) - self.hss_arm = not_none(self.hss_trial.arm) + self.hss_arm = none_throws(self.hss_trial.arm) self.hss_cand_metadata = self.hss_trial._get_candidate_metadata( arm_name=self.hss_arm.name ) diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 6b1099874e1..b438cd3765f 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -68,8 +68,8 @@ from ax.models.torch_base import TorchModel, TorchOptConfig from ax.models.types import TConfig from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset +from pyre_extensions import none_throws from torch import Tensor logger: Logger = get_logger(__name__) @@ -127,7 +127,7 @@ def __init__( # Handle init for multi-objective optimization. self.is_moo_problem: bool = False if optimization_config or (experiment and experiment.optimization_config): - optimization_config = not_none( + optimization_config = none_throws( optimization_config or experiment.optimization_config ) self.is_moo_problem = optimization_config.is_moo_problem @@ -149,7 +149,7 @@ def __init__( ) def feature_importances(self, metric_name: str) -> dict[str, float]: - importances_tensor = not_none(self.model).feature_importances() + importances_tensor = none_throws(self.model).feature_importances() importances_dict = dict(zip(self.outcomes, importances_tensor)) importances_arr = importances_dict[metric_name].flatten() return dict(zip(self.parameters, importances_arr)) @@ -209,7 +209,7 @@ def infer_objective_thresholds( ) obj_thresholds = infer_objective_thresholds( - model=not_none(model), + model=none_throws(model), objective_weights=torch_opt_config.objective_weights, bounds=search_space_digest.bounds, outcome_constraints=torch_opt_config.outcome_constraints, @@ -253,7 +253,7 @@ def model_best_point( optimization_config=base_gen_args.optimization_config, ) try: - xbest = not_none(self.model).best_point( + xbest = none_throws(self.model).best_point( search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, ) @@ -453,7 +453,7 @@ def _cross_validate( device=self.device, ) # Use the model to do the cross validation - f_test, cov_test = not_none(self.model).cross_validate( + f_test, cov_test = none_throws(self.model).cross_validate( datasets=datasets, X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device), search_space_digest=search_space_digest, @@ -528,7 +528,7 @@ def evaluate_acquisition_function( return self._evaluate_acquisition_function( observation_features=obs_feats, search_space=base_gen_args.search_space, - optimization_config=not_none(base_gen_args.optimization_config), + optimization_config=none_throws(base_gen_args.optimization_config), pending_observations=base_gen_args.pending_observations, fixed_features=base_gen_args.fixed_features, acq_options=acq_options, @@ -559,7 +559,7 @@ def _evaluate_acquisition_function( for obsf in observation_features ] ) - evals = not_none(self.model).evaluate_acquisition_function( + evals = none_throws(self.model).evaluate_acquisition_function( X=self._array_to_tensor(X), search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, @@ -689,7 +689,7 @@ def _gen( ) # Generate the candidates - gen_results = not_none(self.model).gen( + gen_results = none_throws(self.model).gen( n=n, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, @@ -720,7 +720,7 @@ def _gen( candidate_metadata=gen_results.candidate_metadata, ) try: - xbest = not_none(self.model).best_point( + xbest = none_throws(self.model).best_point( search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, ) @@ -747,7 +747,7 @@ def _predict( raise ValueError(FIT_MODEL_ERROR.format(action="_model_predict")) # Convert observation features to array X = observation_features_to_array(self.parameters, observation_features) - f, cov = not_none(self.model).predict(X=self._array_to_tensor(X)) + f, cov = none_throws(self.model).predict(X=self._array_to_tensor(X)) f = f.detach().cpu().clone().numpy() cov = cov.detach().cpu().clone().numpy() if f.shape[-2] != X.shape[-2]: @@ -853,7 +853,7 @@ def _get_transformed_model_gen_args( else None ) if risk_measure is not None: - if not not_none(self.model)._supports_robust_optimization: + if not none_throws(self.model)._supports_robust_optimization: raise UnsupportedError( f"{self.model.__class__.__name__} does not support robust " "optimization. Consider using modular BoTorch model instead." diff --git a/ax/modelbridge/transforms/cast.py b/ax/modelbridge/transforms/cast.py index 45ac255c321..6c5c04eb33a 100644 --- a/ax/modelbridge/transforms/cast.py +++ b/ax/modelbridge/transforms/cast.py @@ -13,7 +13,8 @@ from ax.exceptions.core import UserInputError from ax.modelbridge.transforms.base import Transform from ax.models.types import TConfig -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws if TYPE_CHECKING: # import as module to make sphinx-autodoc-typehints happy @@ -47,7 +48,7 @@ def __init__( modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, config: TConfig | None = None, ) -> None: - self.search_space: SearchSpace = not_none(search_space).clone() + self.search_space: SearchSpace = none_throws(search_space).clone() config = (config or {}).copy() self.flatten_hss: bool = checked_cast( bool, diff --git a/ax/modelbridge/transforms/convert_metric_names.py b/ax/modelbridge/transforms/convert_metric_names.py index 66852af075a..1fd875ad423 100644 --- a/ax/modelbridge/transforms/convert_metric_names.py +++ b/ax/modelbridge/transforms/convert_metric_names.py @@ -14,7 +14,7 @@ from ax.modelbridge.transforms.base import Transform from ax.models.types import TConfig from ax.utils.common.docutils import copy_doc -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws if TYPE_CHECKING: # import as module to make sphinx-autodoc-typehints happy @@ -105,7 +105,7 @@ def untransform_observations( if not self.perform_untransform: return observations for obs in observations: - trial_index = int(not_none(obs.features.trial_index)) + trial_index = int(none_throws(obs.features.trial_index)) trial_type = self.trial_index_to_type[trial_index] reverse_map = self.reverse_metric_name_map.get(trial_type) diff --git a/ax/modelbridge/transforms/derelativize.py b/ax/modelbridge/transforms/derelativize.py index 366b08ef6b7..a197ce96609 100644 --- a/ax/modelbridge/transforms/derelativize.py +++ b/ax/modelbridge/transforms/derelativize.py @@ -18,7 +18,7 @@ from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.ivw import ivw_metric_merge from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws if TYPE_CHECKING: @@ -67,7 +67,7 @@ def transform_optimization_config( "not fit with status quo." ) - sq = not_none(modelbridge.status_quo) + sq = none_throws(modelbridge.status_quo) # Only use model predictions if the status quo is in the search space (including # parameter constraints) and `use_raw_sq` is false. if not use_raw_sq and modelbridge.model_space.check_membership( diff --git a/ax/modelbridge/transforms/int_to_float.py b/ax/modelbridge/transforms/int_to_float.py index 7e75d07b652..5f161f50a13 100644 --- a/ax/modelbridge/transforms/int_to_float.py +++ b/ax/modelbridge/transforms/int_to_float.py @@ -20,7 +20,8 @@ from ax.modelbridge.transforms.utils import construct_new_search_space from ax.models.types import TConfig from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws if TYPE_CHECKING: # import as module to make sphinx-autodoc-typehints happy @@ -53,7 +54,7 @@ def __init__( modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, config: TConfig | None = None, ) -> None: - self.search_space: SearchSpace = not_none( + self.search_space: SearchSpace = none_throws( search_space, "IntToFloat requires search space" ) config = config or {} diff --git a/ax/modelbridge/transforms/relativize.py b/ax/modelbridge/transforms/relativize.py index 896c1a56998..bed943d81f1 100644 --- a/ax/modelbridge/transforms/relativize.py +++ b/ax/modelbridge/transforms/relativize.py @@ -26,8 +26,8 @@ from ax.modelbridge import ModelBridge from ax.modelbridge.transforms.base import Transform from ax.models.types import TConfig -from ax.utils.common.typeutils import not_none from ax.utils.stats.statstools import relativize, unrelativize +from pyre_extensions import none_throws if TYPE_CHECKING: # import as module to make sphinx-autodoc-typehints happy @@ -65,11 +65,11 @@ def __init__( config=config, ) # self.modelbridge should NOT be modified - self.modelbridge: ModelBridge = not_none( + self.modelbridge: ModelBridge = none_throws( modelbridge, f"{cls_name} transform requires a modelbridge" ) - self.status_quo_data_by_trial: dict[int, ObservationData] = not_none( + self.status_quo_data_by_trial: dict[int, ObservationData] = none_throws( self.modelbridge.status_quo_data_by_trial, f"{cls_name} requires status quo data.", ) diff --git a/ax/modelbridge/transforms/stratified_standardize_y.py b/ax/modelbridge/transforms/stratified_standardize_y.py index 82635601611..3ef5996871f 100644 --- a/ax/modelbridge/transforms/stratified_standardize_y.py +++ b/ax/modelbridge/transforms/stratified_standardize_y.py @@ -21,7 +21,8 @@ from ax.modelbridge.transforms.standardize_y import compute_standardization_parameters from ax.models.types import TConfig from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws if TYPE_CHECKING: @@ -117,7 +118,7 @@ def __init__( observation_features, observation_data = separate_observations(observations) Ys: defaultdict[tuple[str, TParamValue], list[float]] = defaultdict(list) for j, obsd in enumerate(observation_data): - v = not_none(observation_features[j].parameters[self.p_name]) + v = none_throws(observation_features[j].parameters[self.p_name]) strata = self.strata_mapping[v] for i, m in enumerate(obsd.metric_names): Ys[(m, strata)].append(obsd.means[i]) @@ -137,7 +138,7 @@ def transform_observations( ) -> list[Observation]: # Transform observations for obs in observations: - v = not_none(obs.features.parameters[self.p_name]) + v = none_throws(obs.features.parameters[self.p_name]) strata = self.strata_mapping[v] means = np.array([self.Ymean[(m, strata)] for m in obs.data.metric_names]) stds = np.array([self.Ystd[(m, strata)] for m in obs.data.metric_names]) @@ -158,7 +159,7 @@ def transform_optimization_config( f"StratifiedStandardizeY transform requires {self.p_name} to be fixed " "during generation." ) - v = not_none(fixed_features.parameters[self.p_name]) + v = none_throws(fixed_features.parameters[self.p_name]) strata = self.strata_mapping[v] for c in optimization_config.all_constraints: if c.relative: @@ -176,7 +177,7 @@ def untransform_observations( observations: list[Observation], ) -> list[Observation]: for obs in observations: - v = not_none(obs.features.parameters[self.p_name]) + v = none_throws(obs.features.parameters[self.p_name]) strata = self.strata_mapping[v] means = np.array([self.Ymean[(m, strata)] for m in obs.data.metric_names]) stds = np.array([self.Ystd[(m, strata)] for m in obs.data.metric_names]) @@ -193,7 +194,7 @@ def untransform_outcome_constraints( raise ValueError( f"StratifiedStandardizeY requires {self.p_name} to be fixed here" ) - v = not_none(fixed_features.parameters[self.p_name]) + v = none_throws(fixed_features.parameters[self.p_name]) strata = self.strata_mapping[v] for c in outcome_constraints: if c.relative: diff --git a/ax/modelbridge/transforms/tests/test_relativize_transform.py b/ax/modelbridge/transforms/tests/test_relativize_transform.py index 8de382ef3fa..7bcb0fe5c3b 100644 --- a/ax/modelbridge/transforms/tests/test_relativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_relativize_transform.py @@ -31,7 +31,7 @@ ) from ax.models.base import Model from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.stats.statstools import relativize_data from ax.utils.testing.core_stubs import ( get_branin_data_batch, @@ -42,6 +42,7 @@ get_search_space, ) from hypothesis import assume, given, settings, strategies as st +from pyre_extensions import none_throws class RelativizeDataTest(TestCase): @@ -121,7 +122,7 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( with_status_quo=True, ) # making status_quo out of design - not_none(exp._status_quo)._parameters["x1"] = 10000.0 + none_throws(exp._status_quo)._parameters["x1"] = 10000.0 for t in exp.trials.values(): t.mark_running(no_runner_required=True) exp.attach_data( @@ -137,14 +138,15 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( data=data, ) mean_in_data = data.df.query( - f"arm_name == '{not_none(exp.status_quo).name}'" + f"arm_name == '{none_throws(exp.status_quo).name}'" )["mean"].item() # modelbridge.status_quo_data_by_trial is accurate self.assertEqual( - mean_in_data, not_none(modelbridge.status_quo_data_by_trial)[0].means[0] + mean_in_data, + none_throws(modelbridge.status_quo_data_by_trial)[0].means[0], ) # reset SQ - not_none(exp._status_quo)._parameters["x1"] = 0.0 + none_throws(exp._status_quo)._parameters["x1"] = 0.0 modelbridge = ModelBridge( search_space=exp.search_space, model=Model(), @@ -174,15 +176,16 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( self.assertNotEqual(data, new_data) self.assertFalse(data.df.equals(new_data.df)) mean_in_data = new_data.df.query( - f"arm_name == '{not_none(new_exp.status_quo).name}'" + f"arm_name == '{none_throws(new_exp.status_quo).name}'" )["mean"].item() # modelbridge.status_quo_data_by_trial remains accurate self.assertEqual( - mean_in_data, not_none(modelbridge.status_quo_data_by_trial)[0].means[0] + mean_in_data, + none_throws(modelbridge.status_quo_data_by_trial)[0].means[0], ) # Can still find status_quo_data_by_trial when status_quo name is None - mb_sq = not_none(modelbridge._status_quo) + mb_sq = none_throws(modelbridge._status_quo) mb_sq.arm_name = None self.assertIsNotNone(modelbridge.status_quo_data_by_trial) self.assertEqual(len(modelbridge.status_quo_data_by_trial), 1) @@ -415,7 +418,7 @@ def test_multitask_data(self) -> None: status_quo=Observation( data=sq_obs_data[0], features=ObservationFeatures( - parameters=not_none(experiment.status_quo).parameters + parameters=none_throws(experiment.status_quo).parameters ), arm_name="status_quo", ), diff --git a/ax/modelbridge/transforms/time_as_feature.py b/ax/modelbridge/transforms/time_as_feature.py index d664e1313bb..91b992b1d9d 100644 --- a/ax/modelbridge/transforms/time_as_feature.py +++ b/ax/modelbridge/transforms/time_as_feature.py @@ -22,7 +22,8 @@ from ax.models.types import TConfig from ax.utils.common.logger import get_logger from ax.utils.common.timeutils import unixtime_to_pandas_ts -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws if TYPE_CHECKING: # import as module to make sphinx-autodoc-typehints happy @@ -68,7 +69,7 @@ def __init__( "Unable to use TimeAsFeature since not all observations have " "start time specified." ) - start_time = not_none(obsf.start_time).timestamp() + start_time = none_throws(obsf.start_time).timestamp() self.min_start_time = min(self.min_start_time, start_time) self.max_start_time = max(self.max_start_time, start_time) duration = self._get_duration(start_time=start_time, end_time=obsf.end_time) diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index c3174f89629..7727a33fc80 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -25,7 +25,7 @@ from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger from ax.utils.common.serialization import SerializationMixin, serialize_init_args -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -765,7 +765,7 @@ def check_aux_exp_purposes( """Helper method to check if all elements in expected_aux_exp_purposes are in (or not in) aux_exp_purposes""" if expected_aux_exp_purposes is not None: - for purpose in not_none(expected_aux_exp_purposes): + for purpose in none_throws(expected_aux_exp_purposes): purpose_present = purpose in aux_exp_purposes if purpose_present != include: return False diff --git a/ax/models/random/sobol.py b/ax/models/random/sobol.py index 28d7fbddfd3..c8a54821006 100644 --- a/ax/models/random/sobol.py +++ b/ax/models/random/sobol.py @@ -14,7 +14,7 @@ from ax.models.model_utils import tunable_feature_indices from ax.models.random.base import RandomModel from ax.models.types import TConfig -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws from torch.quasirandom import SobolEngine @@ -115,7 +115,7 @@ def gen( rounding_func=rounding_func, ) if self.engine: - self.init_position = not_none(self.engine).num_generated + self.init_position = none_throws(self.engine).num_generated return points, weights def _gen_samples(self, n: int, tunable_d: int) -> npt.NDArray: @@ -136,4 +136,4 @@ def _gen_samples(self, n: int, tunable_d: int) -> npt.NDArray: raise ValueError( "Sobol Engine must be initialized before candidate generation." ) - return not_none(self.engine).draw(n, dtype=torch.double).numpy() + return none_throws(self.engine).draw(n, dtype=torch.double).numpy() diff --git a/ax/models/tests/test_botorch_defaults.py b/ax/models/tests/test_botorch_defaults.py index d90d144b2b4..d7a019e4106 100644 --- a/ax/models/tests/test_botorch_defaults.py +++ b/ax/models/tests/test_botorch_defaults.py @@ -21,7 +21,7 @@ NO_OBSERVED_POINTS_MESSAGE, ) from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.mock import mock_botorch_optimize from botorch.acquisition.logei import ( qLogExpectedImprovement, @@ -50,6 +50,7 @@ from gpytorch.priors import GammaPrior from gpytorch.priors.lkj_prior import LKJCovariancePrior from gpytorch.priors.prior import Prior +from pyre_extensions import none_throws class BotorchDefaultsTest(TestCase): @@ -367,7 +368,7 @@ def test_get_acquisition_func(self) -> None: torch.tensor([[1.0], [-1.0], [0.0]]), # k x 1 ) X_observed = torch.zeros(2, d) - expected_constraints = not_none( + expected_constraints = none_throws( get_outcome_constraint_transforms(outcome_constraints) ) samples = torch.zeros(n, m) # to test constraints diff --git a/ax/models/tests/test_botorch_model.py b/ax/models/tests/test_botorch_model.py index 6d266b29d18..5cfe70ca957 100644 --- a/ax/models/tests/test_botorch_model.py +++ b/ax/models/tests/test_botorch_model.py @@ -28,7 +28,6 @@ from ax.models.torch.utils import sample_simplex from ax.models.torch_base import TorchOptConfig from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.mock import mock_botorch_optimize from ax.utils.testing.torch_stubs import get_torch_test_data from botorch.acquisition.utils import get_infeasible_cost @@ -42,6 +41,7 @@ from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood from gpytorch.priors import GammaPrior from gpytorch.priors.lkj_prior import LKJCovariancePrior +from pyre_extensions import none_throws FIT_MODEL_MO_PATH = f"{get_and_fit_model.__module__}.fit_gpytorch_mll" @@ -475,7 +475,7 @@ def test_BotorchModel( ) # test get_rounding_func - dummy_rounding = not_none(get_rounding_func(rounding_func=dummy_func)) + dummy_rounding = none_throws(get_rounding_func(rounding_func=dummy_func)) X_temp = torch.rand(1, 2, 3, 4) self.assertTrue(torch.equal(X_temp, dummy_rounding(X_temp))) diff --git a/ax/models/tests/test_random.py b/ax/models/tests/test_random.py index a2e60922b26..af76bcda4db 100644 --- a/ax/models/tests/test_random.py +++ b/ax/models/tests/test_random.py @@ -10,7 +10,7 @@ import torch from ax.models.random.base import RandomModel from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws class RandomModelTest(TestCase): @@ -44,7 +44,7 @@ def test_RandomModelGenUnconstrained(self) -> None: def test_ConvertEqualityConstraints(self) -> None: fixed_features = {3: 0.7, 1: 0.5} d = 4 - C, c = not_none( + C, c = none_throws( self.random_model._convert_equality_constraints(d, fixed_features) ) c_expected = torch.tensor([[0.5], [0.7]], dtype=torch.double) @@ -58,7 +58,7 @@ def test_ConvertEqualityConstraints(self) -> None: def test_ConvertInequalityConstraints(self) -> None: A = np.array([[1, 2], [3, 4]]) b = np.array([[5], [6]]) - A_result, b_result = not_none( + A_result, b_result = none_throws( self.random_model._convert_inequality_constraints((A, b)) ) A_expected = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) diff --git a/ax/models/tests/test_torch_model_utils.py b/ax/models/tests/test_torch_model_utils.py index 28a7591cd1d..102cf25cfb3 100644 --- a/ax/models/tests/test_torch_model_utils.py +++ b/ax/models/tests/test_torch_model_utils.py @@ -17,12 +17,12 @@ tensor_callable_to_array_callable, ) from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from botorch.models import HeteroskedasticSingleTaskGP, SingleTaskGP from botorch.models.deterministic import GenericDeterministicModel from botorch.models.model import ModelList from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP +from pyre_extensions import none_throws from torch import Tensor @@ -169,7 +169,7 @@ def test_with_obj_thresholds_can_subset(self) -> None: ) model_sub = subset_model_results.model obj_weights_sub = subset_model_results.objective_weights - ocs_sub = not_none(subset_model_results.outcome_constraints) + ocs_sub = none_throws(subset_model_results.outcome_constraints) obj_t_sub = subset_model_results.objective_thresholds self.assertTrue(torch.equal(subset_model_results.indices, torch.tensor([0]))) self.assertEqual(model_sub.num_outputs, 1) diff --git a/ax/models/tests/test_torch_utils.py b/ax/models/tests/test_torch_utils.py index 327f875e030..15ccd550b48 100644 --- a/ax/models/tests/test_torch_utils.py +++ b/ax/models/tests/test_torch_utils.py @@ -14,7 +14,7 @@ get_botorch_objective_and_transform, ) from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.logei import qLogNoisyExpectedImprovement from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement @@ -34,6 +34,7 @@ from botorch.acquisition.risk_measures import Expectation from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.model import Model +from pyre_extensions import none_throws class TorchUtilsTest(TestCase): @@ -68,7 +69,7 @@ def _to_obs_set(X: torch.Tensor) -> set[tuple[float]]: fixed_features=fixed_features, ) expected = Xs[0][1:] - self.assertEqual(_to_obs_set(expected), _to_obs_set(not_none(X_observed))) + self.assertEqual(_to_obs_set(expected), _to_obs_set(none_throws(X_observed))) # Filter too strict; return unfiltered X_observed fixed_features = {0: 1.0} @@ -79,7 +80,7 @@ def _to_obs_set(X: torch.Tensor) -> set[tuple[float]]: fixed_features=fixed_features, ) expected = Xs[0] - self.assertEqual(_to_obs_set(expected), _to_obs_set(not_none(X_observed))) + self.assertEqual(_to_obs_set(expected), _to_obs_set(none_throws(X_observed))) # Out of design observations are filtered out Xs = [torch.tensor([[2.0, 3.0], [3.0, 4.0]])] @@ -101,7 +102,7 @@ def _to_obs_set(X: torch.Tensor) -> set[tuple[float]]: fit_out_of_design=True, ) expected = Xs[0] - self.assertEqual(_to_obs_set(expected), _to_obs_set(not_none(X_observed))) + self.assertEqual(_to_obs_set(expected), _to_obs_set(none_throws(X_observed))) @patch( f"{get_botorch_objective_and_transform.__module__}.get_infeasible_cost", @@ -203,7 +204,7 @@ def test_get_botorch_objective_w_risk_measures( ) self.assertTrue( torch.allclose( - not_none(risk_measure)(torch.tensor([[1.0], [2.0]])), + none_throws(risk_measure)(torch.tensor([[1.0], [2.0]])), torch.tensor([-1.5]), ) ) @@ -217,7 +218,7 @@ def test_get_botorch_objective_w_risk_measures( Y = torch.tensor([[1.0, -1.0, 3.0], [2.0, -2.0, 3.0]]) self.assertTrue( torch.allclose( - not_none(risk_measure)(Y), + none_throws(risk_measure)(Y), torch.tensor([-3.0]), ) ) @@ -230,7 +231,7 @@ def test_get_botorch_objective_w_risk_measures( ) self.assertTrue( torch.allclose( - not_none(risk_measure)(Y), + none_throws(risk_measure)(Y), torch.tensor([-1.5, -3.0]), ) ) @@ -259,7 +260,7 @@ def test_get_botorch_objective_w_risk_measures( risk_measure.chebyshev_weights = [0.0, 1.0] self.assertTrue( torch.allclose( - not_none(risk_measure)(Y), + none_throws(risk_measure)(Y), torch.tensor([-4.0]), ) ) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index c0344cd5ba0..18f4ec47b41 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -36,7 +36,6 @@ from ax.utils.common.base import Base from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.input_constructors import get_acqf_input_constructor from botorch.acquisition.knowledge_gradient import qKnowledgeGradient @@ -163,7 +162,7 @@ def __init__( objective_thresholds=self._objective_thresholds, ) objective_thresholds = ( - not_none(self._objective_thresholds)[subset_idcs] + none_throws(self._objective_thresholds)[subset_idcs] if subset_idcs is not None else self._objective_thresholds ) diff --git a/ax/models/torch/botorch_modular/sebo.py b/ax/models/torch/botorch_modular/sebo.py index 713042fd7de..5244843a31b 100644 --- a/ax/models/torch/botorch_modular/sebo.py +++ b/ax/models/torch/botorch_modular/sebo.py @@ -22,7 +22,6 @@ from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch_base import TorchOptConfig from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.multi_objective.logei import ( qLogNoisyExpectedHypervolumeImprovement, @@ -38,6 +37,7 @@ ) from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import unnormalize +from pyre_extensions import none_throws from torch import Tensor from torch.quasirandom import SobolEngine @@ -85,7 +85,7 @@ def __init__( clamp_tol=CLAMP_TOL, ) # update the training data in new surrogate - not_none(surrogate_f._training_data).append( + none_throws(surrogate_f._training_data).append( SupervisedDataset( X=X_sparse, Y=self.deterministic_model(X_sparse), @@ -161,7 +161,7 @@ def _transform_torch_config( ) if torch_opt_config.outcome_constraints is not None: # update the shape of A matrix in outcome_constraints - A, b = not_none(torch_opt_config.outcome_constraints) + A, b = none_throws(torch_opt_config.outcome_constraints) outcome_constraints_sebo = ( torch.cat([A, torch.zeros(A.shape[0], 1, **tkwargs)], dim=1), b, diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 7b8be177010..26a273ddd4e 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -19,7 +19,7 @@ from ax.models.types import TConfig from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.logei import qLogNoisyExpectedImprovement from botorch.acquisition.multi_objective.logei import ( @@ -37,6 +37,7 @@ from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from pyre_extensions import none_throws from torch import Tensor MIN_OBSERVED_NOISE_LEVEL = 1e-7 @@ -231,7 +232,7 @@ def convert_to_block_design( # data complies to block design, can concat with impunity Y = torch.cat([ds.Y for ds in datasets], dim=-1) if is_fixed: - Yvar = torch.cat([not_none(ds.Yvar) for ds in datasets], dim=-1) + Yvar = torch.cat([none_throws(ds.Yvar) for ds in datasets], dim=-1) else: Yvar = None datasets = [ diff --git a/ax/models/torch/botorch_moo.py b/ax/models/torch/botorch_moo.py index bf1c97b9044..671d785b4e6 100644 --- a/ax/models/torch/botorch_moo.py +++ b/ax/models/torch/botorch_moo.py @@ -45,9 +45,10 @@ from ax.utils.common.constants import Keys from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.model import Model +from pyre_extensions import none_throws from torch import Tensor @@ -254,7 +255,7 @@ def gen( if ( torch_opt_config.objective_thresholds is not None and torch_opt_config.objective_weights.shape[0] - != not_none(torch_opt_config.objective_thresholds).shape[0] + != none_throws(torch_opt_config.objective_thresholds).shape[0] ): raise AxError( "Objective weights and thresholds most both contain an element for" @@ -271,7 +272,7 @@ def gen( fixed_features=torch_opt_config.fixed_features, ) - model = not_none(self.model) + model = none_throws(self.model) full_objective_thresholds = torch_opt_config.objective_thresholds full_objective_weights = torch_opt_config.objective_weights full_outcome_constraints = torch_opt_config.outcome_constraints @@ -352,7 +353,7 @@ def gen( ): full_objective_thresholds = infer_objective_thresholds( model=model, - X_observed=not_none(X_observed), + X_observed=none_throws(X_observed), objective_weights=full_objective_weights, outcome_constraints=full_outcome_constraints, subset_idcs=idcs, diff --git a/ax/models/torch/botorch_moo_defaults.py b/ax/models/torch/botorch_moo_defaults.py index 33aedd58864..487973b6889 100644 --- a/ax/models/torch/botorch_moo_defaults.py +++ b/ax/models/torch/botorch_moo_defaults.py @@ -40,7 +40,7 @@ subset_model, ) from ax.models.torch_base import TorchModel -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.acquisition import get_acquisition_function from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.multi_objective.logei import ( @@ -60,6 +60,7 @@ from botorch.posteriors.posterior_list import PosteriorList from botorch.utils.multi_objective.hypervolume import infer_reference_point from botorch.utils.multi_objective.pareto import is_non_dominated +from pyre_extensions import none_throws from torch import Tensor DEFAULT_EHVI_MC_SAMPLES = 128 @@ -587,7 +588,7 @@ def pareto_frontier_evaluator( # TODO: better input validation, making more explicit whether we are using # model predictions or not if X is not None: - Y, Yvar = not_none(model).predict(X) + Y, Yvar = none_throws(model).predict(X) # model.predict returns cpu tensors Y = Y.to(X.device) Yvar = Yvar.to(X.device) @@ -749,7 +750,7 @@ def infer_objective_thresholds( ) with torch.no_grad(): pred = _check_posterior_type( - not_none(model).posterior(not_none(X_observed)) + none_throws(model).posterior(none_throws(X_observed)) ).mean if outcome_constraints is not None: diff --git a/ax/models/torch/tests/test_sebo.py b/ax/models/torch/tests/test_sebo.py index b07bffbe41f..5599689917b 100644 --- a/ax/models/torch/tests/test_sebo.py +++ b/ax/models/torch/tests/test_sebo.py @@ -28,7 +28,6 @@ from ax.models.torch_base import TorchOptConfig from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.mock import mock_botorch_optimize from botorch.acquisition.multi_objective.monte_carlo import ( qNoisyExpectedHypervolumeImprovement, @@ -38,6 +37,7 @@ from botorch.models.gp_regression import SingleTaskGP from botorch.models.model import ModelList from botorch.utils.datasets import SupervisedDataset +from pyre_extensions import none_throws SEBOACQUISITION_PATH: str = SEBOAcquisition.__module__ @@ -150,7 +150,7 @@ def test_init(self) -> None: ) # Check that determinstic metric is added to surrogate surrogate = acquisition1.surrogate - model_list = not_none(surrogate._model) + model_list = none_throws(surrogate._model) self.assertIsInstance(model_list, ModelList) self.assertIsInstance(model_list.models[0], SingleTaskGP) self.assertIsInstance(model_list.models[1], GenericDeterministicModel) @@ -170,7 +170,7 @@ def test_init(self) -> None: ) self.assertTrue( torch.equal( - not_none(acquisition1.objective_thresholds), + none_throws(acquisition1.objective_thresholds), self.objective_thresholds_sebo, ) ) @@ -182,7 +182,7 @@ def test_init(self) -> None: ) self.assertEqual(acquisition2.penalty_name, "L1_norm") surrogate = acquisition2.surrogate - model_list = not_none(surrogate._model) + model_list = none_throws(surrogate._model) self.assertIsInstance(model_list.models[1]._f, functools.partial) self.assertIs(model_list.models[1]._f.func, L1_norm_func) diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 3aa5a733214..2d6b12b7ef0 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -22,7 +22,7 @@ from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model from ax.models.torch_base import TorchOptConfig from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.mock import mock_botorch_optimize from ax.utils.testing.torch_stubs import get_torch_test_data from ax.utils.testing.utils import generic_equals @@ -40,7 +40,7 @@ from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood -from pyre_extensions import assert_is_instance +from pyre_extensions import assert_is_instance, none_throws from torch import Tensor from torch.nn import ModuleList # @manual -- autodeps can't figure it out. @@ -524,7 +524,7 @@ def test_construct_custom_model(self) -> None: self.training_data, search_space_digest=self.search_space_digest, ) - model = not_none(surrogate._model) + model = none_throws(surrogate._model) self.assertEqual(type(model.likelihood), GaussianLikelihood) noise_constraint.eval() # For the equality check. self.assertEqual( diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index 4ab0f4677c4..19d263f5fa1 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -29,7 +29,7 @@ from ax.models.torch_base import TorchOptConfig from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.torch_stubs import get_torch_test_data from botorch.acquisition import qLogNoisyExpectedImprovement from botorch.acquisition.multi_objective.logei import ( @@ -42,6 +42,7 @@ from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP from botorch.utils.datasets import SupervisedDataset +from pyre_extensions import none_throws class BoTorchModelUtilsTest(TestCase): @@ -434,7 +435,7 @@ def test_convert_to_block_design(self) -> None: self.assertTrue(torch.equal(new_datasets[0].X, X)) self.assertTrue(torch.equal(new_datasets[0].Y, torch.cat(Ys, dim=-1))) self.assertTrue( - torch.equal(not_none(new_datasets[0].Yvar), torch.cat(Yvars, dim=-1)) + torch.equal(none_throws(new_datasets[0].Yvar), torch.cat(Yvars, dim=-1)) ) self.assertEqual(new_datasets[0].outcome_names, metric_names) @@ -496,7 +497,7 @@ def test_convert_to_block_design(self) -> None: ) self.assertTrue( torch.equal( - not_none(new_datasets[0].Yvar), + none_throws(new_datasets[0].Yvar), torch.cat([Yvar[:3] for Yvar in Yvars], dim=-1), ) ) @@ -505,7 +506,7 @@ def test_convert_to_block_design(self) -> None: def test_to_inequality_constraints(self) -> None: A = torch.tensor([[0, 1, -2, 3], [0, 1, 0, 0]]) b = torch.tensor([[1], [2]]) - ineq_constraints = not_none( + ineq_constraints = none_throws( _to_inequality_constraints(linear_constraints=(A, b)) ) self.assertEqual(len(ineq_constraints), 2) diff --git a/ax/plot/diagnostic.py b/ax/plot/diagnostic.py index 09ea83b4716..e00faa2d436 100644 --- a/ax/plot/diagnostic.py +++ b/ax/plot/diagnostic.py @@ -28,8 +28,8 @@ ) from ax.plot.helper import compose_annotation from ax.plot.scatter import _error_scatter_data, _error_scatter_trace -from ax.utils.common.typeutils import not_none from plotly import subplots +from pyre_extensions import none_throws # type alias @@ -677,7 +677,7 @@ def interact_batch_comparison( if isinstance(experiment, MultiTypeExperiment): observations = convert_mt_observations(observations, experiment) if not status_quo_name and experiment.status_quo: - status_quo_name = not_none(experiment.status_quo).name + status_quo_name = none_throws(experiment.status_quo).name plot_data = _get_batch_comparison_plot_data( observations, batch_x, batch_y, rel=rel, status_quo_name=status_quo_name ) diff --git a/ax/plot/helper.py b/ax/plot/helper.py index bc7c3d3f8f6..6a50ca8e8bc 100644 --- a/ax/plot/helper.py +++ b/ax/plot/helper.py @@ -27,7 +27,7 @@ from ax.modelbridge.transforms.ivw import IVW from ax.plot.base import DECIMALS, PlotData, PlotInSampleArm, PlotOutOfSampleArm, Z from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -269,8 +269,8 @@ def _get_in_sample_arms( else: pred_y = obs_y pred_se = obs_se - in_sample_plot[not_none(obs.arm_name)] = PlotInSampleArm( - name=not_none(obs.arm_name), + in_sample_plot[none_throws(obs.arm_name)] = PlotInSampleArm( + name=none_throws(obs.arm_name), y=obs_y, se=obs_se, parameters=obs.features.parameters, @@ -792,7 +792,7 @@ def infer_is_relative( relative = {} constraint_relativity = {} if model._optimization_config: - constraints = not_none(model._optimization_config).outcome_constraints + constraints = none_throws(model._optimization_config).outcome_constraints constraint_relativity = { constraint.metric.name: constraint.relative for constraint in constraints } diff --git a/ax/plot/pareto_frontier.py b/ax/plot/pareto_frontier.py index cdf9551a520..027d9b04ecf 100644 --- a/ax/plot/pareto_frontier.py +++ b/ax/plot/pareto_frontier.py @@ -26,8 +26,9 @@ from ax.plot.helper import _format_CI, _format_dict, extend_range from ax.plot.pareto_utils import ParetoFrontierResults from ax.service.utils.best_point_mixin import BestPointMixin -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from plotly import express as px +from pyre_extensions import none_throws from scipy.stats import norm @@ -852,7 +853,7 @@ def _validate_experiment_and_get_optimization_config( f"{metric_names}." ) return None - return not_none(experiment.optimization_config) + return none_throws(experiment.optimization_config) def _validate_and_maybe_get_default_metric_names( @@ -861,15 +862,15 @@ def _validate_and_maybe_get_default_metric_names( ) -> tuple[str, str]: # Default metric_names is all metrics, producing an error if more than 2 if metric_names is None: - if not_none(optimization_config).is_moo_problem: + if none_throws(optimization_config).is_moo_problem: multi_objective = checked_cast( - MultiObjective, not_none(optimization_config).objective + MultiObjective, none_throws(optimization_config).objective ) metric_names = tuple(obj.metric.name for obj in multi_objective.objectives) else: raise UserInputError( "Inference of `metric_names` failed. Expected `MultiObjective` but " - f"got {not_none(optimization_config).objective}. Please specify " + f"got {none_throws(optimization_config).objective}. Please specify " "`metric_names` of length 2 or provide an experiment whose " "`optimization_config` has 2 objective metrics." ) @@ -976,7 +977,7 @@ def _validate_and_maybe_get_default_minimize( "includes 2 objectives. Returning None." ) return None - minimize = tuple(not_none(i_min) for i_min in minimize) + minimize = tuple(none_throws(i_min) for i_min in minimize) # If only one bool provided, use for both dimensions elif isinstance(minimize, bool): minimize = (minimize, minimize) diff --git a/ax/plot/slice.py b/ax/plot/slice.py index 29b1c68ebcb..cb3c7623108 100644 --- a/ax/plot/slice.py +++ b/ax/plot/slice.py @@ -24,8 +24,8 @@ slice_config_to_trace, TNullableGeneratorRunsDict, ) -from ax.utils.common.typeutils import not_none from plotly import graph_objs as go +from pyre_extensions import none_throws # type aliases @@ -349,13 +349,13 @@ def interact_slice_plotly( is_log_dict: dict[str, bool] = {} if should_replace_slice_values: - slice_values = not_none(fixed_features).parameters + slice_values = none_throws(fixed_features).parameters else: fixed_features = ObservationFeatures(parameters={}) fixed_values = get_fixed_values(model, slice_values, trial_index) prediction_features = [] for x in grid: - predf = deepcopy(not_none(fixed_features)) + predf = deepcopy(none_throws(fixed_features)) predf.parameters = fixed_values.copy() predf.parameters[param_name] = x prediction_features.append(predf) diff --git a/ax/plot/trace.py b/ax/plot/trace.py index ed2858eceed..070f7ad45f0 100644 --- a/ax/plot/trace.py +++ b/ax/plot/trace.py @@ -17,9 +17,9 @@ from ax.plot.base import AxPlotConfig, AxPlotTypes from ax.plot.color import COLORS, DISCRETE_COLOR_SCALE, rgba from ax.utils.common.timeutils import timestamps_in_range -from ax.utils.common.typeutils import not_none from plotly import express as px from plotly.express.colors import sample_colorscale +from pyre_extensions import none_throws FIVE_MINUTES = timedelta(minutes=5) @@ -700,7 +700,7 @@ def get_running_trials_per_minute( trial_runtimes: list[tuple[int, datetime, datetime | None]] = [ ( trial.index, - not_none(trial._time_run_started), + none_throws(trial._time_run_started), trial._time_completed, # Time trial was completed, failed, or abandoned. ) for trial in experiment.trials.values() @@ -708,7 +708,7 @@ def get_running_trials_per_minute( ] earliest_start = min(tr[1] for tr in trial_runtimes) - latest_end = max(not_none(tr[2]) for tr in trial_runtimes if tr[2] is not None) + latest_end = max(none_throws(tr[2]) for tr in trial_runtimes if tr[2] is not None) running_during = { ts: [ @@ -717,7 +717,7 @@ def get_running_trials_per_minute( # Trial is running during a given timestamp if: # 1) it's run start time is at/before the timestamp, # 2) it's completion time has not yet come or is after the timestamp. - if t[1] <= ts and (True if t[2] is None else not_none(t[2]) >= ts) + if t[1] <= ts and (True if t[2] is None else none_throws(t[2]) >= ts) ] for ts in timestamps_in_range( earliest_start, diff --git a/ax/runners/torchx.py b/ax/runners/torchx.py index fbf7bc211e2..3ed51ee20e2 100644 --- a/ax/runners/torchx.py +++ b/ax/runners/torchx.py @@ -16,7 +16,7 @@ from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.runner import Runner from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -145,7 +145,7 @@ def run(self, trial: BaseTrial) -> dict[str, Any]: ) parameters = dict(self._component_const_params) - parameters.update(not_none(trial.arm).parameters) + parameters.update(none_throws(trial.arm).parameters) component_args = inspect.getfullargspec(self._component).args if "trial_idx" in component_args: parameters["trial_idx"] = trial.index diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 09be6bf0df7..03ea368b4e4 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -87,8 +87,8 @@ from ax.utils.common.executils import retry_on_exception from ax.utils.common.logger import _round_floats_for_logging, get_logger from ax.utils.common.random import with_rng_seed -from ax.utils.common.typeutils import checked_cast, not_none -from pyre_extensions import assert_is_instance +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import assert_is_instance, none_throws logger: Logger = get_logger(__name__) @@ -555,8 +555,8 @@ def get_next_trial( raise e logger.info( f"Generated new trial {trial.index} with parameters " - f"{round_floats_for_logging(item=not_none(trial.arm).parameters)} " - f"using model {not_none(trial.generator_run)._model_key}." + f"{round_floats_for_logging(item=none_throws(trial.arm).parameters)} " + f"using model {none_throws(trial.generator_run)._model_key}." ) trial.mark_running(no_runner_required=True) self._save_or_update_trial_in_db_if_possible( @@ -568,7 +568,7 @@ def get_next_trial( generation_strategy=self.generation_strategy, new_generator_runs=[self.generation_strategy._generator_runs[-1]], ) - return not_none(trial.arm).parameters, trial.index + return none_throws(trial.arm).parameters, trial.index def get_current_trial_generation_limit(self) -> tuple[int, bool]: """How many trials this ``AxClient`` instance can currently produce via @@ -878,7 +878,7 @@ def attach_trial( def get_trial_parameters(self, trial_index: int) -> TParameterization: """Retrieve the parameterization of the trial by the given index.""" - return not_none(self.get_trial(trial_index).arm).parameters + return none_throws(self.get_trial(trial_index).arm).parameters def get_trials_data_frame(self) -> pd.DataFrame: """Get a Pandas DataFrame representation of this experiment. The columns @@ -961,7 +961,7 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float: ] ) hover_labels = [ - _format_dict(not_none(checked_cast(Trial, trial).arm).parameters) + _format_dict(none_throws(checked_cast(Trial, trial).arm).parameters) for trial in self.experiment.trials.values() if trial.status.is_completed ] @@ -1045,7 +1045,7 @@ def get_contour_plot( "Remaining parameters are affixed to the middle of their range." ) return plot_contour( - model=not_none(self.generation_strategy.model), + model=none_throws(self.generation_strategy.model), param_x=param_x, param_y=param_y, metric_name=metric_name, @@ -1114,7 +1114,7 @@ def load_experiment_from_database( experiment, generation_strategy = self._load_experiment_and_generation_strategy( experiment_name=experiment_name ) - self._experiment = not_none( + self._experiment = none_throws( experiment, f"Experiment by name '{experiment_name}' not found." ) logger.info(f"Loaded {experiment}.") @@ -1211,9 +1211,9 @@ def get_model_predictions( metric_names_to_predict = ( set(metric_names) if metric_names is not None - else set(not_none(self.experiment.metrics).keys()) + else set(none_throws(self.experiment.metrics).keys()) ) - model = not_none( + model = none_throws( self.generation_strategy.model, "No model has been instantiated yet." ) @@ -1282,7 +1282,9 @@ def verify_trial_parameterization( """Whether the given parameterization matches that of the arm in the trial specified in the trial index. """ - return not_none(self.get_trial(trial_index).arm).parameters == parameterization + return ( + none_throws(self.get_trial(trial_index).arm).parameters == parameterization + ) def should_stop_trials_early( self, trial_indices: set[int] @@ -1464,7 +1466,7 @@ def from_json_snapshot( @property def experiment(self) -> Experiment: """Returns the experiment set on this Ax client.""" - return not_none( + return none_throws( self._experiment, ( "Experiment not set on Ax client. Must first " @@ -1479,14 +1481,14 @@ def get_trial(self, trial_index: int) -> Trial: @property def generation_strategy(self) -> GenerationStrategy: """Returns the generation strategy, set on this experiment.""" - return not_none( + return none_throws( self._generation_strategy, "No generation strategy has been set on this optimization yet.", ) @property def objective(self) -> Objective: - return not_none(self.experiment.optimization_config).objective + return none_throws(self.experiment.optimization_config).objective @property def objective_name(self) -> str: @@ -1668,7 +1670,7 @@ def _set_experiment( ) if self.db_settings_set: experiment_id, _ = self._get_experiment_and_generation_strategy_db_id( - experiment_name=not_none(name) + experiment_name=none_throws(name) ) if experiment_id: raise ValueError( @@ -1782,7 +1784,7 @@ def _gen_new_generator_run( else None ) with with_rng_seed(seed=self._random_seed): - return not_none(self.generation_strategy).gen( + return none_throws(self.generation_strategy).gen( experiment=self.experiment, n=n, pending_observations=self._get_pending_observation_features( @@ -1798,7 +1800,10 @@ def _find_last_trial_with_parameterization( contains an arm with that parameterization. """ for trial_idx in sorted(self.experiment.trials.keys(), reverse=True): - if not_none(self.get_trial(trial_idx).arm).parameters == parameterization: + if ( + none_throws(self.get_trial(trial_idx).arm).parameters + == parameterization + ): return trial_idx raise ValueError( f"No trial on experiment matches parameterization {parameterization}." diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index 9f85bf18451..16c62ddbc77 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -42,7 +42,7 @@ ) from ax.utils.common.executils import retry_on_exception from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws logger: logging.Logger = get_logger(__name__) @@ -188,7 +188,7 @@ def _get_weights_by_arm( ) -> Iterable[tuple[Arm, float | None]]: if isinstance(trial, Trial): if trial.arm is not None: - return [(not_none(trial.arm), None)] + return [(none_throws(trial.arm), None)] return [] elif isinstance(trial, BatchTrial): return trial.normalized_arm_weights().items() @@ -218,7 +218,7 @@ def run_trial(self) -> None: trial_index=self.current_trial, sample_sizes={}, data_type=self.experiment.default_data_type, - metric_names=not_none( + metric_names=none_throws( self.experiment.optimization_config ).objective.metric_names, ) diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 56095e70e3b..070f193efd0 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -68,7 +68,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.common.timeutils import current_timestamp_in_millis -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( CustomTestMetric, CustomTestRunner, @@ -368,7 +368,7 @@ def _get_generation_strategy_strategy_for_test( experiment: Experiment, generation_strategy: GenerationStrategy | None = None, ) -> GenerationStrategyInterface: - return not_none(generation_strategy) + return none_throws(generation_strategy) @property def runner_registry(self) -> dict[type[Runner], int]: @@ -1260,7 +1260,7 @@ def test_sqa_storage_with_experiment_name(self) -> None: # interface to node-level from strategy-level (the latter is likely the # better option) TODO len(gs._generator_runs), - len(not_none(loaded_gs)._generator_runs), + len(none_throws(loaded_gs)._generator_runs), ) scheduler.run_all_trials() # Check that experiment and GS were saved and test reloading with reduced state. @@ -1710,9 +1710,9 @@ def test_get_best_trial(self) -> None: scheduler.run_n_trials(max_trials=1) - trial, params, _arm = not_none(scheduler.get_best_trial()) - just_params, _just_arm = not_none(scheduler.get_best_parameters()) - just_params_unmodeled, _just_arm_unmodled = not_none( + trial, params, _arm = none_throws(scheduler.get_best_trial()) + just_params, _just_arm = none_throws(scheduler.get_best_parameters()) + just_params_unmodeled, _just_arm_unmodled = none_throws( scheduler.get_best_parameters(use_model_predictions=False) ) with self.assertRaisesRegex( diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 556141d9a87..4d0c8036cf6 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -74,7 +74,7 @@ from ax.storage.sqa_store.structs import DBSettings from ax.utils.common.random import with_rng_seed from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.measurement.synthetic_functions import Branin from ax.utils.testing.core_stubs import ( DummyEarlyStoppingStrategy, @@ -343,10 +343,10 @@ def test_status_quo_property(self) -> None: ) self.assertEqual(ax_client.status_quo, status_quo_params) with self.subTest("it returns a copy"): - not_none(ax_client.status_quo).update({"x": 2.0}) - not_none(ax_client.status_quo)["y"] = 2.0 - self.assertEqual(not_none(ax_client.status_quo)["x"], 1.0) - self.assertEqual(not_none(ax_client.status_quo)["y"], 1.0) + none_throws(ax_client.status_quo).update({"x": 2.0}) + none_throws(ax_client.status_quo)["y"] = 2.0 + self.assertEqual(none_throws(ax_client.status_quo)["x"], 1.0) + self.assertEqual(none_throws(ax_client.status_quo)["y"], 1.0) def test_set_optimization_config_to_moo_with_constraints(self) -> None: ax_client = AxClient() @@ -496,7 +496,7 @@ def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None: """ ax_client = get_branin_optimization() self.assertEqual( - [s.model for s in not_none(ax_client.generation_strategy)._steps], + [s.model for s in none_throws(ax_client.generation_strategy)._steps], [Models.SOBOL, Models.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): @@ -711,7 +711,7 @@ def test_default_generation_strategy_continuous_for_moo( }, ) self.assertEqual( - [s.model for s in not_none(ax_client.generation_strategy)._steps], + [s.model for s in none_throws(ax_client.generation_strategy)._steps], [Models.SOBOL, Models.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): @@ -1775,7 +1775,7 @@ def test_attach_trial_and_get_trial_parameters(self) -> None: ax_client.get_trial_parameters(trial_index=idx), {"x": 0, "y": 1} ) self.assertEqual( - not_none(ax_client.get_trial(trial_index=idx).arm).name, ARM_NAME + none_throws(ax_client.get_trial(trial_index=idx).arm).name, ARM_NAME ) with self.assertRaises(KeyError): ax_client.get_trial_parameters( @@ -2391,7 +2391,7 @@ def helper_test_get_pareto_optimal_points_from_sobol_step( "Sobol", ) - cfg = not_none(ax_client.experiment.optimization_config) + cfg = none_throws(ax_client.experiment.optimization_config) assert isinstance(cfg, MultiObjectiveOptimizationConfig) thresholds = np.array([t.bound for t in cfg.objective_thresholds]) diff --git a/ax/service/tests/test_best_point.py b/ax/service/tests/test_best_point.py index 11e1db65ce2..8b5863d5e30 100644 --- a/ax/service/tests/test_best_point.py +++ b/ax/service/tests/test_best_point.py @@ -19,7 +19,7 @@ from ax.service.utils.best_point import extract_Y_from_data from ax.service.utils.best_point_mixin import BestPointMixin from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( get_arm_weights2, get_arms_from_dict, @@ -27,6 +27,7 @@ get_experiment_with_observations, get_experiment_with_trial, ) +from pyre_extensions import none_throws class TestBestPointMixin(TestCase): @@ -40,7 +41,7 @@ def test_get_trace(self) -> None: ) self.assertEqual(get_trace(exp), [11, 10, 9, 9, 5]) # Same experiment with maximize via new optimization config. - opt_conf = not_none(exp.optimization_config).clone() + opt_conf = none_throws(exp.optimization_config).clone() opt_conf.objective.minimize = False self.assertEqual(get_trace(exp, opt_conf), [11, 11, 11, 15, 15]) @@ -163,7 +164,7 @@ def test_get_best_observed_value(self) -> None: ) self.assertEqual(get_best(exp), 5) # Same experiment with maximize via new optimization config. - opt_conf = not_none(exp.optimization_config).clone() + opt_conf = none_throws(exp.optimization_config).clone() opt_conf.objective.minimize = False self.assertEqual(get_best(exp, opt_conf), 15) diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index 897b824eb21..9672db9dcff 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -33,7 +33,6 @@ logger as best_point_logger, ) from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import ( get_branin_experiment, get_branin_metric, @@ -41,6 +40,7 @@ get_sobol, ) from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import none_throws best_point_module: str = _derel_opt_config_wrapper.__module__ DUMMY_OPTIMIZATION_CONFIG = "test_optimization_config" @@ -148,17 +148,17 @@ def test_best_raw_objective_point(self) -> None: constrained=True, minimize=False, ) - _, best_prediction = not_none(get_best_parameters(exp, Models)) - best_metrics = not_none(best_prediction)[0] + _, best_prediction = none_throws(get_best_parameters(exp, Models)) + best_metrics = none_throws(best_prediction)[0] self.assertDictEqual(best_metrics, {"m1": 3.0, "m2": 4.0}) # Tensor bounds are accepted. - constraint = not_none(exp.optimization_config).all_constraints[0] + constraint = none_throws(exp.optimization_config).all_constraints[0] # pyre-fixme[8]: Attribute `bound` declared in class `OutcomeConstraint` # has type `float` but is used as type `Tensor`. constraint.bound = torch.tensor(constraint.bound) - _, best_prediction = not_none(get_best_parameters(exp, Models)) - best_metrics = not_none(best_prediction)[0] + _, best_prediction = none_throws(get_best_parameters(exp, Models)) + best_metrics = none_throws(best_prediction)[0] self.assertDictEqual(best_metrics, {"m1": 3.0, "m2": 4.0}) def test_best_raw_objective_point_unsatisfiable(self) -> None: @@ -187,7 +187,7 @@ def test_best_raw_objective_point_unsatisfiable_relative(self) -> None: ) # Create altered optimization config with unsatisfiable relative constraint. - opt_conf = not_none(exp.optimization_config).clone() + opt_conf = none_throws(exp.optimization_config).clone() opt_conf.outcome_constraints[0].relative = True opt_conf.outcome_constraints[0].bound = 9999 @@ -245,7 +245,7 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: observations=[[-1, 1, 1], [1, 2, 1], [3, 3, -1], [2, 4, 1], [2, 0, 1]], constrained=True, ) - input_optimization_config = not_none(exp.optimization_config) + input_optimization_config = none_throws(exp.optimization_config) optimization_config = _derel_opt_config_wrapper( optimization_config=input_optimization_config ) @@ -277,8 +277,8 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: # ModelBridges will have specific addresses and so must be self-same to # pass equality checks. test_modelbridge_1 = get_tensor_converter_model( - experiment=not_none(exp), - data=not_none(exp).lookup_data(), + experiment=none_throws(exp), + data=none_throws(exp).lookup_data(), ) test_observations_1 = test_modelbridge_1.get_training_data() returned_value = _derel_opt_config_wrapper( @@ -311,8 +311,8 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: # Observations and ModelBridge are not constructed from other inputs when # provided. test_modelbridge_2 = get_tensor_converter_model( - experiment=not_none(exp), - data=not_none(exp).lookup_data(), + experiment=none_throws(exp), + data=none_throws(exp).lookup_data(), ) test_observations_2 = test_modelbridge_2.get_training_data() with self.assertLogs(logger=best_point_logger, level="WARN") as lg, patch( @@ -481,7 +481,7 @@ def test_is_row_feasible(self) -> None: ) feasible_series = _is_row_feasible( df=exp.lookup_data().df, - optimization_config=not_none(exp.optimization_config), + optimization_config=none_throws(exp.optimization_config), ) expected_per_arm = [False, True, False, True, True] expected_series = _repeat_elements( @@ -501,7 +501,7 @@ def test_is_row_feasible(self) -> None: # with lookout for warnings(" OutcomeConstraint(m3 >= 0.0%) ignored."): feasible_series = _is_row_feasible( df=exp.lookup_data().df, - optimization_config=not_none(exp.optimization_config), + optimization_config=none_throws(exp.optimization_config), ) self.assertTrue( any(relative_constraint_warning in warning for warning in lg.output), @@ -515,10 +515,10 @@ def test_is_row_feasible(self) -> None: feasible_series, expected_series, check_names=False ) exp._status_quo = exp.trials[0].arms[0] - for constraint in not_none(exp.optimization_config).all_constraints: + for constraint in none_throws(exp.optimization_config).all_constraints: constraint.relative = True optimization_config = _derel_opt_config_wrapper( - optimization_config=not_none(exp.optimization_config), + optimization_config=none_throws(exp.optimization_config), experiment=exp, ) with self.assertLogs(logger=best_point_logger, level="WARN") as lg: diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index 923f76f748e..c6ffa3bebff 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -49,7 +49,7 @@ ) from ax.service.utils.scheduler_options import SchedulerOptions from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( get_branin_experiment, get_branin_experiment_with_multi_objective, @@ -63,6 +63,7 @@ from ax.utils.testing.mock import mock_botorch_optimize from ax.utils.testing.modeling_stubs import get_generation_strategy from plotly import graph_objects as go +from pyre_extensions import none_throws OBJECTIVE_NAME = "branin" PARAMETER_COLUMNS = ["x1", "x2"] @@ -1115,7 +1116,7 @@ def test_compare_to_baseline_moo(self) -> None: digits=2, ) - result = not_none( + result = none_throws( compare_to_baseline( experiment=experiment, optimization_config=None, diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 508d6d4235a..7dd4d9fcb93 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -51,8 +51,9 @@ ) from ax.plot.pareto_utils import get_tensor_converter_model from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from numpy import nan +from pyre_extensions import none_throws from torch import Tensor logger: Logger = get_logger(__name__) @@ -133,7 +134,7 @@ def get_best_raw_objective_point_with_trial_index( for _, row in objective_rows.iterrows() } - return best_trial_index, not_none(best_arm).parameters, vals + return best_trial_index, none_throws(best_arm).parameters, vals def get_best_raw_objective_point( @@ -273,7 +274,7 @@ def get_best_parameters_from_model_predictions_with_trial_index( best_arm, best_arm_predictions = res - return idx, not_none(best_arm).parameters, best_arm_predictions + return idx, none_throws(best_arm).parameters, best_arm_predictions return None @@ -595,7 +596,7 @@ def get_pareto_optimal_parameters( # { trial_index --> (parameterization, (means, covariances) } res: dict[int, tuple[TParameterization, TModelPredictArm]] = OrderedDict() for obs in pareto_optimal_observations: - res[int(not_none(obs.features.trial_index))] = ( + res[int(none_throws(obs.features.trial_index))] = ( obs.features.parameters, (obs.data.means_dict, obs.data.covariance_matrix), ) @@ -664,7 +665,7 @@ def _is_row_feasible( name = df["metric_name"] # When SEM is NaN we should treat it as if it were 0 - sems = not_none(df["sem"].fillna(0)) + sems = none_throws(df["sem"].fillna(0)) # Bounds computed for 95% confidence interval on Normal distribution lower_bound = df["mean"] - sems * 1.96 @@ -756,8 +757,8 @@ def _derel_opt_config_wrapper( ) elif not modelbridge: modelbridge = get_tensor_converter_model( - experiment=not_none(experiment), - data=not_none(experiment).lookup_data(), + experiment=none_throws(experiment), + data=none_throws(experiment).lookup_data(), ) else: # Both modelbridge and experiment specified. logger.warning( diff --git a/ax/service/utils/best_point_mixin.py b/ax/service/utils/best_point_mixin.py index eec3a6fe476..5996bac164c 100644 --- a/ax/service/utils/best_point_mixin.py +++ b/ax/service/utils/best_point_mixin.py @@ -44,8 +44,9 @@ fill_missing_thresholds_from_nadir, ) from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.utils.multi_objective.box_decompositions import DominatedPartitioning +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -258,7 +259,7 @@ def _get_best_trial( trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> tuple[int, TParameterization, TModelPredictArm | None] | None: - optimization_config = optimization_config or not_none( + optimization_config = optimization_config or none_throws( experiment.optimization_config ) if optimization_config.is_moo_problem: @@ -317,7 +318,7 @@ def _get_best_observed_value( The best objective value so far. """ if optimization_config is None: - optimization_config = not_none(experiment.optimization_config) + optimization_config = none_throws(experiment.optimization_config) if optimization_config.is_moo_problem: raise NotImplementedError( "Please use `get_hypervolume` for multi-objective problems." @@ -333,7 +334,7 @@ def _get_best_observed_value( if predictions is None: return None - means = not_none(predictions)[0] + means = none_throws(predictions)[0] objective = optimization_config.objective if isinstance(objective, ScalarizedObjective): value = 0 @@ -352,7 +353,7 @@ def _get_pareto_optimal_parameters( trial_indices: Iterable[int] | None = None, use_model_predictions: bool = True, ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: - optimization_config = optimization_config or not_none( + optimization_config = optimization_config or none_throws( experiment.optimization_config ) if not optimization_config.is_moo_problem: @@ -402,7 +403,7 @@ def _get_hypervolume( ) model = get_model_from_generator_run( - generator_run=not_none(generation_strategy.last_generator_run), + generator_run=none_throws(generation_strategy.last_generator_run), experiment=experiment, data=experiment.fetch_data(trial_indices=trial_indices), models_enum=models_enum, @@ -449,7 +450,7 @@ def _get_trace( Returns: A list of performance values at each iteration. """ - optimization_config = optimization_config or not_none( + optimization_config = optimization_config or none_throws( experiment.optimization_config ) # Get the names of the metrics in optimization config. @@ -495,7 +496,7 @@ def _get_trace( objective=optimization_config.objective, outcomes=metric_names, ) - objective_thresholds = to_tensor(not_none(objective_thresholds)) + objective_thresholds = to_tensor(none_throws(objective_thresholds)) else: objective_thresholds = None ( @@ -518,7 +519,7 @@ def _get_trace( weighted_objective_thresholds, ) = get_weighted_mc_objective_and_objective_thresholds( objective_weights=objective_weights, - objective_thresholds=not_none(objective_thresholds), + objective_thresholds=none_throws(objective_thresholds), ) Y_obj = obj(Y) infeas_value = weighted_objective_thresholds @@ -527,7 +528,9 @@ def _get_trace( infeas_value = Y_obj.min() # Account for feasibility. if outcome_constraints is not None: - cons_tfs = not_none(get_outcome_constraint_transforms(outcome_constraints)) + cons_tfs = none_throws( + get_outcome_constraint_transforms(outcome_constraints) + ) feas = torch.all(torch.stack([c(Y) <= 0 for c in cons_tfs], dim=-1), dim=-1) # Set the infeasible points to reference point or the worst observed value. Y_obj[~feas] = infeas_value @@ -573,7 +576,7 @@ def _get_trace_by_progression( bins: list[float] | None = None, final_progression_only: bool = False, ) -> tuple[list[float], list[float]]: - optimization_config = optimization_config or not_none( + optimization_config = optimization_config or none_throws( experiment.optimization_config ) objective = optimization_config.objective.metric.name diff --git a/ax/service/utils/early_stopping.py b/ax/service/utils/early_stopping.py index 936508f933f..0eb6d48e69e 100644 --- a/ax/service/utils/early_stopping.py +++ b/ax/service/utils/early_stopping.py @@ -8,7 +8,7 @@ from ax.core.experiment import Experiment from ax.early_stopping.strategies import BaseEarlyStoppingStrategy -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws def should_stop_trials_early( @@ -31,7 +31,7 @@ def should_stop_trials_early( if early_stopping_strategy is None: return {} - early_stopping_strategy = not_none(early_stopping_strategy) + early_stopping_strategy = none_throws(early_stopping_strategy) return early_stopping_strategy.should_stop_trials_early( trial_indices=trial_indices, experiment=experiment ) diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 5f96d8cabf0..acc0f05f7b0 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -49,8 +49,8 @@ checked_cast, checked_cast_optional, checked_cast_to_tuple, - not_none, ) +from pyre_extensions import none_throws DEFAULT_OBJECTIVE_NAME = "objective" @@ -193,7 +193,7 @@ def _to_parameter_type( field_name: str, ) -> ParameterType: if typ is None: - typ = type(not_none(vals[0])) + typ = type(none_throws(vals[0])) parameter_type = cls._get_parameter_type(typ) # pyre-ignore[6] assert all(isinstance(x, typ) for x in vals), ( f"Values in `{field_name}` not of the same type and no " diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 5f3aa420e35..5b5c05dce92 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -63,9 +63,10 @@ from ax.service.utils.best_point import _derel_opt_config_wrapper, _is_row_feasible from ax.service.utils.early_stopping import get_early_stopping_metrics from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.sensitivity.sobol_measures import ax_parameter_sens from pandas.core.frame import DataFrame +from pyre_extensions import none_throws if TYPE_CHECKING: from ax.service.scheduler import Scheduler @@ -132,7 +133,7 @@ def _get_objective_trace_plot( plot_objective_value_vs_trial_index( exp_df=exp_df, metric_colname=metric_name, - minimize=not_none( + minimize=none_throws( optimization_config.objective.minimize if optimization_config.objective.metric.name == metric_name else experiment.metrics[metric_name].lower_is_better @@ -232,7 +233,7 @@ def _get_objective_v_param_plots( with gpytorch.settings.max_eager_kernel_size(float("inf")): output_plots.append( interact_contour_plotly( - model=not_none(model), + model=none_throws(model), metric_name=metric_name, parameters_to_use=params_to_use, ) @@ -362,7 +363,7 @@ def get_standard_plots( "standard plots." ) - objective = not_none(experiment.optimization_config).objective + objective = none_throws(experiment.optimization_config).objective if isinstance(objective, ScalarizedObjective): logger.warning( "get_standard_plots does not currently support ScalarizedObjective " @@ -694,7 +695,7 @@ def _get_generation_method_str(trial: BaseTrial) -> str: return trial_generation_property generation_methods = { - not_none(generator_run._model_key) + none_throws(generator_run._model_key) for generator_run in trial.generator_runs if generator_run._model_key is not None } @@ -897,9 +898,9 @@ def exp_to_df( # Add `FEASIBLE_COL_NAME` column according to constraints if any. if ( exp.optimization_config is not None - and len(not_none(exp.optimization_config).all_constraints) > 0 + and len(none_throws(exp.optimization_config).all_constraints) > 0 ): - optimization_config = not_none(exp.optimization_config) + optimization_config = none_throws(exp.optimization_config) try: if any(oc.relative for oc in optimization_config.all_constraints): optimization_config = _derel_opt_config_wrapper( @@ -1033,7 +1034,7 @@ def exp_to_df( metrics=metrics or list(exp.metrics.values()), ) - exp_df = not_none(not_none(exp_df).sort_values(["trial_index"])) + exp_df = none_throws(none_throws(exp_df).sort_values(["trial_index"])) initial_column_order = ( ["trial_index", "arm_name", "trial_status", "reason", "generation_method"] + (run_metadata_fields or []) @@ -1096,9 +1097,9 @@ def _get_metric_name_pairs( optimization_config = _validate_experiment_and_get_optimization_config( experiment=experiment ) - if not_none(optimization_config).is_moo_problem: + if none_throws(optimization_config).is_moo_problem: multi_objective = checked_cast( - MultiObjective, not_none(optimization_config).objective + MultiObjective, none_throws(optimization_config).objective ) metric_names = [obj.metric.name for obj in multi_objective.objectives] if len(metric_names) > use_first_n_metrics: @@ -1112,7 +1113,7 @@ def _get_metric_name_pairs( return metric_name_pairs raise UserInputError( "Inference of `metric_names` failed. Expected `MultiObjective` but " - f"got {not_none(optimization_config).objective}. Please provide an experiment " + f"got {none_throws(optimization_config).objective}. Please provide an experiment " "with a MultiObjective `optimization_config`." ) @@ -1359,10 +1360,10 @@ def select_baseline_arm( if ( experiment.status_quo and not arms_df[ - arms_df["arm_name"] == not_none(experiment.status_quo).name + arms_df["arm_name"] == none_throws(experiment.status_quo).name ].empty ): - baseline_arm_name = not_none(experiment.status_quo).name + baseline_arm_name = none_throws(experiment.status_quo).name return baseline_arm_name, False if ( @@ -1497,7 +1498,7 @@ def compare_to_baseline_impl( result_message = ( result_message + (" \n* " if len(comparison_list) > 1 else "") - + not_none(comparison_message) + + none_throws(comparison_message) ) return result_message if result_message else None @@ -1520,7 +1521,7 @@ def compare_to_baseline( ) if not comparison_list: return None - comparison_list = not_none(comparison_list) + comparison_list = none_throws(comparison_list) return compare_to_baseline_impl(comparison_list) @@ -1569,7 +1570,9 @@ def warn_if_unpredictable_metrics( if experiment.optimization_config is None: metric_names = list(experiment.metrics.keys()) else: - metric_names = list(not_none(experiment.optimization_config).metrics.keys()) + metric_names = list( + none_throws(experiment.optimization_config).metrics.keys() + ) else: # Raise a ValueError if any metric names are invalid. bad_metric_names = set(metric_names) - set(experiment.metrics.keys()) diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index 5686b113788..e3b09d9e034 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -27,7 +27,7 @@ from ax.modelbridge.generation_strategy import GenerationStrategy from ax.utils.common.executils import retry_on_exception from ax.utils.common.logger import _round_floats_for_logging, get_logger -from ax.utils.common.typeutils import not_none +from pyre_extensions import none_throws RETRY_EXCEPTION_TYPES: tuple[type[Exception], ...] = () @@ -40,7 +40,7 @@ from sqlalchemy import __version__ as sqa_version # pyre-fixme[16]: Module `sqlalchemy` has no attribute `__version__`. - sqa_major_version = int(not_none(re.match(r"^\d*", sqa_version))[0]) + sqa_major_version = int(none_throws(re.match(r"^\d*", sqa_version))[0]) if sqa_major_version > 1: msg = ( "Ax currently requires a sqlalchemy version below 2.0. This will be " @@ -133,7 +133,7 @@ def db_settings(self) -> DBSettings: """DB settings set on this instance; guaranteed to be non-None.""" if self._db_settings is None: raise ValueError("No DB settings are set on this instance.") - return not_none(self._db_settings) + return none_throws(self._db_settings) def _get_experiment_and_generation_strategy_db_id( self, experiment_name: str @@ -173,7 +173,7 @@ def _maybe_save_experiment_and_generation_strategy( raise ValueError( "Experiment must specify a name to use storage functionality." ) - exp_name = not_none(experiment.name) + exp_name = none_throws(experiment.name) exp_id, gs_id = self._get_experiment_and_generation_strategy_db_id( experiment_name=exp_name ) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 965ad491397..d39e46801dd 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -62,8 +62,9 @@ TClassDecoderRegistry, TDecoderRegistry, ) -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.common.typeutils_torch import torch_type_from_str +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -629,7 +630,7 @@ def _load_experiment_info( if trial.ttl_seconds is not None: exp._trials_have_ttl = True if exp.status_quo is not None: - sq = not_none(exp.status_quo) + sq = none_throws(exp.status_quo) exp._register_arm(sq) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 2283341be84..d61f029c2de 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -68,9 +68,8 @@ from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import not_none from pandas import read_json -from pyre_extensions import assert_is_instance +from pyre_extensions import assert_is_instance, none_throws from sqlalchemy.orm.exc import DetachedInstanceError logger: Logger = get_logger(__name__) @@ -112,7 +111,7 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa( from ax.storage.sqa_store.load import load_experiment auxiliary_experiments_by_purpose = {} - aux_exp_name_dict = not_none( + aux_exp_name_dict = none_throws( experiment_sqa.auxiliary_experiments_by_purpose ) for aux_exp_purpose_str, aux_exp_names in aux_exp_name_dict.items(): @@ -219,10 +218,10 @@ def _init_mt_experiment_from_sqa( else None ) trial_type_to_runner = { - not_none(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner) + none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } - default_trial_type = not_none(experiment_sqa.default_trial_type) + default_trial_type = none_throws(experiment_sqa.default_trial_type) properties = dict(experiment_sqa.properties or {}) default_data_type = experiment_sqa.default_data_type experiment = MultiTypeExperiment( @@ -242,7 +241,7 @@ def _init_mt_experiment_from_sqa( sqa_metric = sqa_metric_dict[tracking_metric.name] experiment.add_tracking_metric( tracking_metric, - trial_type=not_none(sqa_metric.trial_type), + trial_type=none_throws(sqa_metric.trial_type), canonical_name=sqa_metric.canonical_name, ) return experiment @@ -301,7 +300,7 @@ def experiment_from_sqa( for arm in trial.arms: experiment._register_arm(arm) if experiment.status_quo is not None: - sq = not_none(experiment.status_quo) + sq = none_throws(experiment.status_quo) experiment._register_arm(sq) experiment._time_created = experiment_sqa.time_created experiment._experiment_type = self.get_enum_name( @@ -327,8 +326,8 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: parameter = RangeParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, - lower=float(not_none(parameter_sqa.lower)), - upper=float(not_none(parameter_sqa.upper)), + lower=float(none_throws(parameter_sqa.lower)), + upper=float(none_throws(parameter_sqa.upper)), log_scale=parameter_sqa.log_scale or False, digits=parameter_sqa.digits, is_fidelity=parameter_sqa.is_fidelity or False, @@ -342,7 +341,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: f" parameter {parameter_sqa.name}." ) if bool(parameter_sqa.is_task) and target_value is None: - target_value = not_none(parameter_sqa.choice_values)[0] + target_value = none_throws(parameter_sqa.choice_values)[0] logger.debug( f"Target value is null for parameter {parameter_sqa.name}. " f"Defaulting to first choice {target_value}." @@ -350,7 +349,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: parameter = ChoiceParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, - values=not_none(parameter_sqa.choice_values), + values=none_throws(parameter_sqa.choice_values), is_fidelity=parameter_sqa.is_fidelity or False, target_value=target_value, is_ordered=parameter_sqa.is_ordered, @@ -471,8 +470,8 @@ def environmental_variable_from_sqa(self, parameter_sqa: SQAParameter) -> Parame parameter = RangeParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, - lower=float(not_none(parameter_sqa.lower)), - upper=float(not_none(parameter_sqa.upper)), + lower=float(none_throws(parameter_sqa.lower)), + upper=float(none_throws(parameter_sqa.upper)), log_scale=parameter_sqa.log_scale or False, digits=parameter_sqa.digits, is_fidelity=parameter_sqa.is_fidelity or False, @@ -688,14 +687,14 @@ def generator_run_from_sqa( ): best_arm = Arm( name=generator_run_sqa.best_arm_name, - parameters=not_none(generator_run_sqa.best_arm_parameters), + parameters=none_throws(generator_run_sqa.best_arm_parameters), ) best_arm_predictions = ( best_arm, - tuple(not_none(generator_run_sqa.best_arm_predictions)), + tuple(none_throws(generator_run_sqa.best_arm_predictions)), ) model_predictions = ( - tuple(not_none(generator_run_sqa.model_predictions)) + tuple(none_throws(generator_run_sqa.model_predictions)) if generator_run_sqa.model_predictions is not None else None ) @@ -1154,7 +1153,7 @@ def _scalarized_objective_from_sqa(self, parent_metric_sqa: SQAMetric) -> Object scalarized_objective = ScalarizedObjective( metrics=list(metrics), weights=list(weights), - minimize=not_none(parent_metric_sqa.minimize), + minimize=none_throws(parent_metric_sqa.minimize), ) scalarized_objective.db_id = parent_metric_sqa.id return scalarized_objective @@ -1173,9 +1172,9 @@ def _outcome_constraint_from_sqa( ) return OutcomeConstraint( metric=metric, - bound=float(not_none(metric_sqa.bound)), - op=not_none(metric_sqa.op), - relative=not_none(metric_sqa.relative), + bound=float(none_throws(metric_sqa.bound)), + op=none_throws(metric_sqa.op), + relative=none_throws(metric_sqa.relative), ) def _scalarized_outcome_constraint_from_sqa( @@ -1219,9 +1218,9 @@ def _scalarized_outcome_constraint_from_sqa( scalarized_outcome_constraint = ScalarizedOutcomeConstraint( metrics=list(metrics), weights=list(weights), - bound=float(not_none(metric_sqa.bound)), - op=not_none(metric_sqa.op), - relative=not_none(metric_sqa.relative), + bound=float(none_throws(metric_sqa.bound)), + op=none_throws(metric_sqa.op), + relative=none_throws(metric_sqa.relative), ) scalarized_outcome_constraint.db_id = metric_sqa.id return scalarized_outcome_constraint @@ -1236,8 +1235,8 @@ def _objective_threshold_from_sqa( ) ot = ObjectiveThreshold( metric=metric, - bound=float(not_none(metric_sqa.bound)), - relative=not_none(metric_sqa.relative), + bound=float(none_throws(metric_sqa.bound)), + relative=none_throws(metric_sqa.relative), op=metric_sqa.op, ) # ObjectiveThreshold constructor clones the passed-in metric, which means diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 5230573e37a..adaf1cba271 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -67,7 +67,8 @@ from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.serialization import serialize_init_args -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -166,8 +167,8 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: status_quo_name = None status_quo_parameters = None if experiment.status_quo is not None: - status_quo_name = not_none(experiment.status_quo).name - status_quo_parameters = not_none(experiment.status_quo).parameters + status_quo_name = none_throws(experiment.status_quo).name + status_quo_parameters = none_throws(experiment.status_quo).parameters trials = [] for trial in experiment.trials.values(): @@ -204,7 +205,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: metric.name ] elif experiment.runner: - runners.append(self.runner_to_sqa(not_none(experiment.runner))) + runners.append(self.runner_to_sqa(none_throws(experiment.runner))) # pyre-ignore[9]: Expected `Base` for 1st...yping.Type[Experiment]`. experiment_class: type[SQAExperiment] = self.config.class_to_sqa_class[ @@ -946,7 +947,7 @@ def trial_to_sqa( runner = None if trial.runner: - runner = self.runner_to_sqa(runner=not_none(trial.runner)) + runner = self.runner_to_sqa(runner=none_throws(trial.runner)) abandoned_arms = [] generator_runs = [] @@ -956,7 +957,7 @@ def trial_to_sqa( if isinstance(trial, Trial) and trial.generator_run: gr_sqa = self.generator_run_to_sqa( - generator_run=not_none(trial.generator_run), + generator_run=none_throws(trial.generator_run), reduced_state=generator_run_reduced_state, ) generator_runs.append(gr_sqa) diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index 4979d57a18b..847c748cb41 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -35,7 +35,8 @@ from ax.storage.utils import MetricIntent from ax.utils.common.constants import Keys -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws from sqlalchemy.orm import defaultload, lazyload, noload from sqlalchemy.orm.exc import DetachedInstanceError @@ -546,7 +547,7 @@ def get_generation_strategy_sqa_reduced_state( # Swap last generator run with no state for a generator run with # state. - gs_sqa.generator_runs[len(gs_sqa.generator_runs) - 1] = not_none(last_gr_sqa) + gs_sqa.generator_runs[len(gs_sqa.generator_runs) - 1] = none_throws(last_gr_sqa) return gs_sqa diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index d0f9c381670..bbf44302fce 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -40,7 +40,8 @@ from ax.storage.sqa_store.utils import copy_db_ids from ax.utils.common.base import Base from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -146,7 +147,7 @@ def _save_generation_strategy( decode_args={"experiment": experiment}, ) - return not_none(generation_strategy.db_id) + return none_throws(generation_strategy.db_id) def save_or_update_trial( diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 910582a3c6f..75566a76330 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -91,7 +91,7 @@ from ax.utils.common.logger import get_logger from ax.utils.common.serialization import serialize_init_args from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( CustomTestMetric, CustomTestRunner, @@ -128,6 +128,7 @@ sobol_gpei_generation_node_gs, ) from plotly import graph_objects as go, io as pio +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -436,7 +437,7 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None: gr = trial.generator_runs[0] if multi_objective and not immutable: objectives = checked_cast( - MultiObjective, not_none(gr.optimization_config).objective + MultiObjective, none_throws(gr.optimization_config).objective ).objectives for i, objective in enumerate(objectives): metric = objective.metric @@ -1500,7 +1501,7 @@ def test_EncodeDecodeGenerationStrategy(self) -> None: self.assertIsInstance(new_generation_strategy._steps[0].model, Models) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( - not_none(new_generation_strategy._experiment)._name, experiment._name + none_throws(new_generation_strategy._experiment)._name, experiment._name ) def test_EncodeDecodeGenerationNodeGSWithAdvancedSettings(self) -> None: @@ -1555,7 +1556,7 @@ def test_EncodeDecodeGenerationNodeGSWithAdvancedSettings(self) -> None: ) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( - not_none(new_generation_strategy._experiment)._name, experiment._name + none_throws(new_generation_strategy._experiment)._name, experiment._name ) def test_EncodeDecodeGenerationNodeBasedGenerationStrategy(self) -> None: @@ -1612,7 +1613,7 @@ def test_EncodeDecodeGenerationNodeBasedGenerationStrategy(self) -> None: ) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( - not_none(new_generation_strategy._experiment)._name, experiment._name + none_throws(new_generation_strategy._experiment)._name, experiment._name ) def test_EncodeDecodeGenerationStrategyReducedState(self) -> None: @@ -1659,7 +1660,7 @@ def test_EncodeDecodeGenerationStrategyReducedState(self) -> None: self.assertIsInstance(new_generation_strategy._steps[0].model, Models) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( - not_none(new_generation_strategy._experiment)._name, experiment._name + none_throws(new_generation_strategy._experiment)._name, experiment._name ) experiment.new_trial(new_generation_strategy.gen(experiment=experiment)) @@ -1719,7 +1720,7 @@ def test_EncodeDecodeGenerationStrategyReducedStateLoadExperiment(self) -> None: self.assertIsInstance(new_generation_strategy._steps[0].model, Models) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( - not_none(new_generation_strategy._experiment)._name, experiment._name + none_throws(new_generation_strategy._experiment)._name, experiment._name ) experiment.new_trial(new_generation_strategy.gen(experiment=experiment)) @@ -1770,12 +1771,12 @@ def test_UpdateGenerationStrategy(self) -> None: self.assertEqual(generation_strategy, loaded_generation_strategy) self.assertIsNotNone(loaded_generation_strategy._experiment) self.assertEqual( - not_none(generation_strategy._experiment).description, + none_throws(generation_strategy._experiment).description, experiment.description, ) self.assertEqual( - not_none(generation_strategy._experiment).description, - not_none(loaded_generation_strategy._experiment).description, + none_throws(generation_strategy._experiment).description, + none_throws(loaded_generation_strategy._experiment).description, ) def test_GeneratorRunGenMetadata(self) -> None: diff --git a/ax/utils/measurement/synthetic_functions.py b/ax/utils/measurement/synthetic_functions.py index 844dd689b50..e117074bb8c 100644 --- a/ax/utils/measurement/synthetic_functions.py +++ b/ax/utils/measurement/synthetic_functions.py @@ -13,9 +13,9 @@ import numpy.typing as npt import torch from ax.utils.common.docutils import copy_doc -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.test_functions import synthetic as botorch_synthetic -from pyre_extensions import override +from pyre_extensions import none_throws, override T = TypeVar("T") @@ -31,7 +31,7 @@ class SyntheticFunction(ABC): def informative_failure_on_none(self, attr: T | None) -> T: if attr is None: raise NotImplementedError(f"{self.name} does not specify property.") - return not_none(attr) + return none_throws(attr) @property def name(self) -> str: diff --git a/ax/utils/sensitivity/derivative_measures.py b/ax/utils/sensitivity/derivative_measures.py index 78c8426b4da..5cd2be6d9bb 100644 --- a/ax/utils/sensitivity/derivative_measures.py +++ b/ax/utils/sensitivity/derivative_measures.py @@ -11,7 +11,7 @@ from typing import Any import torch -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.sensitivity.derivative_gp import posterior_derivative from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior @@ -20,6 +20,7 @@ from botorch.utils.sampling import draw_sobol_samples from botorch.utils.transforms import unnormalize from gpytorch.distributions import MultivariateNormal +from pyre_extensions import none_throws def sample_discrete_parameters( @@ -126,7 +127,7 @@ def __init__( if self.derivative_gp: posterior = posterior_derivative( - model, self.input_mc_samples, not_none(self.kernel_type) + model, self.input_mc_samples, none_throws(self.kernel_type) ) else: self.input_mc_samples.requires_grad = True @@ -163,7 +164,7 @@ def aggregation( ) -> torch.Tensor: gradients_measure = torch.tensor( [ - torch.mean(transform_fun(not_none(self.mean_gradients)[:, i])) + torch.mean(transform_fun(none_throws(self.mean_gradients)[:, i])) for i in range(self.dim) ] ) @@ -177,7 +178,7 @@ def aggregation( [ torch.mean( transform_fun( - not_none(self.mean_gradients_btsp)[b][:, i] + none_throws(self.mean_gradients_btsp)[b][:, i] ) ) for i in range(self.dim) @@ -326,13 +327,13 @@ def _compute_gradient_quantities( ) self.samples_gradients_btsp = [] for j in range(self.num_gp_samples): - not_none(self.samples_gradients_btsp).append( + none_throws(self.samples_gradients_btsp).append( torch.cat( [ torch.index_select( - not_none(self.samples_gradients)[j], 0, indices + none_throws(self.samples_gradients)[j], 0, indices ).unsqueeze(0) - for indices in not_none(self.bootstrap_indices) + for indices in none_throws(self.bootstrap_indices) ], dim=0, ) @@ -347,7 +348,7 @@ def aggregation( torch.tensor( [ torch.mean( - transform_fun(not_none(self.samples_gradients)[j][:, i]) + transform_fun(none_throws(self.samples_gradients)[j][:, i]) ) for i in range(self.dim) ] @@ -379,7 +380,7 @@ def aggregation( [ torch.mean( transform_fun( - not_none(self.samples_gradients_btsp)[j][b][:, i] + none_throws(self.samples_gradients_btsp)[j][b][:, i] ) ) for i in range(self.dim) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 4e7e120234a..afdd488d0b4 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -108,7 +108,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.random import set_rng_seed -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from ax.utils.measurement.synthetic_functions import branin from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.monte_carlo import qExpectedImprovement @@ -122,6 +122,7 @@ from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -681,9 +682,9 @@ def get_branin_with_multi_task(with_multi_objective: bool = False) -> Experiment sobol_generator = get_sobol(search_space=exp.search_space, seed=TEST_SOBOL_SEED) sobol_run = sobol_generator.gen(n=5) exp.new_batch_trial(optimize_for_power=True).add_generator_run(sobol_run) - not_none(exp.trials.get(0)).run() + none_throws(exp.trials.get(0)).run() exp.new_batch_trial(optimize_for_power=True).add_generator_run(sobol_run) - not_none(exp.trials.get(1)).run() + none_throws(exp.trials.get(1)).run() return exp @@ -1998,10 +1999,10 @@ def get_branin_data( { "trial_index": trial.index, "metric_name": "branin", - "arm_name": not_none(checked_cast(Trial, trial).arm).name, + "arm_name": none_throws(checked_cast(Trial, trial).arm).name, "mean": branin( - float(not_none(not_none(trial.arm).parameters["x1"])), - float(not_none(not_none(trial.arm).parameters["x2"])), + float(none_throws(none_throws(trial.arm).parameters["x1"])), + float(none_throws(none_throws(trial.arm).parameters["x2"])), ), "sem": 0.0, } @@ -2029,8 +2030,8 @@ def get_branin_data_batch(batch: BatchTrial) -> Data: else: means.append( branin( - float(not_none(arm.parameters["x1"])), - float(not_none(arm.parameters["x2"])), + float(none_throws(arm.parameters["x1"])), + float(none_throws(arm.parameters["x2"])), ) ) return Data( diff --git a/ax/utils/testing/preference_stubs.py b/ax/utils/testing/preference_stubs.py index af89b64acab..8523c89b7df 100644 --- a/ax/utils/testing/preference_stubs.py +++ b/ax/utils/testing/preference_stubs.py @@ -15,8 +15,9 @@ from ax.core.types import TEvaluationOutcome, TParameterization from ax.service.utils.instantiation import InstantiationBase from ax.utils.common.constants import Keys -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast from botorch.utils.sampling import draw_sobol_samples +from pyre_extensions import none_throws # from ExperimentType in ae/lazarus/fb/utils/if/ae.thrift PBO_EXPERIMENT_TYPE: str = "PREFERENCE_LEARNING" @@ -155,7 +156,7 @@ def get_pbo_experiment( for t in range(num_experimental_trials): arm = {} for i, param_name in enumerate(experiment.search_space.parameters.keys()): - arm[param_name] = not_none(X)[t, i].item() + arm[param_name] = none_throws(X)[t, i].item() gr = ( # pyre-ignore: Incompatible parameter type [6] GeneratorRun([Arm(arm), Arm(sq)]) @@ -192,7 +193,7 @@ def get_pbo_experiment( metric_name=param_name, metric_names=parameter_names ) else: - param_dict[param_name] = not_none(X)[t * 2 + j, i].item() + param_dict[param_name] = none_throws(X)[t * 2 + j, i].item() arms.append(Arm(parameters=param_dict)) gr = GeneratorRun(arms)