From 555f895ef7db68844bdc417397522bd5f787fb8d Mon Sep 17 00:00:00 2001 From: Tim Fischer Date: Sun, 11 Feb 2024 05:22:57 +0000 Subject: [PATCH] add word_frequencies to sourcedocumentdata --- ..._word_frequencies_to_sourcedocumentdata.py | 29 ++++++++++ .../app/core/data/dto/source_document_data.py | 6 ++- .../app/core/data/orm/source_document_data.py | 7 +++ .../steps/text/write_pptd_to_database.py | 6 +++ backend/src/migration/migrate.py | 53 +++++++++++++++++++ 5 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 backend/src/alembic/versions/b0ac316511e1_add_word_frequencies_to_sourcedocumentdata.py diff --git a/backend/src/alembic/versions/b0ac316511e1_add_word_frequencies_to_sourcedocumentdata.py b/backend/src/alembic/versions/b0ac316511e1_add_word_frequencies_to_sourcedocumentdata.py new file mode 100644 index 000000000..6129169d7 --- /dev/null +++ b/backend/src/alembic/versions/b0ac316511e1_add_word_frequencies_to_sourcedocumentdata.py @@ -0,0 +1,29 @@ +"""add word_frequencies to SourceDocumentData + +Revision ID: b0ac316511e1 +Revises: 3bd76cc03486 +Create Date: 2024-02-10 17:50:19.307561 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b9de10411f61" +down_revision: Union[str, None] = "3bd76cc03486" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "sourcedocumentdata", + sa.Column("word_frequencies", sa.String(), server_default="[]", nullable=False), + ) + + +def downgrade() -> None: + op.drop_column("sourcedocumentdata", "word_frequencies") diff --git a/backend/src/app/core/data/dto/source_document_data.py b/backend/src/app/core/data/dto/source_document_data.py index d10a292ae..fe68918a2 100644 --- a/backend/src/app/core/data/dto/source_document_data.py +++ b/backend/src/app/core/data/dto/source_document_data.py @@ -37,4 +37,8 @@ class SourceDocumentDataRead(SourceDocumentDataBase): # Properties for creation class SourceDocumentDataCreate(SourceDocumentDataBase): - pass + word_frequencies: str = Field( + description=( + "JSON Representation of List[WordFrequency] of the SourceDocument" + ), + ) diff --git a/backend/src/app/core/data/orm/source_document_data.py b/backend/src/app/core/data/orm/source_document_data.py index 1ad322b94..f4907ce69 100644 --- a/backend/src/app/core/data/orm/source_document_data.py +++ b/backend/src/app/core/data/orm/source_document_data.py @@ -29,6 +29,13 @@ class SourceDocumentDataORM(ORMBase): sentence_ends: Mapped[List[int]] = mapped_column( ARRAY(Integer), nullable=False, index=False ) + # JSON representation of List[{word: str, count: int}] + word_frequencies: Mapped[str] = mapped_column( + String, + server_default="[]", + nullable=False, + index=False, + ) @property def tokens(self): diff --git a/backend/src/app/preprocessing/pipeline/steps/text/write_pptd_to_database.py b/backend/src/app/preprocessing/pipeline/steps/text/write_pptd_to_database.py index 992894120..297ba780b 100644 --- a/backend/src/app/preprocessing/pipeline/steps/text/write_pptd_to_database.py +++ b/backend/src/app/preprocessing/pipeline/steps/text/write_pptd_to_database.py @@ -1,5 +1,6 @@ import traceback +import srsly from loguru import logger from psycopg2 import OperationalError from sqlalchemy.orm import Session @@ -52,6 +53,10 @@ def _create_and_persist_sdoc(db: Session, pptd: PreProTextDoc) -> SourceDocument def _persist_sdoc_data( db: Session, sdoc_db_obj: SourceDocumentORM, pptd: PreProTextDoc ) -> None: + word_frequencies_str = srsly.json_dumps( + [{"word": word, "count": count} for word, count in pptd.word_freqs.items()] + ) + sdoc_data = SourceDocumentDataCreate( id=sdoc_db_obj.id, content=pptd.text, @@ -60,6 +65,7 @@ def _persist_sdoc_data( token_ends=[e for _, e in pptd.token_character_offsets], sentence_starts=[s.start for s in pptd.sentences], sentence_ends=[s.end for s in pptd.sentences], + word_frequencies=word_frequencies_str, ) crud_sdoc_data.create(db=db, create_dto=sdoc_data) diff --git a/backend/src/migration/migrate.py b/backend/src/migration/migrate.py index 1dbc7374d..d0751e57d 100644 --- a/backend/src/migration/migrate.py +++ b/backend/src/migration/migrate.py @@ -6,6 +6,7 @@ from alembic.command import upgrade from alembic.config import Config +from app.core.data.crud.crud_base import NoSuchElementError from app.core.data.crud.project_metadata import crud_project_meta from app.core.data.crud.source_document_data import crud_sdoc_data from app.core.data.crud.source_document_metadata import crud_sdoc_meta @@ -75,6 +76,11 @@ def run_required_migrations(): db_version.version = 8 db.commit() print("MIGRATED IMAGE WIDTH HEIGHT!") + if db_version.version < 9: + __migrate_word_frequencies(db) + db_version.version = 9 + db.commit() + print("MIGRATED WORD FREQUENCIES!") def __migrate_database_schema() -> None: @@ -439,3 +445,50 @@ def __migrate_image_width_height(db: Session): db.add(height_pm) db.commit() + + +def __migrate_word_frequencies(db: Session): + import srsly + + from app.core.data.dto.word_frequency import WordFrequencyRead + from app.core.data.orm.word_frequency import WordFrequencyORM + + projects = db.query(ProjectORM).all() + for project in projects: + logger.info( + "Migration: Migrating word_frequencies project {}...", + project.id, + ) + + sdoc_ids = ( + db.query(SourceDocumentORM.id) + .filter( + SourceDocumentORM.project_id == project.id, + SourceDocumentORM.doctype == DocType.text, + ) + .all() + ) + sdoc_ids = [sdoc_id[0] for sdoc_id in sdoc_ids] + + for sdoc_id in sdoc_ids: + result = ( + db.query(WordFrequencyORM) + .filter( + WordFrequencyORM.sdoc_id == sdoc_id, + ) + .all() + ) + word_frequencies = [WordFrequencyRead.model_validate(row) for row in result] + word_frequencies_str = srsly.json_dumps( + [{"word": wf.word, "count": wf.count} for wf in word_frequencies] + ) + + # update SourceDocumentData + try: + db_obj = crud_sdoc_data.read(db=db, id=sdoc_id) + except NoSuchElementError: + continue + setattr(db_obj, "word_frequencies", word_frequencies_str) + db.add(db_obj) + + db.commit()