Skip to content

Commit

Permalink
Reap AX_OBJECT_FIELD_OVERRIDES (#2896)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2896

Reviewed By: lena-kashtelyan

Differential Revision: D64484546
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 18, 2024
1 parent 202dc33 commit e2f1d53
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 41 deletions.
9 changes: 1 addition & 8 deletions ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Iterable

from logging import INFO, Logger
from typing import Any, Optional
from typing import Optional

from ax.analysis.analysis import AnalysisCard

Expand Down Expand Up @@ -95,12 +95,6 @@ class WithDBSettingsBase:

_db_settings: Optional[DBSettings] = None

# Mapping of object types to mapping of fields to override values
# loaded objects will all be instantiated with fields set to
# override value
# current valid object types are: "runner"
AX_OBJECT_FIELD_OVERRIDES: dict[str, Any] = {}

def __init__(
self,
db_settings: Optional[DBSettings] = None,
Expand Down Expand Up @@ -256,7 +250,6 @@ def _load_experiment_and_generation_strategy(
decoder=self.db_settings.decoder,
reduced_state=reduced_state,
load_trials_in_batches_of_size=LOADING_MINI_BATCH_SIZE,
ax_object_field_overrides=self.AX_OBJECT_FIELD_OVERRIDES,
skip_runners_and_metrics=skip_runners_and_metrics,
)
if not isinstance(experiment, Experiment):
Expand Down
30 changes: 3 additions & 27 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from enum import Enum
from io import StringIO
from logging import Logger
from typing import Any, cast, Union
from typing import cast, Union

import pandas as pd
from ax.analysis.analysis import AnalysisCard
Expand Down Expand Up @@ -138,7 +138,6 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa(
def _init_experiment_from_sqa(
self,
experiment_sqa: SQAExperiment,
ax_object_field_overrides: dict[str, Any] | None = None,
load_auxiliary_experiments: bool = True,
) -> Experiment:
"""First step of conversion within experiment_from_sqa."""
Expand All @@ -162,14 +161,7 @@ def _init_experiment_from_sqa(
if len(experiment_sqa.runners) == 0:
runner = None
elif len(experiment_sqa.runners) == 1:
runner_kwargs = (
ax_object_field_overrides.get("runner")
if ax_object_field_overrides is not None
else None
)
runner = self.runner_from_sqa(
runner_sqa=experiment_sqa.runners[0], runner_kwargs=runner_kwargs
)
runner = self.runner_from_sqa(runner_sqa=experiment_sqa.runners[0])
else:
raise ValueError(
"Multiple runners on experiment "
Expand Down Expand Up @@ -259,7 +251,6 @@ def experiment_from_sqa(
self,
experiment_sqa: SQAExperiment,
reduced_state: bool = False,
ax_object_field_overrides: dict[str, Any] | None = None,
load_auxiliary_experiments: bool = True,
) -> Experiment:
"""Convert SQLAlchemy Experiment to Ax Experiment.
Expand All @@ -269,10 +260,6 @@ def experiment_from_sqa(
reduced_state: Whether to load experiment with a slightly reduced state
(without abandoned arms on experiment and without model state,
search space, and optimization config on generator runs).
ax_object_field_overrides: Mapping of object types to mapping of fields
to override values loaded objects will all be instantiated with fields
set to override value
current valid object types are: "runner"
load_auxiliary_experiment: whether to load auxiliary experiments.
"""
subclass = (experiment_sqa.properties or {}).get(Keys.SUBCLASS)
Expand All @@ -281,15 +268,13 @@ def experiment_from_sqa(
else:
experiment = self._init_experiment_from_sqa(
experiment_sqa,
ax_object_field_overrides=ax_object_field_overrides,
load_auxiliary_experiments=load_auxiliary_experiments,
)
trials = [
self.trial_from_sqa(
trial_sqa=trial,
experiment=experiment,
reduced_state=reduced_state,
ax_object_field_overrides=ax_object_field_overrides,
)
for trial in experiment_sqa.trials
]
Expand Down Expand Up @@ -870,9 +855,7 @@ def generation_strategy_from_sqa(

return gs

def runner_from_sqa(
self, runner_sqa: SQARunner, runner_kwargs: dict[str, Any] | None = None
) -> Runner:
def runner_from_sqa(self, runner_sqa: SQARunner) -> Runner:
"""Convert SQLAlchemy Runner to Ax Runner."""
if runner_sqa.runner_type not in self.config.reverse_runner_registry:
raise SQADecodeError(
Expand All @@ -886,7 +869,6 @@ def runner_from_sqa(
decoder_registry=self.config.json_decoder_registry,
class_decoder_registry=self.config.json_class_decoder_registry,
)
args.update(runner_kwargs or {})
# pyre-ignore[45]: Cannot instantiate abstract class `Runner`.
runner = runner_class(**args)
runner.db_id = runner_sqa.id
Expand All @@ -897,7 +879,6 @@ def trial_from_sqa(
trial_sqa: SQATrial,
experiment: Experiment,
reduced_state: bool = False,
ax_object_field_overrides: dict[str, Any] | None = None,
) -> BaseTrial:
"""Convert SQLAlchemy Trial to Ax Trial.
Expand Down Expand Up @@ -1004,11 +985,6 @@ def trial_from_sqa(
trial._runner = (
self.runner_from_sqa(
trial_sqa.runner,
runner_kwargs=(
ax_object_field_overrides.get("runner")
if ax_object_field_overrides is not None
else None
),
)
if trial_sqa.runner
else None
Expand Down
6 changes: 0 additions & 6 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def _load_experiment(
decoder: Decoder,
reduced_state: bool = False,
load_trials_in_batches_of_size: int | None = None,
ax_object_field_overrides: dict[str, Any] | None = None,
skip_runners_and_metrics: bool = False,
load_auxiliary_experiments: bool = True,
) -> Experiment:
Expand All @@ -100,10 +99,6 @@ def _load_experiment(
reduced_state: Whether to load experiment and generation strategy
load_trials_in_batches_of_size: Number of trials to be fetched from database
per batch
ax_object_field_overrides: Mapping of object types to mapping of fields
to override values loaded objects will all be instantiated with fields
set to override value
current valid object types are: "runner"
load_auxiliary_experiments: whether to load auxiliary experiments.
"""

Expand Down Expand Up @@ -170,7 +165,6 @@ def _load_experiment(
return decoder.experiment_from_sqa(
experiment_sqa=experiment_sqa,
reduced_state=reduced_state,
ax_object_field_overrides=ax_object_field_overrides,
load_auxiliary_experiments=load_auxiliary_experiments,
)

Expand Down

0 comments on commit e2f1d53

Please sign in to comment.