Skip to content

Commit

Permalink
add word_frequencies to sourcedocumentdata
Browse files Browse the repository at this point in the history
  • Loading branch information
bigabig committed Feb 11, 2024
1 parent bd61281 commit 555f895
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""add word_frequencies to SourceDocumentData
Revision ID: b0ac316511e1
Revises: 3bd76cc03486
Create Date: 2024-02-10 17:50:19.307561
"""
from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b9de10411f61"
down_revision: Union[str, None] = "3bd76cc03486"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column(
"sourcedocumentdata",
sa.Column("word_frequencies", sa.String(), server_default="[]", nullable=False),
)


def downgrade() -> None:
op.drop_column("sourcedocumentdata", "word_frequencies")
6 changes: 5 additions & 1 deletion backend/src/app/core/data/dto/source_document_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ class SourceDocumentDataRead(SourceDocumentDataBase):

# Properties for creation
class SourceDocumentDataCreate(SourceDocumentDataBase):
pass
word_frequencies: str = Field(
description=(
"JSON Representation of List[WordFrequency] of the SourceDocument"
),
)
7 changes: 7 additions & 0 deletions backend/src/app/core/data/orm/source_document_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class SourceDocumentDataORM(ORMBase):
sentence_ends: Mapped[List[int]] = mapped_column(
ARRAY(Integer), nullable=False, index=False
)
# JSON representation of List[{word: str, count: int}]
word_frequencies: Mapped[str] = mapped_column(
String,
server_default="[]",
nullable=False,
index=False,
)

@property
def tokens(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import traceback

import srsly
from loguru import logger
from psycopg2 import OperationalError
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -52,6 +53,10 @@ def _create_and_persist_sdoc(db: Session, pptd: PreProTextDoc) -> SourceDocument
def _persist_sdoc_data(
db: Session, sdoc_db_obj: SourceDocumentORM, pptd: PreProTextDoc
) -> None:
word_frequencies_str = srsly.json_dumps(
[{"word": word, "count": count} for word, count in pptd.word_freqs.items()]
)

sdoc_data = SourceDocumentDataCreate(
id=sdoc_db_obj.id,
content=pptd.text,
Expand All @@ -60,6 +65,7 @@ def _persist_sdoc_data(
token_ends=[e for _, e in pptd.token_character_offsets],
sentence_starts=[s.start for s in pptd.sentences],
sentence_ends=[s.end for s in pptd.sentences],
word_frequencies=word_frequencies_str,
)
crud_sdoc_data.create(db=db, create_dto=sdoc_data)

Expand Down
53 changes: 53 additions & 0 deletions backend/src/migration/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from alembic.command import upgrade
from alembic.config import Config
from app.core.data.crud.crud_base import NoSuchElementError
from app.core.data.crud.project_metadata import crud_project_meta
from app.core.data.crud.source_document_data import crud_sdoc_data
from app.core.data.crud.source_document_metadata import crud_sdoc_meta
Expand Down Expand Up @@ -75,6 +76,11 @@ def run_required_migrations():
db_version.version = 8
db.commit()
print("MIGRATED IMAGE WIDTH HEIGHT!")
if db_version.version < 9:
__migrate_word_frequencies(db)
db_version.version = 9
db.commit()
print("MIGRATED WORD FREQUENCIES!")


def __migrate_database_schema() -> None:
Expand Down Expand Up @@ -439,3 +445,50 @@ def __migrate_image_width_height(db: Session):
db.add(height_pm)

db.commit()


def __migrate_word_frequencies(db: Session):
import srsly

from app.core.data.dto.word_frequency import WordFrequencyRead
from app.core.data.orm.word_frequency import WordFrequencyORM

projects = db.query(ProjectORM).all()
for project in projects:
logger.info(
"Migration: Migrating word_frequencies project {}...",
project.id,
)

sdoc_ids = (
db.query(SourceDocumentORM.id)
.filter(
SourceDocumentORM.project_id == project.id,
SourceDocumentORM.doctype == DocType.text,
)
.all()
)
sdoc_ids = [sdoc_id[0] for sdoc_id in sdoc_ids]

for sdoc_id in sdoc_ids:
result = (
db.query(WordFrequencyORM)
.filter(
WordFrequencyORM.sdoc_id == sdoc_id,
)
.all()
)
word_frequencies = [WordFrequencyRead.model_validate(row) for row in result]
word_frequencies_str = srsly.json_dumps(
[{"word": wf.word, "count": wf.count} for wf in word_frequencies]
)

# update SourceDocumentData
try:
db_obj = crud_sdoc_data.read(db=db, id=sdoc_id)
except NoSuchElementError:
continue
setattr(db_obj, "word_frequencies", word_frequencies_str)
db.add(db_obj)

db.commit()

0 comments on commit 555f895

Please sign in to comment.