From 6bdf0a5a13d097c247655644b3f541e7d906ebac Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 18 Oct 2024 14:45:01 -0700 Subject: [PATCH] Configs for Ax API entry point class (#2918) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2918 The new Ax API will rely on configs to bundle together related information while the user sets up their optimization. These classes will have a stable and backwards compatible API (in contrast to their ax.core counterparts), better reflect how a user conceptualizes their optimization without leaking in "implementation details", and be a centralized place for validation. Reviewed By: lena-kashtelyan Differential Revision: D58022351 fbshipit-source-id: 2fc434c056a82262568cf8463d4358130c36e98c --- ax/preview/api/configs.py | 108 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 ax/preview/api/configs.py diff --git a/ax/preview/api/configs.py b/ax/preview/api/configs.py new file mode 100644 index 00000000000..5f544f53d69 --- /dev/null +++ b/ax/preview/api/configs.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Union + +from ax.core.types import TParamValue + +# Note: I'm not sold these should be dataclasses, just using this as a placeholder + + +class DomainType(Enum): + """ + The DomainType enum allows the ParameterConfig to know whether to expect inputs for + a RangeParameter or ChoiceParameter (or FixedParameter) during the parameter + instantiation and validation process. + """ + + RANGE = "range" + CHOICE = "choice" + + +class ParameterType(Enum): + """ + The ParameterType enum allows users to specify the type of a parameter. + """ + + INT = "int" + FLOAT = "float" + BOOL = "bool" + STR = "str" + + +class ParameterScaling(Enum): + """ + The ParameterScaling enum allows users to specify which scaling to apply during + candidate generation. This is useful for parameters that should not be explored + on the same scale, such as learning rates and batch sizes. + """ + + LINEAR = "linear" + LOG = "log" + + +@dataclass +class ParameterConfig: + """ + ParameterConfig allows users to specify the parameters of an experiment and will + internally validate the inputs to ensure they are valid for the given DomainType. + """ + + name: str + domain_type: DomainType + parameter_type: ParameterType | None = None + + # Fields for RANGE + bounds: Optional[tuple[float, float]] = None + step_size: Optional[float] = None + scaling: Optional[ParameterScaling] = None + + # Fields for CHOICE ("FIXED" is Choice with len(values) == 1) + values: Optional[Union[List[float], List[str], List[bool]]] = None + is_ordered: Optional[bool] = None + dependent_parameters: Optional[Dict[TParamValue, str]] = None + + +@dataclass +class ExperimentConfig: + """ + ExperimentConfig allows users to specify the SearchSpace and OptimizationConfig of + an Experiment and validates their inputs jointly. + + This will also be the construct that handles transforming string-based inputs (the + objective, parameter constraints, and output constraints) into their corresponding + Ax class using SymPy. + """ + + name: str + parameters: List[ParameterConfig] + # Parameter constraints will be parsed via SymPy + # Ex: "num_layers1 <= num_layers2", "compound_a + compound_b <= 1" + parameter_constraints: List[str] = field(default_factory=list) + + description: str | None = None + owner: str | None = None + + +@dataclass +class GenerationStrategyConfig: + # This will hold the args to choose_generation_strategy + num_trials: Optional[int] = None + num_initialization_trials: Optional[int] = None + maximum_parallelism: Optional[int] = None + + +@dataclass +class OrchestrationConfig: + parallelism: int = 1 + tolerated_trial_failure_rate: float = 0.5 + seconds_between_polls: float = 1.0 + + +@dataclass +class DatabaseConfig: + url: str