Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove hardcoded test settings #19

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .env
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
MPI_DB_TYPE=postgres
MPI_DBNAME=testdb
MPI_HOST=localhost
MPI_PORT=5432
MPI_USER=postgres
MPI_PASSWORD=pw
DB_URI="sqlite:///db.sqlite3"
DB_URI="postgresql+psycopg2://postgres:pw@localhost:5432/postgres"
4 changes: 1 addition & 3 deletions scripts/local_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cd "$(dirname "$0")/.."

PORT=${1:-8000}

DB_PID=$(docker run -d --rm -p 5432:5432 -e POSTGRES_PASSWORD=pw -e POSTGRES_DB=testdb postgres:13-alpine)
DB_PID=$(docker run -d --rm -p 5432:5432 -e POSTGRES_PASSWORD=pw postgres:13-alpine)

cleanup() {
docker stop ${DB_PID} > /dev/null 2>&1
Expand All @@ -28,7 +28,5 @@ done

trap cleanup EXIT

# Read in environment variables defined in .env
export $(grep -v '^#' .env | xargs)
# Start the API server
uvicorn recordlinker.main:app --app-dir src --reload --host 0 --port ${PORT} --log-config src/recordlinker/log_config.yml
3 changes: 2 additions & 1 deletion scripts/test_unit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cd "$(dirname "$0")/.."

TESTS=${1:-tests/unit/}

DB_PID=$(docker run -d --rm -p 5432:5432 -e POSTGRES_PASSWORD=pw -e POSTGRES_DB=testdb postgres:13-alpine)
DB_PID=$(docker run -d --rm -p 5432:5432 -e POSTGRES_PASSWORD=pw postgres:13-alpine)

cleanup() {
docker stop ${DB_PID} > /dev/null 2>&1
Expand All @@ -28,4 +28,5 @@ done

trap cleanup EXIT

# Run the tests
pytest ${TESTS}
44 changes: 11 additions & 33 deletions src/recordlinker/config.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,24 @@
from functools import lru_cache
from typing import Optional
import typing

from pydantic import Field
from pydantic_settings import BaseSettings
import pydantic
import pydantic_settings


class Settings(BaseSettings):
mpi_db_type: str = Field(
description="The type of database used by the MPI",
class Settings(pydantic_settings.BaseSettings):
model_config = pydantic_settings.SettingsConfigDict(
env_file='.env', env_file_encoding='utf-8',
)
mpi_dbname: str = Field(
description="The name of the database used by the MPI",
)
mpi_host: str = Field(
description="The host name of the MPI database",
)
mpi_user: str = Field(
description="The name of the user used to connect to the MPI database",
)
mpi_password: str = Field(
description="The password used to connect to the MPI database",
)
mpi_port: str = Field(description="The port used to connect to the MPI database")
connection_pool_size: Optional[int] = Field(

db_uri: str = pydantic.Field(description="The URI for the MPI database")
connection_pool_size: typing.Optional[int] = pydantic.Field(
description="The number of MPI database connections in the connection pool",
default=5,
)
connection_pool_max_overflow: Optional[int] = Field(
connection_pool_max_overflow: typing.Optional[int] = pydantic.Field(
description="The maximum number of MPI database connections that can be opened "
"above the connection pool size",
default=10,
)


@lru_cache()
def get_settings() -> dict:
"""
Load the values specified in the Settings class from the environment and return a
dictionary containing them. The dictionary is cached to reduce overhead accessing
these values.

:return: A dictionary with keys specified by the Settings. The value of each key is
read from the corresponding environment variable.
"""
return Settings().dict()
settings = Settings()
25 changes: 0 additions & 25 deletions src/recordlinker/linkage/config.py

This file was deleted.

19 changes: 5 additions & 14 deletions src/recordlinker/linkage/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from sqlalchemy.dialects.postgresql import aggregate_order_by
from sqlalchemy.dialects.postgresql import array_agg

from recordlinker.config import settings
from recordlinker.linkage.core import BaseMPIConnectorClient
from recordlinker.linkage.dal import DataAccessLayer
from recordlinker.linkage.utils import extract_value_with_resource_path
from recordlinker.linkage.utils import load_mpi_env_vars_os


class DIBBsMPIConnectorClient(BaseMPIConnectorClient):
Expand All @@ -29,24 +29,15 @@ class DIBBsMPIConnectorClient(BaseMPIConnectorClient):

"""

def __init__(self, pool_size: int = 5, max_overflow: int = 10):
def __init__(self):
"""
Initialize the MPI connector client with the MPI database.
:param pool_size: The number of connections to keep open to the database.
:param max_overflow: The number of connections to allow in connection pool.
"""
dbsettings = load_mpi_env_vars_os()
dbuser = dbsettings.get("user")
dbname = dbsettings.get("dbname")
dbpwd = dbsettings.get("password")
dbhost = dbsettings.get("host")
dbport = dbsettings.get("port")
self.dal = DataAccessLayer()
self.dal.get_connection(
engine_url=f"postgresql+psycopg2://{dbuser}:"
+ f"{dbpwd}@{dbhost}:{dbport}/{dbname}",
pool_size=pool_size,
max_overflow=max_overflow,
engine_url=settings.db_uri,
pool_size=settings.connection_pool_size,
max_overflow=settings.connection_pool_max_overflow,
)
self.dal.initialize_schema()
self.column_to_fhirpaths = {
Expand Down
19 changes: 0 additions & 19 deletions src/recordlinker/linkage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,6 @@
import fhirpathpy
import rapidfuzz

from recordlinker.config import get_settings


def load_mpi_env_vars_os():
"""
Simple helper function to load some of the environment variables
needed to make a database connection as part of the DB migrations.
"""
dbsettings = {
"dbname": get_settings().get("mpi_dbname"),
"user": get_settings().get("mpi_user"),
"password": get_settings().get("mpi_password"),
"host": get_settings().get("mpi_host"),
"port": get_settings().get("mpi_port"),
"db_type": get_settings().get("mpi_db_type"),
}
return dbsettings


# TODO: Not sure if we will need this or not
# leaving in utils for now until it's determined that
# we won't need to use this within any of the DAL/MPI/LINK
Expand Down
13 changes: 4 additions & 9 deletions src/recordlinker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,18 @@
from pydantic import Field

from recordlinker.base_service import BaseService
from recordlinker.config import settings
from recordlinker.linkage.algorithms import DIBBS_BASIC
from recordlinker.linkage.algorithms import DIBBS_ENHANCED
from recordlinker.linkage.link import add_person_resource
from recordlinker.linkage.link import link_record_against_mpi
from recordlinker.linkage.mpi import DIBBsMPIConnectorClient
from recordlinker.utils import get_settings
from recordlinker.utils import read_json_from_assets
from recordlinker.utils import run_migrations

# Ensure MPI is configured as expected.
run_migrations()
settings = get_settings()
MPI_CLIENT = DIBBsMPIConnectorClient(
pool_size=settings["connection_pool_size"],
max_overflow=settings["connection_pool_max_overflow"],
)
MPI_CLIENT = DIBBsMPIConnectorClient()
# Instantiate FastAPI via DIBBs' BaseService class
app = BaseService(
service_name="DIBBs Record Linkage Service",
Expand Down Expand Up @@ -148,13 +144,12 @@ async def link_record(

# Check that DB type is appropriately set up as Postgres so
# we can fail fast if it's not
db_type = get_settings().get("mpi_db_type", "")
if db_type != "postgres":
if not settings.db_uri.startswith("postgres"):
response.status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
return {
"found_match": False,
"updated_bundle": input_bundle,
"message": f"Unsupported database type {db_type} supplied. "
"message": f"Unsupported database {settings.db_uri} supplied. "
+ "Make sure your environment variables include an entry "
+ "for `mpi_db_type` and that it is set to 'postgres'.",
}
Expand Down
51 changes: 14 additions & 37 deletions src/recordlinker/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import logging
import os
import pathlib
import subprocess
from typing import Literal

from sqlalchemy import text
from sqlalchemy.engine import url

from recordlinker.config import get_settings
from recordlinker.config import settings
from recordlinker.linkage.dal import DataAccessLayer
from recordlinker.linkage.mpi import DIBBsMPIConnectorClient

Expand All @@ -32,18 +32,23 @@ def run_pyway(

logger = logging.getLogger(__name__)

# Extract the database type and its parts from the MPI database URI.
db_parts = url.make_url(settings.db_uri)
db_type = db_parts.drivername.split("+")[0]
if db_type == "postgresql":
db_type = "postgres"

# Prepare the pyway command.
migrations_dir = str(pathlib.Path(__file__).parent.parent.parent / "migrations")
settings = get_settings()
pyway_args = [
"--database-table public.pyway",
f"--database-migration-dir {migrations_dir}",
f"--database-type {settings['mpi_db_type']}",
f"--database-host {settings['mpi_host']}",
f"--database-port {settings['mpi_port']}",
f"--database-name {settings['mpi_dbname']}",
f"--database-username {settings['mpi_user']}",
f"--database-password {settings['mpi_password']}",
f"--database-type {db_type}",
f"--database-host {db_parts.host}",
f"--database-port {db_parts.port}",
f"--database-name {db_parts.database}",
f"--database-username {db_parts.username}",
f"--database-password {db_parts.password}",
]

full_command = ["pyway", pyway_command] + pyway_args
Expand Down Expand Up @@ -107,34 +112,6 @@ def run_migrations():
raise Exception(validation_response.stderr.decode("utf-8"))


def set_mpi_env_vars():
ericbuckley marked this conversation as resolved.
Show resolved Hide resolved
"""
Utility function for testing purposes that sets the environment variables
of the testing suite to prespecified valid values, and clears out any
old values from the DB Settings cache.
"""
os.environ["mpi_db_type"] = "postgres"
os.environ["mpi_dbname"] = "testdb"
os.environ["mpi_user"] = "postgres"
os.environ["mpi_password"] = "pw"
os.environ["mpi_host"] = "localhost"
os.environ["mpi_port"] = "5432"
get_settings.cache_clear()


def pop_mpi_env_vars():
"""
Utility function for testing purposes that removes the environment variables
used for database access from the testing environment.
"""
os.environ.pop("mpi_db_type", None)
os.environ.pop("mpi_dbname", None)
os.environ.pop("mpi_user", None)
os.environ.pop("mpi_password", None)
os.environ.pop("mpi_host", None)
os.environ.pop("mpi_port", None)


def _clean_up(dal: DataAccessLayer | None = None) -> None:
"""
Utility function for testing purposes that makes tests idempotent by cleaning up
Expand Down
13 changes: 4 additions & 9 deletions tests/unit/test_dal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from recordlinker.linkage.dal import DataAccessLayer
from recordlinker.linkage.mpi import DIBBsMPIConnectorClient
from recordlinker.utils import _clean_up
from recordlinker.config import settings
from sqlalchemy import Engine
from sqlalchemy import select
from sqlalchemy import Table
Expand All @@ -13,9 +14,7 @@

def _init_db() -> DataAccessLayer:
dal = DataAccessLayer()
dal.get_connection(
engine_url="postgresql+psycopg2://postgres:pw@localhost:5432/testdb"
)
dal.get_connection(engine_url=settings.db_uri)
_clean_up(dal)

# load ddl
Expand Down Expand Up @@ -48,9 +47,7 @@ def test_init_dal():

def test_get_connection():
dal = DataAccessLayer()
dal.get_connection(
engine_url="postgresql+psycopg2://postgres:pw@localhost:5432/testdb"
)
dal.get_connection(engine_url=settings.db_uri)

assert dal.engine is not None
assert isinstance(dal.engine, Engine)
Expand All @@ -68,9 +65,7 @@ def test_get_connection():

def test_get_session():
dal = DataAccessLayer()
dal.get_connection(
engine_url="postgresql+psycopg2://postgres:pw@localhost:5432/testdb"
)
dal.get_connection(engine_url=settings.db_uri)
dal.get_session()


Expand Down
14 changes: 2 additions & 12 deletions tests/unit/test_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from json.decoder import JSONDecodeError

import pytest
from recordlinker.config import settings
from recordlinker.linkage.algorithms import DIBBS_BASIC
from recordlinker.linkage.algorithms import DIBBS_ENHANCED
from recordlinker.linkage.dal import DataAccessLayer
Expand Down Expand Up @@ -41,19 +42,8 @@


def _init_db() -> DataAccessLayer:
os.environ = {
"mpi_dbname": "testdb",
"mpi_user": "postgres",
"mpi_password": "pw",
"mpi_host": "localhost",
"mpi_port": "5432",
"mpi_db_type": "postgres",
}

dal = DataAccessLayer()
dal.get_connection(
engine_url="postgresql+psycopg2://postgres:pw@localhost:5432/testdb"
)
dal.get_connection(engine_url=settings.db_uri)
_clean_up(dal)

# load ddl
Expand Down
Loading
Loading