diff --git a/alembic.ini b/alembic.ini index d9f94e0c..70747ace 100644 --- a/alembic.ini +++ b/alembic.ini @@ -60,8 +60,13 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # are written from script.py.mako # output_encoding = utf-8 -sqlalchemy.url = driver://user:pass@localhost/dbname - +drivername = +username = +password = +host = +port = +database = +sqlalchemy.url = %(drivername)s://%(username)s:%(password)s@%(host)s:%(port)s/%(database)s [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run diff --git a/alembic/env.py b/alembic/env.py index c7c0ddd7..66997211 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -33,7 +33,11 @@ def run_migrations_offline() -> None: Calls to alembic.context.execute() here emit the given string to the script output. """ - url = config.get_main_option("sqlalchemy.url") + url_props = dict() + for prop in ["drivername", "username", "password", "host", "port", "database"]: + url_props[prop] = config.get_main_option(prop) + url_props["port"] = url_props["port"] and int(url_props["port"]) or None # type: ignore + url = sa.engine.URL.create(**url_props) # type: ignore alembic.context.configure( url=url, target_metadata=cads_broker.database.BaseModel.metadata, @@ -50,13 +54,13 @@ def run_migrations_online() -> None: In this scenario we need to create an Engine and associate a connection with the alembic.context. """ - connectable = sa.engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=sa.pool.NullPool, - ) - - with connectable.connect() as connection: + url_props = dict() + for prop in ["drivername", "username", "password", "host", "port", "database"]: + url_props[prop] = config.get_main_option(prop) + url_props["port"] = url_props["port"] and int(url_props["port"]) or None # type: ignore + url = sa.engine.URL.create(**url_props) # type: ignore + engine = sa.create_engine(url, poolclass=sa.pool.NullPool) + with engine.connect() as connection: alembic.context.configure( connection=connection, target_metadata=cads_broker.database.BaseModel.metadata, diff --git a/cads_broker/config.py b/cads_broker/config.py index 6a64128b..4ddfd4ed 100644 --- a/cads_broker/config.py +++ b/cads_broker/config.py @@ -19,6 +19,7 @@ import pydantic import pydantic_core import pydantic_settings +import sqlalchemy as sa import structlog dbsettings = None @@ -63,20 +64,28 @@ def db_connection_env_vars_must_be_set( @property def connection_string(self) -> str: """Create reader psql connection string.""" - return ( - f"postgresql://{self.compute_db_user}" - f":{self.compute_db_password}@{self.compute_db_host}" - f"/{self.compute_db_name}" + url = sa.engine.URL.create( + drivername="postgresql", + username=self.compute_db_user, + password=self.compute_db_password, + host=self.compute_db_host, + database=self.compute_db_name, ) + ret_value = url.render_as_string(False) + return ret_value @property def connection_string_read(self) -> str: """Create reader psql connection string.""" - return ( - f"postgresql://{self.compute_db_user}" - f":{self.compute_db_password}@{self.compute_db_host_read}" - f"/{self.compute_db_name}" + url = sa.engine.URL.create( + drivername="postgresql", + username=self.compute_db_user, + password=self.compute_db_password, + host=self.compute_db_host_read, + database=self.compute_db_name, ) + ret_value = url.render_as_string(False) + return ret_value def ensure_settings(settings: SqlalchemySettings | None = None) -> SqlalchemySettings: diff --git a/cads_broker/database.py b/cads_broker/database.py index 7761efd5..23c11acf 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -584,7 +584,11 @@ def init_database(connection_string: str, force: bool = False) -> sa.engine.Engi os.chdir(migration_directory) alembic_config_path = os.path.join(migration_directory, "alembic.ini") alembic_cfg = alembic.config.Config(alembic_config_path) - alembic_cfg.set_main_option("sqlalchemy.url", connection_string) + for option in ["drivername", "username", "password", "host", "port", "database"]: + value = getattr(engine.url, option) + if value is None: + value = "" + alembic_cfg.set_main_option(option, str(value)) if not sqlalchemy_utils.database_exists(engine.url): sqlalchemy_utils.create_database(engine.url) # cleanup and create the schema diff --git a/tests/conftest.py b/tests/conftest.py index a35382b4..c2ed4c2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,14 @@ import pytest from psycopg import Connection +from pytest_postgresql import factories from sqlalchemy.orm import sessionmaker from cads_broker import database +postgresql_proc2 = factories.postgresql_proc(password="my@strange@password") +postgresql2 = factories.postgresql("postgresql_proc2") + @pytest.fixture() def session_obj(postgresql: Connection[str]) -> sessionmaker: diff --git a/tests/test_02_database.py b/tests/test_02_database.py index a272ab1c..e9bc54e0 100644 --- a/tests/test_02_database.py +++ b/tests/test_02_database.py @@ -758,6 +758,54 @@ def test_init_database(postgresql: Connection[str]) -> None: conn.close() +def test_init_database_with_password(postgresql2: Connection[str]) -> None: + connection_url = sa.engine.URL.create( + drivername="postgresql+psycopg2", + username=postgresql2.info.user, + password=postgresql2.info.password, + host=postgresql2.info.host, + port=postgresql2.info.port, + database=postgresql2.info.dbname, + ) + connection_string = connection_url.render_as_string(False) + engine = sa.create_engine(connection_string) + conn = engine.connect() + query = sa.text( + "SELECT table_name FROM information_schema.tables WHERE table_schema='public'" + ) + # start with an empty db structure + expected_tables_at_beginning: set[str] = set() + assert set(conn.execute(query).scalars()) == expected_tables_at_beginning # type: ignore + + # verify create structure + db.init_database(connection_string, force=True) + expected_tables_complete = set(db.BaseModel.metadata.tables).union( + {"alembic_version"} + ) + assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore + + adaptor_properties = mock_config() + request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash) + session_obj = sa.orm.sessionmaker(engine) + with session_obj() as session: + session.add(adaptor_properties) + session.add(request) + session.commit() + + db.init_database(connection_string) + assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore + with session_obj() as session: + requests = db.get_accepted_requests(session=session) + assert len(requests) == 1 + + db.init_database(connection_string, force=True) + assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore + with session_obj() as session: + requests = db.get_accepted_requests(session=session) + assert len(requests) == 0 + conn.close() + + def test_ensure_session_obj( postgresql: Connection[str], session_obj: sessionmaker, temp_environ: Any ) -> None: