Skip to content

Commit

Permalink
feat: add orm classes
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrinson-rise8 committed Sep 19, 2024
1 parent f9f3bb3 commit 3feb4dc
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 5 deletions.
8 changes: 7 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
DB_URI="postgresql+psycopg2://postgres:pw@localhost:5432/postgres"
MPI_DB_TYPE=postgres
MPI_DBNAME=testdb
MPI_HOST=localhost
MPI_PORT=5432
MPI_USER=postgres
MPI_PASSWORD=pw
DB_URI="sqlite:///db.sqlite3"
51 changes: 51 additions & 0 deletions alembic/versions/0db31a429322_create_algorithm_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""create algorithm tables
Revision ID: 0db31a429322
Revises: 6052c193a26a
Create Date: 2024-09-19 11:45:31.232579
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '0db31a429322'
down_revision: Union[str, None] = '6052c193a26a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('algorithm',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('is_default', sa.Boolean(), nullable=False),
sa.Column('label', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('label')
)
op.create_table('algorithm_pass',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('algorithm_id', sa.Integer(), nullable=False),
sa.Column('blocking_keys', sa.JSON(), nullable=False),
sa.Column('evaluators', sa.JSON(), nullable=False),
sa.Column('rule', sa.String(length=255), nullable=False),
sa.Column('cluster_ratio', sa.Float(), nullable=False),
sa.Column('kwargs', sa.JSON(), nullable=False),
sa.ForeignKeyConstraint(['algorithm_id'], ['algorithm.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_mpi_blocking_key_key'), 'mpi_blocking_key', ['key'], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_mpi_blocking_key_key'), table_name='mpi_blocking_key')
op.drop_table('algorithm_pass')
op.drop_table('algorithm')
# ### end Alembic commands ###
50 changes: 46 additions & 4 deletions src/recordlinker/linkage/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid

from sqlalchemy import event
from sqlalchemy import ForeignKey
from sqlalchemy import JSON
from sqlalchemy import orm
Expand All @@ -9,14 +10,12 @@
class Base(orm.DeclarativeBase):
pass


class Person(Base):
__tablename__ = "mpi_person"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
internal_id: orm.Mapped[uuid.UUID] = orm.mapped_column(default=uuid.uuid4)


class ExternalPerson(Base):
__tablename__ = "mpi_external_person"

Expand All @@ -25,15 +24,13 @@ class ExternalPerson(Base):
external_id: orm.Mapped[str] = orm.mapped_column(String(255))
source: orm.Mapped[str] = orm.mapped_column(String(255))


class Patient(Base):
__tablename__ = "mpi_patient"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
person_id: orm.Mapped[int] = orm.mapped_column(ForeignKey("mpi_person.id"))
data: orm.Mapped[dict] = orm.mapped_column(JSON)


class BlockingKey(Base):
__tablename__ = "mpi_blocking_key"

Expand All @@ -47,3 +44,48 @@ class BlockingValue(Base):
patient_id: orm.Mapped[int] = orm.mapped_column(ForeignKey("mpi_patient.id"))
blockingkey_id: orm.Mapped[int] = orm.mapped_column(ForeignKey("mpi_blocking_key.id"))
value: orm.Mapped[str] = orm.mapped_column(String(50), index=True)

class Algorithm(Base):
__tablename__ = "algorithm"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
is_default: orm.Mapped[bool] = orm.mapped_column(default=False)
label: orm.Mapped[str] = orm.mapped_column(String(255), unique=True)
description: orm.Mapped[str]

def check_only_one_default(connection, target):
"""
Check if there is already a default algorithm before inserting or updating.
Called before an insert or update operation on the
Algorithm table. If the `is_default` attribute of the target object is
set to True, it checks the database to ensure that no other algorithm
is marked as default. If another default algorithm exists, an exception
is raised to prevent the operation.
Parameters:
connection: The database connection being used for the operation.
target: The instance of the Algorithm class being inserted or updated.
Raises:
Exception: If another algorithm is already marked as default.
"""

if target.is_default:
existing_default = connection.execute("SELECT COUNT(*) FROM algorithm WHERE is_default = TRUE").scalar()
if(existing_default > 0):
raise ValueError("There can only be one default algorithm.")

event.listen(Algorithm, 'before_insert', check_only_one_default)
event.listen(Algorithm, 'before_update', check_only_one_default)

class AlgorithmPass(Base):
__tablename__ = "algorithm_pass"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
algorithm_id: orm.Mapped[int] = orm.mapped_column(ForeignKey("algorithm.id"))
blocking_keys: orm.Mapped[list[int]] = orm.mapped_column(JSON)
evaluators: orm.Mapped[list[str]] = orm.mapped_column(JSON)
rule: orm.Mapped[str] = orm.mapped_column(String(255))
cluster_ratio: orm.Mapped[float]
kwargs: orm.Mapped[dict] = orm.mapped_column(JSON)

0 comments on commit 3feb4dc

Please sign in to comment.