From 24c92655af1546d1fa4e4ce3236313da9b3edde4 Mon Sep 17 00:00:00 2001 From: Cesar Cardoso Date: Fri, 22 Sep 2023 10:19:32 -0700 Subject: [PATCH] Update OSS reference to adaptive_experiment to instead use ax (#1867) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1867 # This diff Update OSS reference to `adaptive_experiment` to instead use `ax`. Reviewed By: lena-kashtelyan Differential Revision: D49471512 fbshipit-source-id: 0affc95d0d75dca7c1a3376c514504e2113c0023 --- ax/storage/sqa_store/db.py | 32 +++++--------------- ax/storage/sqa_store/tests/test_sqa_store.py | 16 ++++++++++ 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/ax/storage/sqa_store/db.py b/ax/storage/sqa_store/db.py index acea045698f..49f558d90bc 100644 --- a/ax/storage/sqa_store/db.py +++ b/ax/storage/sqa_store/db.py @@ -6,9 +6,8 @@ from __future__ import annotations -from contextlib import contextmanager, nullcontext -from os import remove as remove_file -from typing import Any, Callable, ContextManager, Generator, Optional, TypeVar +from contextlib import contextmanager +from typing import Any, Callable, Generator, Optional, TypeVar from sqlalchemy import create_engine from sqlalchemy.engine.base import Engine @@ -202,11 +201,6 @@ def init_test_engine_and_session_factory( ) -def remove_test_db_file(tier_or_path: str) -> None: - """Remove the test DB file from system, useful for cleanup in tests.""" - remove_file(tier_or_path) - - def create_all_tables(engine: Engine) -> None: """Create all tables that inherit from Base. @@ -219,21 +213,17 @@ def create_all_tables(engine: Engine) -> None: define a mapped class that inherits from `Base` must be imported. """ - if ( - engine.dialect.name == "mysql" - and engine.dialect.default_schema_name == "adaptive_experiment" - ): - raise Exception("Cannot mutate tables in XDB. Use AOSC.") + if engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "ax": + raise ValueError( + "The open-source Ax table creation is likely not applicable in this case," + + "please contact the Adaptive Experimentation team if you need help." + ) Base.metadata.create_all(engine) def get_session() -> Session: """Fetch a SQLAlchemy session with a connection to a DB. - Unless `init_engine_and_session_factory` is called first with custom - args, this will automatically initialize a connection to - `xdb.adaptive_experiment`. - Returns: Session: an instance of a SQLAlchemy session. @@ -274,11 +264,3 @@ def session_scope() -> Generator[Session, None, None]: raise finally: session.close() - - -def optional_session_scope( - session: Optional[Session] = None, -) -> ContextManager[Session]: - if session is not None: - return nullcontext(session) - return session_scope() diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index cbd74c0acd6..f80598ee6e1 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -27,6 +27,8 @@ from ax.storage.registry_bundle import RegistryBundle from ax.storage.runner_registry import CORE_RUNNER_REGISTRY, register_runner from ax.storage.sqa_store.db import ( + create_all_tables, + create_test_engine, get_engine, get_session, init_engine_and_session_factory, @@ -1687,3 +1689,17 @@ def test_GeneratorRunValidatedFields(self) -> None: newly_loaded_gr = newly_loaded_exp.trials.get(0).generator_run for instrumented_attr in GR_LARGE_MODEL_ATTRS: self.assertIsNotNone(getattr(newly_loaded_gr, f"_{instrumented_attr.key}")) + + @patch("ax.storage.sqa_store.db.SESSION_FACTORY", None) + def test_MissingSessionFactory(self) -> None: + with self.assertRaises(ValueError): + get_session() + with self.assertRaises(ValueError): + get_engine() + + def test_CreateAllTablesException(self) -> None: + engine = create_test_engine() + engine.dialect.name = "mysql" + engine.dialect.default_schema_name = "ax" + with self.assertRaises(ValueError): + create_all_tables(engine)