Skip to content

Commit

Permalink
Update OSS reference to adaptive_experiment to instead use ax (#1867)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1867

# This diff
Update OSS reference to `adaptive_experiment` to instead use `ax`.

Reviewed By: lena-kashtelyan

Differential Revision: D49471512

fbshipit-source-id: 0affc95d0d75dca7c1a3376c514504e2113c0023
  • Loading branch information
Cesar-Cardoso authored and facebook-github-bot committed Sep 22, 2023
1 parent 4153a57 commit 24c9265
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
32 changes: 7 additions & 25 deletions ax/storage/sqa_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

from __future__ import annotations

from contextlib import contextmanager, nullcontext
from os import remove as remove_file
from typing import Any, Callable, ContextManager, Generator, Optional, TypeVar
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional, TypeVar

from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
Expand Down Expand Up @@ -202,11 +201,6 @@ def init_test_engine_and_session_factory(
)


def remove_test_db_file(tier_or_path: str) -> None:
"""Remove the test DB file from system, useful for cleanup in tests."""
remove_file(tier_or_path)


def create_all_tables(engine: Engine) -> None:
"""Create all tables that inherit from Base.
Expand All @@ -219,21 +213,17 @@ def create_all_tables(engine: Engine) -> None:
define a mapped class that inherits from `Base` must be imported.
"""
if (
engine.dialect.name == "mysql"
and engine.dialect.default_schema_name == "adaptive_experiment"
):
raise Exception("Cannot mutate tables in XDB. Use AOSC.")
if engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "ax":
raise ValueError(
"The open-source Ax table creation is likely not applicable in this case,"
+ "please contact the Adaptive Experimentation team if you need help."
)
Base.metadata.create_all(engine)


def get_session() -> Session:
"""Fetch a SQLAlchemy session with a connection to a DB.
Unless `init_engine_and_session_factory` is called first with custom
args, this will automatically initialize a connection to
`xdb.adaptive_experiment`.
Returns:
Session: an instance of a SQLAlchemy session.
Expand Down Expand Up @@ -274,11 +264,3 @@ def session_scope() -> Generator[Session, None, None]:
raise
finally:
session.close()


def optional_session_scope(
session: Optional[Session] = None,
) -> ContextManager[Session]:
if session is not None:
return nullcontext(session)
return session_scope()
16 changes: 16 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from ax.storage.registry_bundle import RegistryBundle
from ax.storage.runner_registry import CORE_RUNNER_REGISTRY, register_runner
from ax.storage.sqa_store.db import (
create_all_tables,
create_test_engine,
get_engine,
get_session,
init_engine_and_session_factory,
Expand Down Expand Up @@ -1687,3 +1689,17 @@ def test_GeneratorRunValidatedFields(self) -> None:
newly_loaded_gr = newly_loaded_exp.trials.get(0).generator_run
for instrumented_attr in GR_LARGE_MODEL_ATTRS:
self.assertIsNotNone(getattr(newly_loaded_gr, f"_{instrumented_attr.key}"))

@patch("ax.storage.sqa_store.db.SESSION_FACTORY", None)
def test_MissingSessionFactory(self) -> None:
with self.assertRaises(ValueError):
get_session()
with self.assertRaises(ValueError):
get_engine()

def test_CreateAllTablesException(self) -> None:
engine = create_test_engine()
engine.dialect.name = "mysql"
engine.dialect.default_schema_name = "ax"
with self.assertRaises(ValueError):
create_all_tables(engine)

0 comments on commit 24c9265

Please sign in to comment.