Skip to content

Commit

Permalink
Merge pull request #93 from ecmwf-projects/bug-connection-string
Browse files Browse the repository at this point in the history
bugfix: right encoding of db connection tokens
  • Loading branch information
alex75 authored Dec 12, 2023
2 parents 0daeb37 + cd424b3 commit d96d36b
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 19 deletions.
9 changes: 7 additions & 2 deletions alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,13 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# are written from script.py.mako
# output_encoding = utf-8

sqlalchemy.url = driver://user:pass@localhost/dbname

drivername =
username =
password =
host =
port =
database =
sqlalchemy.url = %(drivername)s://%(username)s:%(password)s@%(host)s:%(port)s/%(database)s

[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
Expand Down
20 changes: 12 additions & 8 deletions alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def run_migrations_offline() -> None:
Calls to alembic.context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
url_props = dict()
for prop in ["drivername", "username", "password", "host", "port", "database"]:
url_props[prop] = config.get_main_option(prop)
url_props["port"] = url_props["port"] and int(url_props["port"]) or None # type: ignore
url = sa.engine.URL.create(**url_props) # type: ignore
alembic.context.configure(
url=url,
target_metadata=cads_broker.database.BaseModel.metadata,
Expand All @@ -50,13 +54,13 @@ def run_migrations_online() -> None:
In this scenario we need to create an Engine
and associate a connection with the alembic.context.
"""
connectable = sa.engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=sa.pool.NullPool,
)

with connectable.connect() as connection:
url_props = dict()
for prop in ["drivername", "username", "password", "host", "port", "database"]:
url_props[prop] = config.get_main_option(prop)
url_props["port"] = url_props["port"] and int(url_props["port"]) or None # type: ignore
url = sa.engine.URL.create(**url_props) # type: ignore
engine = sa.create_engine(url, poolclass=sa.pool.NullPool)
with engine.connect() as connection:
alembic.context.configure(
connection=connection,
target_metadata=cads_broker.database.BaseModel.metadata,
Expand Down
25 changes: 17 additions & 8 deletions cads_broker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pydantic
import pydantic_core
import pydantic_settings
import sqlalchemy as sa
import structlog

dbsettings = None
Expand Down Expand Up @@ -63,20 +64,28 @@ def db_connection_env_vars_must_be_set(
@property
def connection_string(self) -> str:
"""Create reader psql connection string."""
return (
f"postgresql://{self.compute_db_user}"
f":{self.compute_db_password}@{self.compute_db_host}"
f"/{self.compute_db_name}"
url = sa.engine.URL.create(
drivername="postgresql",
username=self.compute_db_user,
password=self.compute_db_password,
host=self.compute_db_host,
database=self.compute_db_name,
)
ret_value = url.render_as_string(False)
return ret_value

@property
def connection_string_read(self) -> str:
"""Create reader psql connection string."""
return (
f"postgresql://{self.compute_db_user}"
f":{self.compute_db_password}@{self.compute_db_host_read}"
f"/{self.compute_db_name}"
url = sa.engine.URL.create(
drivername="postgresql",
username=self.compute_db_user,
password=self.compute_db_password,
host=self.compute_db_host_read,
database=self.compute_db_name,
)
ret_value = url.render_as_string(False)
return ret_value


def ensure_settings(settings: SqlalchemySettings | None = None) -> SqlalchemySettings:
Expand Down
6 changes: 5 additions & 1 deletion cads_broker/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,11 @@ def init_database(connection_string: str, force: bool = False) -> sa.engine.Engi
os.chdir(migration_directory)
alembic_config_path = os.path.join(migration_directory, "alembic.ini")
alembic_cfg = alembic.config.Config(alembic_config_path)
alembic_cfg.set_main_option("sqlalchemy.url", connection_string)
for option in ["drivername", "username", "password", "host", "port", "database"]:
value = getattr(engine.url, option)
if value is None:
value = ""
alembic_cfg.set_main_option(option, str(value))
if not sqlalchemy_utils.database_exists(engine.url):
sqlalchemy_utils.create_database(engine.url)
# cleanup and create the schema
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

import pytest
from psycopg import Connection
from pytest_postgresql import factories
from sqlalchemy.orm import sessionmaker

from cads_broker import database

postgresql_proc2 = factories.postgresql_proc(password="my@strange@password")
postgresql2 = factories.postgresql("postgresql_proc2")


@pytest.fixture()
def session_obj(postgresql: Connection[str]) -> sessionmaker:
Expand Down
48 changes: 48 additions & 0 deletions tests/test_02_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,54 @@ def test_init_database(postgresql: Connection[str]) -> None:
conn.close()


def test_init_database_with_password(postgresql2: Connection[str]) -> None:
connection_url = sa.engine.URL.create(
drivername="postgresql+psycopg2",
username=postgresql2.info.user,
password=postgresql2.info.password,
host=postgresql2.info.host,
port=postgresql2.info.port,
database=postgresql2.info.dbname,
)
connection_string = connection_url.render_as_string(False)
engine = sa.create_engine(connection_string)
conn = engine.connect()
query = sa.text(
"SELECT table_name FROM information_schema.tables WHERE table_schema='public'"
)
# start with an empty db structure
expected_tables_at_beginning: set[str] = set()
assert set(conn.execute(query).scalars()) == expected_tables_at_beginning # type: ignore

# verify create structure
db.init_database(connection_string, force=True)
expected_tables_complete = set(db.BaseModel.metadata.tables).union(
{"alembic_version"}
)
assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore

adaptor_properties = mock_config()
request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash)
session_obj = sa.orm.sessionmaker(engine)
with session_obj() as session:
session.add(adaptor_properties)
session.add(request)
session.commit()

db.init_database(connection_string)
assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore
with session_obj() as session:
requests = db.get_accepted_requests(session=session)
assert len(requests) == 1

db.init_database(connection_string, force=True)
assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore
with session_obj() as session:
requests = db.get_accepted_requests(session=session)
assert len(requests) == 0
conn.close()


def test_ensure_session_obj(
postgresql: Connection[str], session_obj: sessionmaker, temp_environ: Any
) -> None:
Expand Down

0 comments on commit d96d36b

Please sign in to comment.