Skip to content

Commit

Permalink
Replace not_none with none_throws in Ax
Browse files Browse the repository at this point in the history
Differential Revision: D65275698
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Oct 31, 2024
1 parent 46fa5a5 commit 9eeec96
Show file tree
Hide file tree
Showing 96 changed files with 542 additions and 485 deletions.
19 changes: 10 additions & 9 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ax.storage.json_store.load import load_experiment
from ax.storage.json_store.save import save_experiment
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.testing.benchmark_stubs import (
get_moo_surrogate,
get_multi_objective_benchmark_problem,
Expand Down Expand Up @@ -102,7 +101,7 @@ def test_storage(self) -> None:

# test saving to temporary file
with tempfile.NamedTemporaryFile(mode="w", delete=True, suffix=".json") as f:
save_experiment(not_none(res.experiment), f.name)
save_experiment(none_throws(res.experiment), f.name)
res.experiment_storage_id = f.name
res.experiment = None
self.assertIsNone(res.experiment)
Expand Down Expand Up @@ -154,7 +153,9 @@ def test_replication_sobol_synthetic(self) -> None:
for problem in problems:
res = benchmark_replication(problem=problem, method=method, seed=0)

self.assertEqual(problem.num_trials, len(not_none(res.experiment).trials))
self.assertEqual(
problem.num_trials, len(none_throws(res.experiment).trials)
)
self.assertTrue(np.isfinite(res.score_trace).all())
self.assertTrue(np.all(res.score_trace <= 100))
experiment = none_throws(res.experiment)
Expand All @@ -180,7 +181,7 @@ def test_replication_sobol_surrogate(self) -> None:

self.assertEqual(
problem.num_trials,
len(not_none(res.experiment).trials),
len(none_throws(res.experiment).trials),
)

self.assertTrue(np.isfinite(res.score_trace).all())
Expand Down Expand Up @@ -328,7 +329,7 @@ def test_replication_mbm(self) -> None:
res = benchmark_replication(problem=problem, method=method, seed=0)
self.assertEqual(
problem.num_trials,
len(not_none(res.experiment).trials),
len(none_throws(res.experiment).trials),
)
self.assertTrue(np.all(res.score_trace <= 100))
self.assertEqual(method.name, method.generation_strategy.name)
Expand All @@ -345,11 +346,11 @@ def test_replication_moo_sobol(self) -> None:

self.assertEqual(
problem.num_trials,
len(not_none(res.experiment).trials),
len(none_throws(res.experiment).trials),
)
self.assertEqual(
problem.num_trials * 2,
len(not_none(res.experiment).fetch_data().df),
len(none_throws(res.experiment).fetch_data().df),
)

self.assertTrue(np.all(res.score_trace <= 100))
Expand All @@ -365,7 +366,7 @@ def test_benchmark_one_method_problem(self) -> None:
self.assertEqual(len(agg.results), 2)
self.assertTrue(
all(
len(not_none(result.experiment).trials) == problem.num_trials
len(none_throws(result.experiment).trials) == problem.num_trials
for result in agg.results
),
"All experiments must have 4 trials",
Expand Down Expand Up @@ -457,5 +458,5 @@ def test_replication_with_generation_node(self) -> None:
problem = get_single_objective_benchmark_problem()
res = benchmark_replication(problem=problem, method=method, seed=0)

self.assertEqual(problem.num_trials, len(not_none(res.experiment).trials))
self.assertEqual(problem.num_trials, len(none_throws(res.experiment).trials))
self.assertTrue(np.isnan(res.score_trace).all())
6 changes: 3 additions & 3 deletions ax/benchmark/tests/test_benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws


class TestBenchmarkMethod(TestCase):
Expand All @@ -33,7 +33,7 @@ def test_benchmark_method(self) -> None:

# test that `fit_tracking_metrics` has been correctly set to False
for step in method.generation_strategy._steps:
self.assertFalse(not_none(step.model_kwargs).get("fit_tracking_metrics"))
self.assertFalse(none_throws(step.model_kwargs).get("fit_tracking_metrics"))

self.assertEqual(method.scheduler_options, options)
self.assertEqual(options.max_pending_trials, 1)
Expand All @@ -60,7 +60,7 @@ def test_benchmark_method(self) -> None:
)
for node in method.generation_strategy._nodes:
self.assertFalse(
not_none(node.model_spec_to_gen_from.model_kwargs).get(
none_throws(node.model_spec_to_gen_from.model_kwargs).get(
"fit_tracking_metrics"
)
)
16 changes: 8 additions & 8 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ax.core.types import TCandidateMetadata, TEvaluationOutcome
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import SortableBase
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws


if TYPE_CHECKING:
Expand Down Expand Up @@ -283,7 +283,7 @@ def index(self) -> int:
def status(self) -> TrialStatus:
"""The status of the trial in the experimentation lifecycle."""
self._mark_failed_if_past_TTL()
return not_none(self._status)
return none_throws(self._status)

@status.setter
def status(self, status: TrialStatus) -> None:
Expand Down Expand Up @@ -413,9 +413,9 @@ def run(self) -> BaseTrial:
if self._runner is None:
raise ValueError("No runner set on trial or experiment.")

self.update_run_metadata(not_none(self._runner).run(self))
self.update_run_metadata(none_throws(self._runner).run(self))

if not_none(self._runner).staging_required:
if none_throws(self._runner).staging_required:
self.mark_staged()
else:
self.mark_running()
Expand Down Expand Up @@ -452,7 +452,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial:
self.assign_runner()
if self._runner is None:
raise ValueError("No runner set on trial or experiment.")
runner = not_none(self._runner)
runner = none_throws(self._runner)

self._stop_metadata = runner.stop(self, reason=reason)
self.mark_as(new_status)
Expand Down Expand Up @@ -713,7 +713,7 @@ def mark_abandoned(
Returns:
The trial instance.
"""
if not unsafe and not_none(self._status).is_terminal:
if not unsafe and none_throws(self._status).is_terminal:
raise ValueError("Cannot abandon a trial in a terminal state.")

self._abandoned_reason = reason
Expand Down Expand Up @@ -792,12 +792,12 @@ def _mark_failed_if_past_TTL(self) -> None:
"""If trial has TTL set and is running, check if the TTL has elapsed
and mark the trial failed if so.
"""
if self.ttl_seconds is None or not not_none(self._status).is_running:
if self.ttl_seconds is None or not none_throws(self._status).is_running:
return
time_run_started = self._time_run_started
assert time_run_started is not None
dt = datetime.now() - time_run_started
if dt > timedelta(seconds=not_none(self.ttl_seconds)):
if dt > timedelta(seconds=none_throws(self.ttl_seconds)):
self.mark_failed()

@property
Expand Down
17 changes: 10 additions & 7 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from ax.utils.common.docutils import copy_doc
from ax.utils.common.equality import datetime_equals, equality_typechecker
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -317,7 +318,9 @@ def add_generator_run(
generator_run.index = len(self._generator_run_structs) - 1

if self.status_quo is not None and self.optimize_for_power:
self.set_status_quo_and_optimize_power(status_quo=not_none(self.status_quo))
self.set_status_quo_and_optimize_power(
status_quo=none_throws(self.status_quo)
)

if generator_run._generation_step_index is not None:
self._set_generation_step_index(
Expand Down Expand Up @@ -403,7 +406,7 @@ def set_status_quo_and_optimize_power(self, status_quo: Arm) -> BatchTrial:
return self

# arm_weights should always have at least one arm now
arm_weights = not_none(self.arm_weights)
arm_weights = none_throws(self.arm_weights)
sum_weights = sum(w for arm, w in arm_weights.items() if arm != status_quo)
optimal_status_quo_weight_override = np.sqrt(sum_weights)
self.set_status_quo_with_weight(
Expand Down Expand Up @@ -480,7 +483,7 @@ def is_factorial(self) -> bool:
param_levels: defaultdict[str, dict[str | float, int]] = defaultdict(dict)
for arm in self.arms:
for param_name, param_value in arm.parameters.items():
param_levels[param_name][not_none(param_value)] = 1
param_levels[param_name][none_throws(param_value)] = 1
param_cardinality = 1
for param_values in param_levels.values():
param_cardinality *= len(param_values)
Expand Down Expand Up @@ -701,7 +704,7 @@ def _get_candidate_metadata(self, arm_name: str) -> TCandidateMetadata:
for gr_struct in self._generator_run_structs:
gr = gr_struct.generator_run
if gr and gr.candidate_metadata_by_arm_signature and arm in gr.arms:
return not_none(gr.candidate_metadata_by_arm_signature).get(
return none_throws(gr.candidate_metadata_by_arm_signature).get(
arm.signature
)
return None
Expand All @@ -710,8 +713,8 @@ def _validate_batch_trial_data(self, data: Data) -> None:
"""Utility function to validate batch data before further processing."""
if (
self.status_quo
and not_none(self.status_quo).name in self.arms_by_name
and not_none(self.status_quo).name not in data.df["arm_name"].values
and none_throws(self.status_quo).name in self.arms_by_name
and none_throws(self.status_quo).name not in data.df["arm_name"].values
):
raise AxError(
f"Trial #{self.index} was completed with data that did "
Expand Down
5 changes: 3 additions & 2 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
TClassDecoderRegistry,
TDecoderRegistry,
)
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

TBaseData = TypeVar("TBaseData", bound="BaseData")
DF_REPR_MAX_LENGTH = 1000
Expand Down Expand Up @@ -235,7 +236,7 @@ def df_hash(self) -> str:
str: The hash of the DataFrame.
"""
return md5(not_none(self.df.to_json()).encode("utf-8")).hexdigest()
return md5(none_throws(self.df.to_json()).encode("utf-8")).hexdigest()

def get_filtered_results(
self: TBaseData, **filters: dict[str, Any]
Expand Down
17 changes: 9 additions & 8 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.result import Err, Ok
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

logger: logging.Logger = get_logger(__name__)

Expand Down Expand Up @@ -366,7 +367,7 @@ def is_moo_problem(self) -> bool:
"""Whether the experiment's optimization config contains multiple objectives."""
if self.optimization_config is None:
return False
return not_none(self.optimization_config).is_moo_problem
return none_throws(self.optimization_config).is_moo_problem

@property
def data_by_trial(self) -> dict[int, OrderedDict[int, Data]]:
Expand Down Expand Up @@ -1296,7 +1297,7 @@ def warm_start_from_old_experiment(
"Only experiments with 1-arm trials currently supported."
)
self.search_space.check_membership(
not_none(trial.arm).parameters,
none_throws(trial.arm).parameters,
raise_error=search_space_check_membership_raise_error,
)
dat, ts = old_experiment.lookup_data_for_trial(trial_index=trial.index)
Expand Down Expand Up @@ -1331,7 +1332,7 @@ def warm_start_from_old_experiment(
{trial.index: new_trial.index}, inplace=True
)
new_df["arm_name"].replace(
{not_none(trial.arm).name: not_none(new_trial.arm).name},
{none_throws(trial.arm).name: none_throws(new_trial.arm).name},
inplace=True,
)
# Attach updated data to new trial on experiment.
Expand Down Expand Up @@ -1783,20 +1784,20 @@ def add_arm_and_prevent_naming_collision(
# the arm name. Preserves all names not matching the automatic naming format.
# experiment is not named, clear the arm's name.
# `arm_index` is 0 since all trials are single-armed.
old_arm_name = not_none(old_trial.arm).name
old_arm_name = none_throws(old_trial.arm).name
has_default_name = bool(old_arm_name == old_trial._get_default_name(arm_index=0))
if has_default_name:
new_arm = not_none(old_trial.arm).clone(clear_name=True)
new_arm = none_throws(old_trial.arm).clone(clear_name=True)
if old_experiment_name is not None:
new_arm.name = f"{old_arm_name}_{old_experiment_name}"
new_trial.add_arm(new_arm)
else:
try:
new_trial.add_arm(not_none(old_trial.arm).clone(clear_name=False))
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=False))
except ValueError as e:
warnings.warn(
f"Attaching arm {old_trial.arm} to trial {new_trial} while preserving "
f"its name failed with error: {e}. Retrying with `clear_name=True`.",
stacklevel=2,
)
new_trial.add_arm(not_none(old_trial.arm).clone(clear_name=True))
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True))
4 changes: 2 additions & 2 deletions ax/core/generation_strategy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import AxError, UnsupportedError
from ax.utils.common.base import Base
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws


class GenerationStrategyInterface(ABC, Base):
Expand Down Expand Up @@ -148,7 +148,7 @@ def experiment(self) -> Experiment:
"""Experiment, currently set on this generation strategy."""
if self._experiment is None:
raise AxError("No experiment set on generation strategy.")
return not_none(self._experiment)
return none_throws(self._experiment)

@experiment.setter
def experiment(self, experiment: Experiment) -> None:
Expand Down
4 changes: 2 additions & 2 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import Base, SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -303,7 +303,7 @@ def model_predictions_by_arm(self) -> dict[str, TModelPredictArm] | None:
predictions: dict[str, TModelPredictArm] = {}
for idx, cond in enumerate(self.arms):
predictions[cond.signature] = extract_arm_predictions(
model_predictions=not_none(self._model_predictions), arm_idx=idx
model_predictions=none_throws(self._model_predictions), arm_idx=idx
)
return predictions

Expand Down
6 changes: 3 additions & 3 deletions ax/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.exceptions.core import UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self, metric: Metric, minimize: bool | None = None) -> None:
f"{minimize=}."
)
self._metric: Metric = metric
self.minimize: bool = not_none(minimize)
self.minimize: bool = none_throws(minimize)

@property
def metric(self) -> Metric:
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
objectives.append(Objective(metric=metric, minimize=minimize))

# pyre-fixme[4]: Attribute must be annotated.
self._objectives = not_none(objectives)
self._objectives = none_throws(objectives)

# For now, assume all objectives are weighted equally.
# This might be used in the future to change emphasis on the
Expand Down
5 changes: 3 additions & 2 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -308,7 +309,7 @@ def _observations_from_dataframe(
metadata = trial._get_candidate_metadata(arm_name) or {}
if Keys.TRIAL_COMPLETION_TIMESTAMP not in metadata:
if trial._time_completed is not None:
metadata[Keys.TRIAL_COMPLETION_TIMESTAMP] = not_none(
metadata[Keys.TRIAL_COMPLETION_TIMESTAMP] = none_throws(
trial._time_completed
).timestamp()
obs_kwargs[Keys.METADATA] = metadata
Expand Down
Loading

0 comments on commit 9eeec96

Please sign in to comment.