Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete not_none #3003

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ax.storage.json_store.load import load_experiment
from ax.storage.json_store.save import save_experiment
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.testing.benchmark_stubs import (
get_moo_surrogate,
get_multi_objective_benchmark_problem,
Expand Down Expand Up @@ -103,7 +102,7 @@ def test_storage(self) -> None:

# test saving to temporary file
with tempfile.NamedTemporaryFile(mode="w", delete=True, suffix=".json") as f:
save_experiment(not_none(res.experiment), f.name)
save_experiment(none_throws(res.experiment), f.name)
res.experiment_storage_id = f.name
res.experiment = None
self.assertIsNone(res.experiment)
Expand Down Expand Up @@ -155,7 +154,9 @@ def test_replication_sobol_synthetic(self) -> None:
for problem in problems:
res = benchmark_replication(problem=problem, method=method, seed=0)

self.assertEqual(problem.num_trials, len(not_none(res.experiment).trials))
self.assertEqual(
problem.num_trials, len(none_throws(res.experiment).trials)
)
self.assertTrue(np.isfinite(res.score_trace).all())
self.assertTrue(np.all(res.score_trace <= 100))
experiment = none_throws(res.experiment)
Expand All @@ -181,7 +182,7 @@ def test_replication_sobol_surrogate(self) -> None:

self.assertEqual(
problem.num_trials,
len(not_none(res.experiment).trials),
len(none_throws(res.experiment).trials),
)

self.assertTrue(np.isfinite(res.score_trace).all())
Expand Down Expand Up @@ -340,7 +341,7 @@ def test_replication_mbm(self) -> None:
res = benchmark_replication(problem=problem, method=method, seed=0)
self.assertEqual(
problem.num_trials,
len(not_none(res.experiment).trials),
len(none_throws(res.experiment).trials),
)
self.assertTrue(np.all(res.score_trace <= 100))
self.assertEqual(method.name, method.generation_strategy.name)
Expand All @@ -357,11 +358,11 @@ def test_replication_moo_sobol(self) -> None:

self.assertEqual(
problem.num_trials,
len(not_none(res.experiment).trials),
len(none_throws(res.experiment).trials),
)
self.assertEqual(
problem.num_trials * 2,
len(not_none(res.experiment).fetch_data().df),
len(none_throws(res.experiment).fetch_data().df),
)

self.assertTrue(np.all(res.score_trace <= 100))
Expand All @@ -377,7 +378,7 @@ def test_benchmark_one_method_problem(self) -> None:
self.assertEqual(len(agg.results), 2)
self.assertTrue(
all(
len(not_none(result.experiment).trials) == problem.num_trials
len(none_throws(result.experiment).trials) == problem.num_trials
for result in agg.results
),
"All experiments must have 4 trials",
Expand Down Expand Up @@ -469,5 +470,5 @@ def test_replication_with_generation_node(self) -> None:
problem = get_single_objective_benchmark_problem()
res = benchmark_replication(problem=problem, method=method, seed=0)

self.assertEqual(problem.num_trials, len(not_none(res.experiment).trials))
self.assertEqual(problem.num_trials, len(none_throws(res.experiment).trials))
self.assertTrue(np.isnan(res.score_trace).all())
6 changes: 3 additions & 3 deletions ax/benchmark/tests/test_benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws


class TestBenchmarkMethod(TestCase):
Expand All @@ -33,7 +33,7 @@ def test_benchmark_method(self) -> None:

# test that `fit_tracking_metrics` has been correctly set to False
for step in method.generation_strategy._steps:
self.assertFalse(not_none(step.model_kwargs).get("fit_tracking_metrics"))
self.assertFalse(none_throws(step.model_kwargs).get("fit_tracking_metrics"))

self.assertEqual(method.scheduler_options, options)
self.assertEqual(options.max_pending_trials, 1)
Expand All @@ -60,7 +60,7 @@ def test_benchmark_method(self) -> None:
)
for node in method.generation_strategy._nodes:
self.assertFalse(
not_none(node.model_spec_to_gen_from.model_kwargs).get(
none_throws(node.model_spec_to_gen_from.model_kwargs).get(
"fit_tracking_metrics"
)
)
16 changes: 8 additions & 8 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ax.core.types import TCandidateMetadata, TEvaluationOutcome
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import SortableBase
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws


if TYPE_CHECKING:
Expand Down Expand Up @@ -283,7 +283,7 @@ def index(self) -> int:
def status(self) -> TrialStatus:
"""The status of the trial in the experimentation lifecycle."""
self._mark_failed_if_past_TTL()
return not_none(self._status)
return none_throws(self._status)

@status.setter
def status(self, status: TrialStatus) -> None:
Expand Down Expand Up @@ -413,9 +413,9 @@ def run(self) -> BaseTrial:
if self._runner is None:
raise ValueError("No runner set on trial or experiment.")

self.update_run_metadata(not_none(self._runner).run(self))
self.update_run_metadata(none_throws(self._runner).run(self))

if not_none(self._runner).staging_required:
if none_throws(self._runner).staging_required:
self.mark_staged()
else:
self.mark_running()
Expand Down Expand Up @@ -452,7 +452,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial:
self.assign_runner()
if self._runner is None:
raise ValueError("No runner set on trial or experiment.")
runner = not_none(self._runner)
runner = none_throws(self._runner)

self._stop_metadata = runner.stop(self, reason=reason)
self.mark_as(new_status)
Expand Down Expand Up @@ -713,7 +713,7 @@ def mark_abandoned(
Returns:
The trial instance.
"""
if not unsafe and not_none(self._status).is_terminal:
if not unsafe and none_throws(self._status).is_terminal:
raise ValueError("Cannot abandon a trial in a terminal state.")

self._abandoned_reason = reason
Expand Down Expand Up @@ -792,12 +792,12 @@ def _mark_failed_if_past_TTL(self) -> None:
"""If trial has TTL set and is running, check if the TTL has elapsed
and mark the trial failed if so.
"""
if self.ttl_seconds is None or not not_none(self._status).is_running:
if self.ttl_seconds is None or not none_throws(self._status).is_running:
return
time_run_started = self._time_run_started
assert time_run_started is not None
dt = datetime.now() - time_run_started
if dt > timedelta(seconds=not_none(self.ttl_seconds)):
if dt > timedelta(seconds=none_throws(self.ttl_seconds)):
self.mark_failed()

@property
Expand Down
17 changes: 10 additions & 7 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from ax.utils.common.docutils import copy_doc
from ax.utils.common.equality import datetime_equals, equality_typechecker
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -317,7 +318,9 @@ def add_generator_run(
generator_run.index = len(self._generator_run_structs) - 1

if self.status_quo is not None and self.optimize_for_power:
self.set_status_quo_and_optimize_power(status_quo=not_none(self.status_quo))
self.set_status_quo_and_optimize_power(
status_quo=none_throws(self.status_quo)
)

if generator_run._generation_step_index is not None:
self._set_generation_step_index(
Expand Down Expand Up @@ -403,7 +406,7 @@ def set_status_quo_and_optimize_power(self, status_quo: Arm) -> BatchTrial:
return self

# arm_weights should always have at least one arm now
arm_weights = not_none(self.arm_weights)
arm_weights = none_throws(self.arm_weights)
sum_weights = sum(w for arm, w in arm_weights.items() if arm != status_quo)
optimal_status_quo_weight_override = np.sqrt(sum_weights)
self.set_status_quo_with_weight(
Expand Down Expand Up @@ -480,7 +483,7 @@ def is_factorial(self) -> bool:
param_levels: defaultdict[str, dict[str | float, int]] = defaultdict(dict)
for arm in self.arms:
for param_name, param_value in arm.parameters.items():
param_levels[param_name][not_none(param_value)] = 1
param_levels[param_name][none_throws(param_value)] = 1
param_cardinality = 1
for param_values in param_levels.values():
param_cardinality *= len(param_values)
Expand Down Expand Up @@ -701,7 +704,7 @@ def _get_candidate_metadata(self, arm_name: str) -> TCandidateMetadata:
for gr_struct in self._generator_run_structs:
gr = gr_struct.generator_run
if gr and gr.candidate_metadata_by_arm_signature and arm in gr.arms:
return not_none(gr.candidate_metadata_by_arm_signature).get(
return none_throws(gr.candidate_metadata_by_arm_signature).get(
arm.signature
)
return None
Expand All @@ -710,8 +713,8 @@ def _validate_batch_trial_data(self, data: Data) -> None:
"""Utility function to validate batch data before further processing."""
if (
self.status_quo
and not_none(self.status_quo).name in self.arms_by_name
and not_none(self.status_quo).name not in data.df["arm_name"].values
and none_throws(self.status_quo).name in self.arms_by_name
and none_throws(self.status_quo).name not in data.df["arm_name"].values
):
raise AxError(
f"Trial #{self.index} was completed with data that did "
Expand Down
5 changes: 3 additions & 2 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
TClassDecoderRegistry,
TDecoderRegistry,
)
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

TBaseData = TypeVar("TBaseData", bound="BaseData")
DF_REPR_MAX_LENGTH = 1000
Expand Down Expand Up @@ -235,7 +236,7 @@ def df_hash(self) -> str:
str: The hash of the DataFrame.

"""
return md5(not_none(self.df.to_json()).encode("utf-8")).hexdigest()
return md5(none_throws(self.df.to_json()).encode("utf-8")).hexdigest()

def get_filtered_results(
self: TBaseData, **filters: dict[str, Any]
Expand Down
17 changes: 9 additions & 8 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.result import Err, Ok
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

logger: logging.Logger = get_logger(__name__)

Expand Down Expand Up @@ -366,7 +367,7 @@ def is_moo_problem(self) -> bool:
"""Whether the experiment's optimization config contains multiple objectives."""
if self.optimization_config is None:
return False
return not_none(self.optimization_config).is_moo_problem
return none_throws(self.optimization_config).is_moo_problem

@property
def data_by_trial(self) -> dict[int, OrderedDict[int, Data]]:
Expand Down Expand Up @@ -1296,7 +1297,7 @@ def warm_start_from_old_experiment(
"Only experiments with 1-arm trials currently supported."
)
self.search_space.check_membership(
not_none(trial.arm).parameters,
none_throws(trial.arm).parameters,
raise_error=search_space_check_membership_raise_error,
)
dat, ts = old_experiment.lookup_data_for_trial(trial_index=trial.index)
Expand Down Expand Up @@ -1331,7 +1332,7 @@ def warm_start_from_old_experiment(
{trial.index: new_trial.index}, inplace=True
)
new_df["arm_name"].replace(
{not_none(trial.arm).name: not_none(new_trial.arm).name},
{none_throws(trial.arm).name: none_throws(new_trial.arm).name},
inplace=True,
)
# Attach updated data to new trial on experiment.
Expand Down Expand Up @@ -1783,20 +1784,20 @@ def add_arm_and_prevent_naming_collision(
# the arm name. Preserves all names not matching the automatic naming format.
# experiment is not named, clear the arm's name.
# `arm_index` is 0 since all trials are single-armed.
old_arm_name = not_none(old_trial.arm).name
old_arm_name = none_throws(old_trial.arm).name
has_default_name = bool(old_arm_name == old_trial._get_default_name(arm_index=0))
if has_default_name:
new_arm = not_none(old_trial.arm).clone(clear_name=True)
new_arm = none_throws(old_trial.arm).clone(clear_name=True)
if old_experiment_name is not None:
new_arm.name = f"{old_arm_name}_{old_experiment_name}"
new_trial.add_arm(new_arm)
else:
try:
new_trial.add_arm(not_none(old_trial.arm).clone(clear_name=False))
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=False))
except ValueError as e:
warnings.warn(
f"Attaching arm {old_trial.arm} to trial {new_trial} while preserving "
f"its name failed with error: {e}. Retrying with `clear_name=True`.",
stacklevel=2,
)
new_trial.add_arm(not_none(old_trial.arm).clone(clear_name=True))
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True))
4 changes: 2 additions & 2 deletions ax/core/generation_strategy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import AxError, UnsupportedError
from ax.utils.common.base import Base
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws


class GenerationStrategyInterface(ABC, Base):
Expand Down Expand Up @@ -148,7 +148,7 @@ def experiment(self) -> Experiment:
"""Experiment, currently set on this generation strategy."""
if self._experiment is None:
raise AxError("No experiment set on generation strategy.")
return not_none(self._experiment)
return none_throws(self._experiment)

@experiment.setter
def experiment(self, experiment: Experiment) -> None:
Expand Down
4 changes: 2 additions & 2 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import Base, SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -303,7 +303,7 @@ def model_predictions_by_arm(self) -> dict[str, TModelPredictArm] | None:
predictions: dict[str, TModelPredictArm] = {}
for idx, cond in enumerate(self.arms):
predictions[cond.signature] = extract_arm_predictions(
model_predictions=not_none(self._model_predictions), arm_idx=idx
model_predictions=none_throws(self._model_predictions), arm_idx=idx
)
return predictions

Expand Down
6 changes: 3 additions & 3 deletions ax/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.exceptions.core import UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self, metric: Metric, minimize: bool | None = None) -> None:
f"{minimize=}."
)
self._metric: Metric = metric
self.minimize: bool = not_none(minimize)
self.minimize: bool = none_throws(minimize)

@property
def metric(self) -> Metric:
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
objectives.append(Objective(metric=metric, minimize=minimize))

# pyre-fixme[4]: Attribute must be annotated.
self._objectives = not_none(objectives)
self._objectives = none_throws(objectives)

# For now, assume all objectives are weighted equally.
# This might be used in the future to change emphasis on the
Expand Down
5 changes: 3 additions & 2 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -308,7 +309,7 @@ def _observations_from_dataframe(
metadata = trial._get_candidate_metadata(arm_name) or {}
if Keys.TRIAL_COMPLETION_TIMESTAMP not in metadata:
if trial._time_completed is not None:
metadata[Keys.TRIAL_COMPLETION_TIMESTAMP] = not_none(
metadata[Keys.TRIAL_COMPLETION_TIMESTAMP] = none_throws(
trial._time_completed
).timestamp()
obs_kwargs[Keys.METADATA] = metadata
Expand Down
Loading