diff --git a/.ruff.toml b/.ruff.toml deleted file mode 100644 index 5405879f6..000000000 --- a/.ruff.toml +++ /dev/null @@ -1,8 +0,0 @@ -src = ["backend/src"] - -[lint] - -# I: isort-compatible import sorting -# W291: Trailing whitespace -# W292: Add newline to end of file -extend-select = ["I", "W292", "W291"] diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 9a32fa34d..4bd8ddfaf 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -8,6 +8,7 @@ // backend "ms-python.python", "ms-python.debugpy", + "ms-python.vscode-pylance", "charliermarsh.ruff", // frontend "dbaeumer.vscode-eslint", diff --git a/.vscode/settings.json b/.vscode/settings.json index 779be299e..e2ab9305c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,5 @@ { "python.envFile": "${workspaceFolder}/backend/.env", - "python.analysis.extraPaths": ["./backend/src"], "python.autoComplete.extraPaths": ["./backend/src"], "editor.formatOnSave": true, "prettier.configPath": "./frontend/package.json", @@ -8,5 +7,6 @@ "typescript.preferences.importModuleSpecifierEnding": "js", "editor.codeActionsOnSave": { "source.organizeImports": "explicit" - } + }, + "python.analysis.diagnosticMode": "workspace" } diff --git a/backend/src/app/celery/background_jobs/__init__.py b/backend/src/app/celery/background_jobs/__init__.py index 9383b9d4a..6572f5984 100644 --- a/backend/src/app/celery/background_jobs/__init__.py +++ b/backend/src/app/celery/background_jobs/__init__.py @@ -26,14 +26,6 @@ def start_trainer_job_async( start_trainer_job_task.apply_async(kwargs={"trainer_job_id": trainer_job_id}) -def use_trainer_model_async( - trainer_job_id: str, -) -> Any: - from app.celery.background_jobs.tasks import use_trainer_model_task - - return use_trainer_model_task.apply_async(kwargs={"trainer_job_id": trainer_job_id}) - - def import_uploaded_archive_apply_async( archive_file_path: Path, project_id: int ) -> Any: diff --git a/backend/src/app/celery/background_jobs/tasks.py b/backend/src/app/celery/background_jobs/tasks.py index 3e004155a..f7df1a755 100644 --- a/backend/src/app/celery/background_jobs/tasks.py +++ b/backend/src/app/celery/background_jobs/tasks.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Tuple +from typing import Tuple from app.celery.background_jobs.cota import start_cota_refinement_job_ from app.celery.background_jobs.crawl import start_crawler_job_ @@ -14,7 +14,6 @@ ) from app.celery.background_jobs.trainer import ( start_trainer_job_, - use_trainer_model_task_, ) from app.celery.celery_worker import celery_worker from app.core.data.dto.crawler_job import CrawlerJobRead @@ -41,15 +40,6 @@ def start_trainer_job_task(trainer_job_id: str) -> None: start_trainer_job_(trainer_job_id=trainer_job_id) -@celery_worker.task( - acks_late=True, - autoretry_for=(Exception,), - retry_kwargs={"max_retries": 0, "countdown": 5}, -) -def use_trainer_model_task(trainer_job_id: str) -> List[float]: - return use_trainer_model_task_(trainer_job_id=trainer_job_id) - - @celery_worker.task(acks_late=True) def start_export_job(export_job: ExportJobRead) -> None: start_export_job_(export_job=export_job) diff --git a/backend/src/app/celery/background_jobs/trainer.py b/backend/src/app/celery/background_jobs/trainer.py index 03b2d0e03..039f24848 100644 --- a/backend/src/app/celery/background_jobs/trainer.py +++ b/backend/src/app/celery/background_jobs/trainer.py @@ -1,5 +1,3 @@ -from typing import List - from loguru import logger from app.trainer.trainer_service import TrainerService @@ -11,7 +9,3 @@ def start_trainer_job_(trainer_job_id: str) -> None: ts_result = ts._start_trainer_job_sync(trainer_job_id=trainer_job_id) logger.info(f"TrainerJob {trainer_job_id} has finished! Result: {ts_result}") - - -def use_trainer_model_task_(trainer_job_id: str) -> List[float]: - return ts._use_trainer_model_sync(trainer_job_id=trainer_job_id) diff --git a/backend/src/app/core/analysis/analysis_service.py b/backend/src/app/core/analysis/analysis_service.py index 90d351165..83561dbe2 100644 --- a/backend/src/app/core/analysis/analysis_service.py +++ b/backend/src/app/core/analysis/analysis_service.py @@ -165,7 +165,7 @@ def find_code_occurrences( CodeORM.id == code_id, ) ) - query = query.group_by(SourceDocumentORM, CodeORM, SpanTextORM.text) + query = query.group_by(SourceDocumentORM.id, CodeORM.id, SpanTextORM.text) res = query.all() span_code_occurrences = [ CodeOccurrence( @@ -205,7 +205,9 @@ def find_code_occurrences( ) ) query = query.group_by( - SourceDocumentORM, CodeORM, BBoxAnnotationORM.annotation_document_id + SourceDocumentORM.id, + CodeORM.id, + BBoxAnnotationORM.annotation_document_id, ) res = query.all() bbox_code_occurrences = [ diff --git a/backend/src/app/core/analysis/cota/pipeline/steps/init_search_space.py b/backend/src/app/core/analysis/cota/pipeline/steps/init_search_space.py index 054850159..62ddb1ca7 100644 --- a/backend/src/app/core/analysis/cota/pipeline/steps/init_search_space.py +++ b/backend/src/app/core/analysis/cota/pipeline/steps/init_search_space.py @@ -1,3 +1,4 @@ +import uuid from datetime import datetime from typing import Dict, List, Optional @@ -40,7 +41,7 @@ def init_search_space(cargo: Cargo) -> Cargo: top_k=cota.training_settings.search_space_topk, threshold=cota.training_settings.search_space_threshold, filter=Filter[SearchColumns]( - items=[], logic_operator=LogicalOperator.and_ + id=str(uuid.uuid4()), items=[], logic_operator=LogicalOperator.and_ ), ), ) diff --git a/backend/src/app/core/analysis/timeline.py b/backend/src/app/core/analysis/timeline.py index ede9d1a67..e26daf9f0 100644 --- a/backend/src/app/core/analysis/timeline.py +++ b/backend/src/app/core/analysis/timeline.py @@ -160,17 +160,19 @@ def timeline_analysis( .subquery() ) + subquery.c + sdoc_ids_agg = aggregate_ids(SourceDocumentORM.id, label="sdoc_ids") query = db.query( sdoc_ids_agg, - *group_by.apply(subquery.c["date"]), # EXTRACT (WEEK FROM TIMESTAMP ...) + *group_by.apply(subquery.c[1]), # type: ignore ).join(subquery, SourceDocumentORM.id == subquery.c.id) query = apply_filtering( query=query, filter=filter, db=db, subquery_dict=subquery.c ) - query = query.group_by(*group_by.apply(column=subquery.c["date"])) + query = query.group_by(*group_by.apply(column=subquery.c["date"])) # type: ignore result_rows = query.all() diff --git a/backend/src/app/core/data/crud/action.py b/backend/src/app/core/data/crud/action.py index ed1dac5ca..896c416dd 100644 --- a/backend/src/app/core/data/crud/action.py +++ b/backend/src/app/core/data/crud/action.py @@ -4,12 +4,12 @@ from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session -from app.core.data.crud.crud_base import CRUDBase +from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed from app.core.data.dto.action import ActionCreate, ActionTargetObjectType, ActionType from app.core.data.orm.action import ActionORM -class CRUDAction(CRUDBase[ActionORM, ActionCreate, None]): +class CRUDAction(CRUDBase[ActionORM, ActionCreate, UpdateNotAllowed]): def create(self, db: Session, *, create_dto: ActionCreate) -> ActionORM: # we have to override this to avoid recursion dto_obj_data = jsonable_encoder(create_dto) diff --git a/backend/src/app/core/data/crud/crud_base.py b/backend/src/app/core/data/crud/crud_base.py index 3d9e90ffe..998cf64c9 100644 --- a/backend/src/app/core/data/crud/crud_base.py +++ b/backend/src/app/core/data/crud/crud_base.py @@ -13,6 +13,10 @@ UpdateDTOType = TypeVar("UpdateDTOType", bound=BaseModel) +class UpdateNotAllowed(BaseModel): + pass + + class NoSuchElementError(Exception): def __init__(self, model: Type[ORMModelType], **kwargs): self.model = model diff --git a/backend/src/app/core/data/crud/object_handle.py b/backend/src/app/core/data/crud/object_handle.py index 7efc6c874..5db86c841 100644 --- a/backend/src/app/core/data/crud/object_handle.py +++ b/backend/src/app/core/data/crud/object_handle.py @@ -7,7 +7,7 @@ from app.core.data.crud.action import crud_action from app.core.data.crud.bbox_annotation import crud_bbox_anno from app.core.data.crud.code import crud_code -from app.core.data.crud.crud_base import CRUDBase, NoSuchElementError +from app.core.data.crud.crud_base import CRUDBase, NoSuchElementError, UpdateNotAllowed from app.core.data.crud.document_tag import crud_document_tag from app.core.data.crud.memo import crud_memo from app.core.data.crud.project import crud_project @@ -30,7 +30,7 @@ from app.core.db.sql_service import SQLService -class CRUDObjectHandle(CRUDBase[ObjectHandleORM, ObjectHandleCreate, None]): +class CRUDObjectHandle(CRUDBase[ObjectHandleORM, ObjectHandleCreate, UpdateNotAllowed]): __obj_id_crud_map = { "code_id": crud_code, "document_tag_id": crud_document_tag, @@ -65,11 +65,18 @@ def create(self, db: Session, *, create_dto: ObjectHandleCreate) -> ObjectHandle if isinstance(e.orig, UniqueViolation): db.close() # Flo: close the session because we have to start a new transaction with SQLService().db_session() as sess: - for obj_id_key, obj_id_val in create_dto.model_dump().items(): - if obj_id_val: - return self.read_by_attached_object_id( - db=sess, obj_id_key=obj_id_key, obj_id_val=obj_id_val - ) + obj_id_key, obj_id_val = next( + filter( + lambda item: item[0] is not None and item[1] is not None, + create_dto.model_dump().items(), + ), + (None, None), + ) + if obj_id_key is not None and obj_id_val is not None: + return self.read_by_attached_object_id( + db=sess, obj_id_key=obj_id_key, obj_id_val=obj_id_val + ) + raise e else: # Flo: re-raise Exception since it's not a UC Violation raise e diff --git a/backend/src/app/core/data/crud/source_document_data.py b/backend/src/app/core/data/crud/source_document_data.py index 18ed66179..0e4953d1a 100644 --- a/backend/src/app/core/data/crud/source_document_data.py +++ b/backend/src/app/core/data/crud/source_document_data.py @@ -1,4 +1,4 @@ -from app.core.data.crud.crud_base import CRUDBase +from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed from app.core.data.dto.source_document_data import SourceDocumentDataCreate from app.core.data.orm.source_document_data import SourceDocumentDataORM @@ -7,7 +7,7 @@ class CRUDSourceDocumentData( CRUDBase[ SourceDocumentDataORM, SourceDocumentDataCreate, - None, + UpdateNotAllowed, ] ): pass diff --git a/backend/src/app/core/data/crud/source_document_link.py b/backend/src/app/core/data/crud/source_document_link.py index 5537bb7c6..5f0da0a64 100644 --- a/backend/src/app/core/data/crud/source_document_link.py +++ b/backend/src/app/core/data/crud/source_document_link.py @@ -2,18 +2,19 @@ from sqlalchemy.orm import Session -from app.core.data.crud.crud_base import CRUDBase, ORMModelType, UpdateDTOType +from app.core.data.crud.crud_base import ( + CRUDBase, + UpdateNotAllowed, +) from app.core.data.dto.source_document_link import SourceDocumentLinkCreate from app.core.data.orm.source_document import SourceDocumentORM from app.core.data.orm.source_document_link import SourceDocumentLinkORM class CRUDSourceDocumentLink( - CRUDBase[SourceDocumentLinkORM, SourceDocumentLinkCreate, None] + CRUDBase[SourceDocumentLinkORM, SourceDocumentLinkCreate, UpdateNotAllowed] ): - def update( - self, db: Session, *, id: int, update_dto: UpdateDTOType - ) -> ORMModelType: + def update(self, db: Session, *, id: int, update_dto): raise NotImplementedError() def resolve_filenames_to_sdoc_ids( @@ -37,8 +38,7 @@ def resolve_filenames_to_sdoc_ids( ), SourceDocumentORM.project_id == proj_id, ) - # noinspection PyTypeChecker - sdoc_fn_to_id: Dict[str, int] = dict(query2.all()) + sdoc_fn_to_id: Dict[str, int] = {filename: id for filename, id in query2.all()} resolved_links: List[SourceDocumentLinkORM] = [] diff --git a/backend/src/app/core/data/crud/span_text.py b/backend/src/app/core/data/crud/span_text.py index 3fbed7642..4e71b78c5 100644 --- a/backend/src/app/core/data/crud/span_text.py +++ b/backend/src/app/core/data/crud/span_text.py @@ -2,14 +2,14 @@ from sqlalchemy.orm import Session -from app.core.data.crud.crud_base import CRUDBase +from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed from app.core.data.dto.span_text import SpanTextCreate from app.core.data.orm.span_text import SpanTextORM -class CRUDSpanText(CRUDBase[SpanTextORM, SpanTextCreate, None]): - def update(self, db: Session, *, id: int, update_dto) -> SpanTextORM: - # Flo: We no not want to update SourceDocument +class CRUDSpanText(CRUDBase[SpanTextORM, SpanTextCreate, UpdateNotAllowed]): + def update(self, db: Session, *, id: int, update_dto): + # Flo: We no not want to update SpanText raise NotImplementedError() def create(self, db: Session, *, create_dto: SpanTextCreate) -> SpanTextORM: diff --git a/backend/src/app/core/data/crud/word_frequency.py b/backend/src/app/core/data/crud/word_frequency.py index 18e4a0345..4ebc9b380 100644 --- a/backend/src/app/core/data/crud/word_frequency.py +++ b/backend/src/app/core/data/crud/word_frequency.py @@ -2,7 +2,7 @@ from sqlalchemy.orm import Session -from app.core.data.crud.crud_base import CRUDBase +from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed from app.core.data.doc_type import DocType from app.core.data.dto.word_frequency import WordFrequencyCreate, WordFrequencyRead from app.core.data.orm.source_document import SourceDocumentORM @@ -13,9 +13,12 @@ class CrudWordFrequency( CRUDBase[ WordFrequencyORM, WordFrequencyCreate, - None, + UpdateNotAllowed, ] ): + def update(self, db: Session, *, id: int, update_dto): + raise NotImplementedError() + def read_by_project_and_doctype( self, db: Session, *, project_id: int, doctype: DocType ) -> List[WordFrequencyRead]: diff --git a/backend/src/app/core/data/dto/export_job.py b/backend/src/app/core/data/dto/export_job.py index 523aa40e6..5d51a48c4 100644 --- a/backend/src/app/core/data/dto/export_job.py +++ b/backend/src/app/core/data/dto/export_job.py @@ -91,7 +91,7 @@ class ExportJobParameters(BaseModel): export_job_type: ExportJobType = Field( description="The type of the export job (what to export)" ) - export_format: Optional[ExportFormat] = Field( + export_format: ExportFormat = Field( description="The format of the exported data.", default=ExportFormat.CSV, ) diff --git a/backend/src/app/core/data/dto/object_handle.py b/backend/src/app/core/data/dto/object_handle.py index 29729e3bb..95e59956b 100644 --- a/backend/src/app/core/data/dto/object_handle.py +++ b/backend/src/app/core/data/dto/object_handle.py @@ -15,12 +15,13 @@ class ObjectHandleBaseDTO(BaseModel): document_tag_id: Optional[int] = None memo_id: Optional[int] = None - # noinspection PyMethodParameters - @model_validator(mode="before") # TODO: Before == root? - def check_at_least_one_not_null(cls, values): - for val in values: - if val: - return values + @model_validator(mode="after") + def check_at_least_one_not_null(self): + # make sure that at least one of the fields is not null + values = self.model_dump() + for val in values.values(): + if val is not None: + return self raise ValueError("At least one of the fields has to be not null!") diff --git a/backend/src/app/core/data/dto/source_document_metadata.py b/backend/src/app/core/data/dto/source_document_metadata.py index 880d93496..0d3f2c050 100644 --- a/backend/src/app/core/data/dto/source_document_metadata.py +++ b/backend/src/app/core/data/dto/source_document_metadata.py @@ -184,7 +184,7 @@ def with_value( id=sdoc_metadata_id, str_value=None, boolean_value=None, - date_value=value if value is not None else datetime.now(), + date_value=datetime.fromisoformat(value), int_value=None, list_value=None, source_document_id=source_document_id, diff --git a/backend/src/app/core/data/export/export_service.py b/backend/src/app/core/data/export/export_service.py index be6165c97..c9cbc8da9 100644 --- a/backend/src/app/core/data/export/export_service.py +++ b/backend/src/app/core/data/export/export_service.py @@ -603,13 +603,10 @@ def _export_all_user_annotations_from_sdoc( ) # export the data - export_data = None + export_data = pd.DataFrame() for adoc in all_adocs: adoc_data = self.__generate_export_df_for_adoc(db=db, adoc=adoc) - if export_data is None: - export_data = adoc_data - else: - export_data = pd.concat((export_data, adoc_data)) + export_data = pd.concat((export_data, adoc_data)) # write single file for all annos of that doc export_file = self.__write_export_data_to_temp_file( @@ -656,15 +653,12 @@ def _export_user_memos_from_proj( f"There are no memos for User {user_id} in Project {project_id}!" ) - export_data = None + export_data = pd.DataFrame() for memo in memos: memo_data = self.__generate_export_df_for_memo( db=db, memo_id=memo.id, memo=memo ) - if export_data is None: - export_data = memo_data - else: - export_data = pd.concat((export_data, memo_data)) + export_data = pd.concat((export_data, memo_data)) export_file = self.__write_export_data_to_temp_file( data=export_data, diff --git a/backend/src/app/core/data/meta_type.py b/backend/src/app/core/data/meta_type.py index 2ab81ffd8..123ef021b 100644 --- a/backend/src/app/core/data/meta_type.py +++ b/backend/src/app/core/data/meta_type.py @@ -1,6 +1,8 @@ +from datetime import datetime from enum import Enum +from typing import List -from sqlalchemy import Column +from sqlalchemy.orm import QueryableAttribute from app.core.data.orm.source_document_metadata import SourceDocumentMetadataORM from app.core.filters.filtering_operators import FilterOperator @@ -13,8 +15,9 @@ class MetaType(str, Enum): BOOLEAN = "BOOLEAN" LIST = "LIST" - # TODO: was ist der richtige typ? - def get_metadata_column(self) -> Column: + def get_metadata_column( + self, + ) -> QueryableAttribute[str | int | bool | datetime | List[str] | None]: match self: case MetaType.STRING: return SourceDocumentMetadataORM.str_value diff --git a/backend/src/app/core/data/orm/memo.py b/backend/src/app/core/data/orm/memo.py index 4c1d1841e..20cd45a11 100644 --- a/backend/src/app/core/data/orm/memo.py +++ b/backend/src/app/core/data/orm/memo.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, func from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -19,10 +19,10 @@ class MemoORM(ORMBase): String, nullable=False, index=False ) # TODO Flo: This will go to ES soon! starred: Mapped[bool] = mapped_column(Boolean, nullable=False, index=True) - created: Mapped[Optional[int]] = mapped_column( + created: Mapped[int] = mapped_column( DateTime, server_default=func.now(), index=True ) - updated: Mapped[Optional[datetime]] = mapped_column( + updated: Mapped[datetime] = mapped_column( DateTime, server_default=func.now(), onupdate=func.current_timestamp() ) @@ -41,7 +41,7 @@ class MemoORM(ORMBase): nullable=False, index=True, ) - attached_to: Mapped[Optional["ObjectHandleORM"]] = relationship( + attached_to: Mapped["ObjectHandleORM"] = relationship( "ObjectHandleORM", uselist=False, back_populates="attached_memos", diff --git a/backend/src/app/core/data/repo/repo_service.py b/backend/src/app/core/data/repo/repo_service.py index 18a44cae7..f986e1491 100644 --- a/backend/src/app/core/data/repo/repo_service.py +++ b/backend/src/app/core/data/repo/repo_service.py @@ -66,6 +66,13 @@ def __init__(self, dst_path: Path): ) +class UnsupportedDocTypeForMimeType(Exception): + def __init__(self, mime_type: str): + super().__init__( + f"Unsupported DocType! Cannot infer DocType from MimeType '{mime_type}'." + ) + + class ErroneousArchiveException(Exception): def __init__(self, archive_path: Path): super().__init__(f"Error with Archive {archive_path}") diff --git a/backend/src/app/core/db/sql_service.py b/backend/src/app/core/db/sql_service.py index fc6788f23..c9814d67e 100644 --- a/backend/src/app/core/db/sql_service.py +++ b/backend/src/app/core/db/sql_service.py @@ -2,7 +2,6 @@ from typing import Generator from loguru import logger -from pydantic import PostgresDsn from sqlalchemy import create_engine, inspect from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, sessionmaker @@ -16,17 +15,9 @@ class SQLService(metaclass=SingletonMeta): def __new__(cls, *args, **kwargs): try: - db_uri = PostgresDsn.build( - scheme="postgresql", - username=conf.postgres.user, - password=conf.postgres.password, - host=conf.postgres.host, - port=int(conf.postgres.port), - path=f"{conf.postgres.db}", - ) - + db_uri = f"postgresql://{conf.postgres.user}:{conf.postgres.password}@{conf.postgres.host}:{conf.postgres.port}/{conf.postgres.db}" engine = create_engine( - str(db_uri), + db_uri, pool_pre_ping=True, pool_size=conf.postgres.pool.pool_size, max_overflow=conf.postgres.pool.max_overflow, diff --git a/backend/src/app/core/filters/filtering.py b/backend/src/app/core/filters/filtering.py index 367cb8c98..029a1c1bf 100644 --- a/backend/src/app/core/filters/filtering.py +++ b/backend/src/app/core/filters/filtering.py @@ -36,6 +36,8 @@ def get_sqlalchemy_operator(self): T = TypeVar("T", bound=AbstractColumns) +FilterValue = Union[bool, str, int, List[str], List[List[str]]] + class FilterExpression(BaseModel, Generic[T]): id: str @@ -49,7 +51,7 @@ class FilterExpression(BaseModel, Generic[T]): DateOperator, BooleanOperator, ] - value: Union[bool, str, int, List[str], List[List[str]]] + value: FilterValue def get_sqlalchemy_expression(self, **kwargs): if isinstance(self.column, int): diff --git a/backend/src/app/core/filters/filtering_operators.py b/backend/src/app/core/filters/filtering_operators.py index a45140db1..1bf57be37 100644 --- a/backend/src/app/core/filters/filtering_operators.py +++ b/backend/src/app/core/filters/filtering_operators.py @@ -1,7 +1,9 @@ from enum import Enum -from typing import List, Union -from sqlalchemy import Column, not_ +from sqlalchemy import not_ +from sqlalchemy.orm import QueryableAttribute + +from app.core.filters.filtering import FilterValue class FilterValueType(Enum): @@ -28,7 +30,7 @@ class BooleanOperator(Enum): EQUALS = "BOOLEAN_EQUALS" NOT_EQUALS = "BOOLEAN_NOT_EQUALS" - def apply(self, column: Column, value: bool): + def apply(self, column: QueryableAttribute, value: FilterValue): if not isinstance(value, bool): raise ValueError("Invalid value type for BooleanOperator (requires bool)!") @@ -46,7 +48,7 @@ class StringOperator(Enum): STARTS_WITH = "STRING_STARTS_WITH" ENDS_WITH = "STRING_ENDS_WITH" - def apply(self, column: Column, value: str): + def apply(self, column: QueryableAttribute, value: FilterValue): if not isinstance(value, str): raise ValueError("Invalid value type for StringOperator (requires str)!") @@ -67,7 +69,14 @@ class IDOperator(Enum): EQUALS = "ID_EQUALS" NOT_EQUALS = "ID_NOT_EQUALS" - def apply(self, column: Column, value: int | str): + def apply( + self, + column: QueryableAttribute, + value: FilterValue, + ): + if not isinstance(value, (int, str)): + raise ValueError("Invalid value type for IDOperator (requires int or str)!") + match self: case IDOperator.EQUALS: return column == value @@ -83,7 +92,7 @@ class NumberOperator(Enum): GTE = "NUMBER_GTE" LTE = "NUMBER_LTE" - def apply(self, column: Column, value: int): + def apply(self, column: QueryableAttribute, value: FilterValue): if not isinstance(value, int): raise ValueError("Invalid value type for NumberOperator (requires int)!") @@ -106,7 +115,17 @@ class IDListOperator(Enum): CONTAINS = "ID_LIST_CONTAINS" NOT_CONTAINS = "ID_LIST_NOT_CONTAINS" - def apply(self, column, value: Union[str, List[str]]): + def apply(self, column, value: FilterValue): + if not isinstance(value, (str, list)): + raise ValueError( + "Invalid value type for IDListOperator (requires str or list)!" + ) + if len(value) > 0 and not isinstance(value[0], str): + raise ValueError( + "Invalid value type for ListOperator (requires List[str])!" + ) + + # value should be Union[str, List[str]] if isinstance(column, tuple): if isinstance(value, str) and (len(column) == 2): # Column is tuple of ORMs, e.g. (SourceDocumentORM.document_tags, DocumentTagORM.id) @@ -143,7 +162,7 @@ class ListOperator(Enum): CONTAINS = "LIST_CONTAINS" NOT_CONTAINS = "LIST_NOT_CONTAINS" - def apply(self, column, value: List[str]): + def apply(self, column: QueryableAttribute, value: FilterValue): if not isinstance(value, list): raise ValueError( "Invalid value type for ListOperator (requires List[str])!" @@ -167,7 +186,7 @@ class DateOperator(Enum): GTE = "DATE_GTE" LTE = "DATE_LTE" - def apply(self, column: Column, value: str): + def apply(self, column: QueryableAttribute, value: FilterValue): if not isinstance(value, str): raise ValueError("Invalid value type for DateOperator (requires str)!") diff --git a/backend/src/app/core/filters/sorting.py b/backend/src/app/core/filters/sorting.py index 107bb7134..85d49f841 100644 --- a/backend/src/app/core/filters/sorting.py +++ b/backend/src/app/core/filters/sorting.py @@ -2,8 +2,8 @@ from typing import Generic, List, TypeVar, Union from pydantic import BaseModel -from sqlalchemy import Column, asc, desc -from sqlalchemy.orm import Session +from sqlalchemy import asc, desc +from sqlalchemy.orm import QueryableAttribute, Session from app.core.data.crud.project_metadata import crud_project_meta from app.core.data.dto.project_metadata import ProjectMetadataRead @@ -14,7 +14,7 @@ class SortDirection(str, Enum): ASC = "asc" DESC = "desc" - def apply(self, column: Column): + def apply(self, column: QueryableAttribute): match self: case SortDirection.ASC: return asc(column).nulls_last() diff --git a/backend/src/app/core/search/elasticsearch_service.py b/backend/src/app/core/search/elasticsearch_service.py index 4d8e97774..8f701fb2f 100644 --- a/backend/src/app/core/search/elasticsearch_service.py +++ b/backend/src/app/core/search/elasticsearch_service.py @@ -51,13 +51,22 @@ def __new__(cls, *args, **kwargs): f"Cannot find ElasticSearch Document Index Mapping: {doc_mappings_path}" ) - memo_mappings = srsly.read_json(memo_mappings_path) doc_mappings = srsly.read_json(doc_mappings_path) - - cls.doc_index_fields = set(doc_mappings["properties"].keys()) - cls.memo_index_fields = set(memo_mappings["properties"].keys()) - + if isinstance(doc_mappings, dict) and "properties" in doc_mappings: + cls.doc_index_fields = set(doc_mappings["properties"].keys()) + else: + raise ValueError( + "Invalid doc_mappings format or 'properties' key missing" + ) cls.doc_mappings = doc_mappings + + memo_mappings = srsly.read_json(memo_mappings_path) + if isinstance(memo_mappings, dict) and "properties" in memo_mappings: + cls.memo_index_fields = set(memo_mappings["properties"].keys()) + else: + raise ValueError( + "Invalid doc_mappings format or 'properties' key missing" + ) cls.memo_mappings = memo_mappings # ElasticSearch Connection @@ -113,7 +122,7 @@ def __create_index( *, index: str, mappings: Dict[str, Any], - settings: Dict[str, Any] = None, + settings: Optional[Dict[str, Any]] = None, replace_if_exists: bool = False, ) -> None: if replace_if_exists and self.__client.indices.exists(index=index): @@ -136,9 +145,14 @@ def __get_index_name(self, proj_id: int, index_type: str = "doc") -> str: def create_project_indices(self, *, proj_id: int) -> None: # create the ES Index for Documents - doc_settings = conf.elasticsearch.index_settings.docs - if doc_settings is not None: - doc_settings = srsly.read_json(doc_settings) + doc_settings_path = Path(conf.elasticsearch.index_settings.docs) + if not doc_settings_path.exists(): + raise FileNotFoundError( + f"Cannot find ElasticSearch Doc Index Settings: {doc_settings_path}" + ) + doc_settings = srsly.read_json(doc_settings_path) + if not isinstance(doc_settings, dict): + raise ValueError("Invalid doc_settings format.") self.__create_index( index=self.__get_index_name(proj_id=proj_id, index_type="doc"), @@ -148,9 +162,15 @@ def create_project_indices(self, *, proj_id: int) -> None: ) # create the ES Index for Memos - memo_settings = conf.elasticsearch.index_settings.memos - if memo_settings is not None: - memo_settings = srsly.read_json(memo_settings) + memo_settings_path = Path(conf.elasticsearch.index_settings.memos) + if not memo_settings_path.exists(): + raise FileNotFoundError( + f"Cannot find ElasticSearch Memo Index Settings: {memo_settings_path}" + ) + memo_settings = srsly.read_json(memo_settings_path) + if not isinstance(memo_settings, dict): + raise ValueError("Invalid memo_settings format.") + self.__create_index( index=self.__get_index_name(proj_id=proj_id, index_type="memo"), mappings=self.memo_mappings, diff --git a/backend/src/app/core/search/simsearch_service.py b/backend/src/app/core/search/simsearch_service.py index 1d5b11c99..030494d26 100644 --- a/backend/src/app/core/search/simsearch_service.py +++ b/backend/src/app/core/search/simsearch_service.py @@ -118,9 +118,7 @@ def __new__(cls, *args, **kwargs): return super(SimSearchService, cls).__new__(cls) - def _encode_text( - self, text: Union[str, List[str]], return_avg_emb: bool = False - ) -> np.ndarray: + def _encode_text(self, text: List[str], return_avg_emb: bool = False) -> np.ndarray: encoded_query = self.rms.clip_text_embedding(ClipTextEmbeddingInput(text=text)) if len(encoded_query.embeddings) == 1: return encoded_query.numpy().squeeze() @@ -196,7 +194,7 @@ def add_image_sdoc_to_index( "sdoc_id": sdoc_id, }, class_name=self._image_class_name, - vector=image_emb, + vector=image_emb.tolist(), ) def remove_sdoc_from_index(self, doctype: str, sdoc_id: int): diff --git a/backend/src/app/core/security.py b/backend/src/app/core/security.py index cd40dac43..5a818526f 100644 --- a/backend/src/app/core/security.py +++ b/backend/src/app/core/security.py @@ -1,6 +1,6 @@ import secrets from datetime import UTC, datetime, timedelta -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple from jose import jwt from loguru import logger @@ -38,7 +38,7 @@ def generate_jwt(user: UserORM) -> Tuple[str, datetime]: return (token, expire) -def decode_jwt(token: str) -> Optional[Dict]: +def decode_jwt(token: str) -> Dict: try: return jwt.decode( token, __jwt_secret, algorithms=__algo, options={"verify_aud": False} diff --git a/backend/src/app/preprocessing/pipeline/steps/image/create_pptd_from_caption.py b/backend/src/app/preprocessing/pipeline/steps/image/create_pptd_from_caption.py index b51f7a0ea..b3e955794 100644 --- a/backend/src/app/preprocessing/pipeline/steps/image/create_pptd_from_caption.py +++ b/backend/src/app/preprocessing/pipeline/steps/image/create_pptd_from_caption.py @@ -6,17 +6,18 @@ def create_pptd_from_caption(cargo: PipelineCargo) -> PipelineCargo: - sdoc_id = cargo.data["sdoc_id"] ppid: PreProImageDoc = cargo.data["ppid"] caption = ppid.metadata["caption"] + if isinstance(caption, list): + caption = " ".join(caption) + # we don't need to set the filepath and filename as they are not used for the text # tasks we apply on the caption. pptd = PreProTextDoc( filepath=Path("/this/is/a/fake_path.txt"), filename="fake_path.txt", project_id=ppid.project_id, - sdoc_id=sdoc_id, text=caption, html=f"

{caption}

", metadata={"language": "en"}, diff --git a/backend/src/app/preprocessing/pipeline/steps/image/run_object_detection.py b/backend/src/app/preprocessing/pipeline/steps/image/run_object_detection.py index 75eddae5b..4ceee524f 100644 --- a/backend/src/app/preprocessing/pipeline/steps/image/run_object_detection.py +++ b/backend/src/app/preprocessing/pipeline/steps/image/run_object_detection.py @@ -20,7 +20,6 @@ def run_object_detection(cargo: PipelineCargo) -> PipelineCargo: y_min=box.y_min, x_max=box.x_max, y_max=box.y_max, - confidence=box.confidence, ) ppid.bboxes.append(bbox) diff --git a/backend/src/app/preprocessing/pipeline/steps/text/extract_text_from_html_and_create_source_mapping.py b/backend/src/app/preprocessing/pipeline/steps/text/extract_text_from_html_and_create_source_mapping.py index 29ab60c3e..798b1efc0 100644 --- a/backend/src/app/preprocessing/pipeline/steps/text/extract_text_from_html_and_create_source_mapping.py +++ b/backend/src/app/preprocessing/pipeline/steps/text/extract_text_from_html_and_create_source_mapping.py @@ -8,13 +8,15 @@ class CustomLineHTMLParser(HTMLParser): + result: List[Dict[str, Union[str, int]]] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.result = None + self.result = [] def reset(self): super().reset() - self.result = None + self.result = [] @property def current_index(self): @@ -35,13 +37,21 @@ class HTMLTextMapper(CustomLineHTMLParser): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.result = [] - self.text = None + self.text = { + "text": "", + "start": 0, + "end": 0, + } self.end_spaces = 0 def reset(self): super().reset() self.result = [] - self.text = None + self.text = { + "text": "", + "start": 0, + "end": 0, + } def handle_data(self, data: str): # only add text if it is not only whitespaces! @@ -70,11 +80,14 @@ def handle_comment(self, data): self.text_end() def text_end(self): - if self.text: - self.text["end"] = self.current_index - self.end_spaces - self.result.append(self.text) - self.text = "" - self.end_spaces = 0 + self.text["end"] = self.current_index - self.end_spaces + self.result.append(self.text) + self.text = { + "text": "", + "start": 0, + "end": 0, + } + self.end_spaces = 0 def close(self): super().close() diff --git a/backend/src/app/preprocessing/pipeline/steps/text/run_spacy_pipeline.py b/backend/src/app/preprocessing/pipeline/steps/text/run_spacy_pipeline.py index f0d470968..7502e3e75 100644 --- a/backend/src/app/preprocessing/pipeline/steps/text/run_spacy_pipeline.py +++ b/backend/src/app/preprocessing/pipeline/steps/text/run_spacy_pipeline.py @@ -8,6 +8,8 @@ def run_spacy_pipeline(cargo: PipelineCargo) -> PipelineCargo: pptd: PreProTextDoc = cargo.data["pptd"] + + assert isinstance(pptd.metadata["language"], str), "Language is not a string" spacy_input: SpacyInput = SpacyInput( text=pptd.text, language=pptd.metadata["language"], diff --git a/backend/src/app/preprocessing/pipeline/steps/video/generate_webp_thumbnail_for_video.py b/backend/src/app/preprocessing/pipeline/steps/video/generate_webp_thumbnail_for_video.py index 919eca819..7a0cc1bdf 100644 --- a/backend/src/app/preprocessing/pipeline/steps/video/generate_webp_thumbnail_for_video.py +++ b/backend/src/app/preprocessing/pipeline/steps/video/generate_webp_thumbnail_for_video.py @@ -14,8 +14,12 @@ def generate_webp_thumbnail_for_video(cargo: PipelineCargo) -> PipelineCargo: ppvd: PreProVideoDoc = cargo.data["ppvd"] + assert isinstance(ppvd.metadata["duration"], str), "Duration is not a string" + assert isinstance(ppvd.metadata["width"], str), "Width is not a string" + half_time = float(ppvd.metadata["duration"]) // 2 frame_width = int(ppvd.metadata["width"]) + try: # get the frame at half time of the video half_time_frame, err = ( diff --git a/backend/src/app/preprocessing/preprocessing_service.py b/backend/src/app/preprocessing/preprocessing_service.py index a01f0ee25..946e3b484 100644 --- a/backend/src/app/preprocessing/preprocessing_service.py +++ b/backend/src/app/preprocessing/preprocessing_service.py @@ -31,6 +31,7 @@ from app.core.data.repo.repo_service import ( FileNotFoundInRepositoryError, RepoService, + UnsupportedDocTypeForMimeType, UnsupportedDocTypeForSourceDocument, ) from app.core.db.sql_service import SQLService @@ -53,6 +54,11 @@ def _store_uploaded_files_and_create_payloads( payloads: List[PreprocessingJobPayloadCreateWithoutPreproJobId] = [] for uploaded_file in uploaded_files: mime_type = uploaded_file.content_type + if mime_type is None: + raise HTTPException( + detail="Could not determine MIME type of uploaded file!", + status_code=406, + ) if not mime_type_supported(mime_type=mime_type): raise HTTPException( detail=f"Document with MIME type {mime_type} not supported!", @@ -74,6 +80,11 @@ def _store_uploaded_files_and_create_payloads( continue doc_type = get_doc_type(mime_type=mime_type) + if doc_type is None: + raise HTTPException( + detail=f"Document with MIME type {mime_type} not supported!", + status_code=406, + ) payloads.append( PreprocessingJobPayloadCreateWithoutPreproJobId( @@ -103,6 +114,9 @@ def _create_ppj_payloads_from_unimported_project_files( try: mime_type = magic.from_file(file_path, mime=True) doc_type = get_doc_type(mime_type=mime_type) + if not doc_type: + logger.error(f"Unsupported DocType (for MIME Type {mime_type})!") + raise UnsupportedDocTypeForMimeType(mime_type=mime_type) payloads.append( PreprocessingJobPayloadCreateWithoutPreproJobId( @@ -116,6 +130,7 @@ def _create_ppj_payloads_from_unimported_project_files( except ( FileNotFoundInRepositoryError, UnsupportedDocTypeForSourceDocument, + UnsupportedDocTypeForMimeType, Exception, ) as e: logger.warning( @@ -291,9 +306,7 @@ def prepare_and_start_preprocessing_job_async( def _get_pipeline(self, doc_type: DocType) -> PreprocessingPipeline: if doc_type not in self._pipelines: - self._pipelines[doc_type] = PreprocessingPipeline( - doc_type=doc_type, num_workers=1, force_sequential=True - ) + self._pipelines[doc_type] = PreprocessingPipeline(doc_type=doc_type) return self._pipelines[doc_type] def get_text_pipeline(self) -> PreprocessingPipeline: diff --git a/backend/src/app/preprocessing/ray_model_worker/generate_ray_model_worker_specs.py b/backend/src/app/preprocessing/ray_model_worker/generate_ray_model_worker_specs.py index c3ec14616..d1074bcc2 100644 --- a/backend/src/app/preprocessing/ray_model_worker/generate_ray_model_worker_specs.py +++ b/backend/src/app/preprocessing/ray_model_worker/generate_ray_model_worker_specs.py @@ -20,9 +20,12 @@ def get_all_apps( return apps -def rename_app_names(generated_spec_fp: Path, spec_out_fp: Path) -> dict: +def rename_app_names(generated_spec_fp: Path, spec_out_fp: Path): print("Renaming generated app names...") spec = srsly.read_yaml(generated_spec_fp) + if not isinstance(spec, dict) or "applications" not in spec: + raise ValueError("Invalid spec format: 'applications' key not found") + for app in spec["applications"]: app["name"] = app["route_prefix"].split("/")[1] diff --git a/backend/src/app/preprocessing/ray_model_worker/main.py b/backend/src/app/preprocessing/ray_model_worker/main.py deleted file mode 100644 index ca17a0d7c..000000000 --- a/backend/src/app/preprocessing/ray_model_worker/main.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging - -from deployments.dbert import DistilBertDeployment -from deployments.dbert import router as dbert_router -from deployments.spacy import SpacyDeployment -from deployments.spacy import router as spacy_router -from deployments.whisper import WhisperDeployment -from deployments.whisper import router as whisper_router -from dto.dbert import DbertInput, DbertOutput -from fastapi import FastAPI -from ray import serve -from ray.serve.deployment import Application - -logger = logging.getLogger("ray.serve") - -api = FastAPI() - -api.include_router(whisper_router, prefix="/whisper") -api.include_router(spacy_router, prefix="/spacy") -api.include_router(dbert_router, prefix="/dbert") - - -@serve.deployment( - num_replicas=1, - route_prefix="/", -) -@serve.ingress(api) -class APIIngress: - def __init__(self, **kwargs) -> None: - logger.info(f"{kwargs=}") - self.whisper: Application = kwargs["whisper_model_handle"] - self.dbert: Application = kwargs["dbert_model_handle"] - self.spacy: Application = kwargs["spacy_model_handle"] - - @api.get("/classify", response_model=DbertOutput) - async def classify(self, dbert_input: DbertInput): - predict_ref = await self.dbert.classify.remote(dbert_input.sentence) - predict_result = await predict_ref - return predict_result - - -Whisper: Application = WhisperDeployment.bind() -DBert: Application = DistilBertDeployment.bind() -Spacy: Application = SpacyDeployment.bind() - -app = APIIngress.bind( - whisper_model_handle=Whisper, - dbert_model_handle=DBert, - spacy_model_handle=Spacy, -) diff --git a/backend/src/app/preprocessing/ray_model_worker/models/blip2.py b/backend/src/app/preprocessing/ray_model_worker/models/blip2.py index 6f42bb64c..36c50c3f1 100644 --- a/backend/src/app/preprocessing/ray_model_worker/models/blip2.py +++ b/backend/src/app/preprocessing/ray_model_worker/models/blip2.py @@ -48,14 +48,16 @@ def __init__(self): f"Loading Blip2ForConditionalGeneration {MODEL} with {PRECISION_BIT} precision ..." ) - captioning_model: Blip2ForConditionalGeneration = ( - Blip2ForConditionalGeneration.from_pretrained( - MODEL, - load_in_8bit=load_in_8bit, - device_map=device_map, - torch_dtype=data_type, - ) + captioning_model = Blip2ForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=MODEL, + load_in_8bit=load_in_8bit, + device_map=device_map, + torch_dtype=data_type, ) + assert isinstance( + captioning_model, Blip2ForConditionalGeneration + ), "Failed to load captioning model" + captioning_model.eval() self.data_type = data_type self.feature_extractor = image_processor diff --git a/backend/src/app/preprocessing/ray_model_worker/models/clip.py b/backend/src/app/preprocessing/ray_model_worker/models/clip.py index 4b46dcb3c..be52e9f4b 100644 --- a/backend/src/app/preprocessing/ray_model_worker/models/clip.py +++ b/backend/src/app/preprocessing/ray_model_worker/models/clip.py @@ -6,6 +6,7 @@ ClipImageEmbeddingInput, ClipTextEmbeddingInput, ) +from numpy import ndarray from PIL import Image from ray import serve from ray_config import build_ray_model_deployment_config, conf @@ -53,6 +54,7 @@ def text_embedding(self, input: ClipTextEmbeddingInput) -> ClipEmbeddingOutput: device=TEXT_DEVICE, convert_to_numpy=True, ) + assert isinstance(encoded_text, ndarray), "Failed to encode texts" return ClipEmbeddingOutput(embeddings=encoded_text.tolist()) @@ -61,13 +63,15 @@ def image_embedding(self, input: ClipImageEmbeddingInput) -> ClipEmbeddingOutput with torch.no_grad(): encoded_images = self.image_encoder.encode( - sentences=images, + sentences=images, # type: ignore batch_size=IMAGE_BATCH_SIZE, show_progress_bar=False, normalize_embeddings=True, device=IMAGE_DEVICE, convert_to_numpy=True, ) + assert isinstance(encoded_images, ndarray), "Failed to encode images" + # close the images for img in images: img.close() diff --git a/backend/src/app/preprocessing/ray_model_worker/models/cota.py b/backend/src/app/preprocessing/ray_model_worker/models/cota.py index 257bd624e..8f48e05fc 100644 --- a/backend/src/app/preprocessing/ray_model_worker/models/cota.py +++ b/backend/src/app/preprocessing/ray_model_worker/models/cota.py @@ -21,10 +21,8 @@ logger = logging.getLogger("ray.serve") -cota_conf: Dict = build_ray_model_deployment_config("cota") - -@serve.deployment(**cota_conf) +@serve.deployment(**build_ray_model_deployment_config("cota")) class CotaModel: def finetune_apply_compute(self, input: RayCOTAJobInput) -> RayCOTAJobResponse: # 1 finetune diff --git a/backend/src/app/preprocessing/ray_model_worker/models/detr.py b/backend/src/app/preprocessing/ray_model_worker/models/detr.py index 97a843c2c..8d032d6e3 100644 --- a/backend/src/app/preprocessing/ray_model_worker/models/detr.py +++ b/backend/src/app/preprocessing/ray_model_worker/models/detr.py @@ -23,16 +23,21 @@ class DETRModel: def __init__(self): logger.debug(f"Loading DetrFeatureExtractor {MODEL} ...") feature_extractor = DetrFeatureExtractor.from_pretrained(MODEL, device=DEVICE) + assert isinstance( + feature_extractor, DetrFeatureExtractor + ), "Failed to load feature extractor" logger.debug(f"Loading DetrForObjectDetection {MODEL} ...") - object_detection_model: DetrForObjectDetection = ( - DetrForObjectDetection.from_pretrained(MODEL) - ) + object_detection_model = DetrForObjectDetection.from_pretrained(MODEL) + assert isinstance( + object_detection_model, DetrForObjectDetection + ), "Failed to load object detection model" + object_detection_model.to(DEVICE) object_detection_model.eval() - self.feature_extractor: DetrFeatureExtractor = feature_extractor - self.object_detection_model: DetrForObjectDetection = object_detection_model + self.feature_extractor = feature_extractor + self.object_detection_model = object_detection_model def object_detection(self, input: DETRFilePathInput) -> DETRObjectDetectionOutput: with Image.open(input.image_fp) as img: diff --git a/backend/src/app/preprocessing/ray_model_worker/models/spacy.py b/backend/src/app/preprocessing/ray_model_worker/models/spacy.py index c31ea4f5a..fe801b216 100644 --- a/backend/src/app/preprocessing/ray_model_worker/models/spacy.py +++ b/backend/src/app/preprocessing/ray_model_worker/models/spacy.py @@ -5,7 +5,7 @@ from dto.spacy import SpacyInput, SpacyPipelineOutput, SpacySpan, SpacyToken from ray import serve from ray_config import build_ray_model_deployment_config, conf -from spacy import Language +from spacy.language import Language cc = conf.spacy @@ -31,7 +31,7 @@ def __init__(self): if len(DEVICE) > 4 and ":" in DEVICE else 0 ) - spacy.require_gpu(gpu_id=device_id) + spacy.require_gpu(gpu_id=device_id) # type: ignore nlp: Dict[str, Language] = dict() @@ -65,7 +65,6 @@ def pipeline(self, input: SpacyInput) -> SpacyPipelineOutput: text=token.text, start_char=token.idx, end_char=token.idx + len(token.text), - label=token.ent_type_, pos=token.pos_, lemma=token.lemma_, is_stopword=token.is_stop, diff --git a/backend/src/app/preprocessing/ray_model_worker/models/vit_gpt2.py b/backend/src/app/preprocessing/ray_model_worker/models/vit_gpt2.py index cb5137a75..e2f121a5b 100644 --- a/backend/src/app/preprocessing/ray_model_worker/models/vit_gpt2.py +++ b/backend/src/app/preprocessing/ray_model_worker/models/vit_gpt2.py @@ -26,14 +26,16 @@ class ViTGPT2Model: def __init__(self): logger.debug(f"Loading ViTFeatureExtractor {MODEL} ...") - image_processor: ViTFeatureExtractor = ViTFeatureExtractor.from_pretrained( - MODEL - ) + image_processor = ViTFeatureExtractor.from_pretrained(MODEL) + assert isinstance( + image_processor, ViTFeatureExtractor + ), "Failed to load feature extractor" logger.debug(f"Loading VisionEncoderDecoderModel {MODEL} ...") - captioning_model: VisionEncoderDecoderModel = ( - VisionEncoderDecoderModel.from_pretrained(MODEL) - ) + captioning_model = VisionEncoderDecoderModel.from_pretrained(MODEL) + assert isinstance( + captioning_model, VisionEncoderDecoderModel + ), "Failed to load captioning model" captioning_model.to(DEVICE) captioning_model.eval() diff --git a/backend/src/app/preprocessing/ray_model_worker/models/whisper.py b/backend/src/app/preprocessing/ray_model_worker/models/whisper.py index 10010a54b..ef823d4ee 100644 --- a/backend/src/app/preprocessing/ray_model_worker/models/whisper.py +++ b/backend/src/app/preprocessing/ray_model_worker/models/whisper.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List import numpy as np import torch @@ -57,9 +57,7 @@ def transcribe_fpi(self, input: WhisperFilePathInput) -> WhisperTranscriptionOut ) with torch.no_grad(): - result: Tuple[Dict[str, Any], Any] = self.model.transcribe( - audio=audionp, **transcribe_options - ) + result = self.model.transcribe(audio=audionp, **transcribe_options) transcriptions = list(result[0]) segments: List[SegmentTranscription] = [] @@ -69,6 +67,9 @@ def transcribe_fpi(self, input: WhisperFilePathInput) -> WhisperTranscriptionOut start_ms=int(segment.start * 1000), end_ms=int(segment.end * 1000), ) + if segment.words is None: + continue + for word in segment.words: words.append( WordTranscription( diff --git a/backend/src/app/preprocessing/ray_model_worker/ray_config.py b/backend/src/app/preprocessing/ray_model_worker/ray_config.py index 6c5fd5b44..531e4dd97 100644 --- a/backend/src/app/preprocessing/ray_model_worker/ray_config.py +++ b/backend/src/app/preprocessing/ray_model_worker/ray_config.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Dict +from typing import Dict, TypedDict from omegaconf import DictConfig, OmegaConf @@ -9,12 +9,17 @@ # global config __conf_file__ = os.getenv("RAY_CONFIG", "./config.yaml") conf = OmegaConf.load(__conf_file__) -assert isinstance(conf, DictConfig), f"Cannot load Ray Config from {__conf_file__}" logger.info(f"Loaded config '{__conf_file__}'") -def build_ray_model_deployment_config(name: str) -> Dict[str, Dict[str, Any]]: +class RayDeploymentConfig(TypedDict): + ray_actor_options: Dict + autoscaling_config: Dict + + +def build_ray_model_deployment_config(name: str) -> RayDeploymentConfig: + assert isinstance(conf, DictConfig), f"Invalid Ray Config format ({__conf_file__})" cc = conf.get(name, None) if cc is None: raise KeyError(f"Cannot access {name} in {__conf_file__}") diff --git a/backend/src/app/trainer/trainer_service.py b/backend/src/app/trainer/trainer_service.py index 52ff26e26..0b713a615 100644 --- a/backend/src/app/trainer/trainer_service.py +++ b/backend/src/app/trainer/trainer_service.py @@ -1,9 +1,7 @@ -from typing import List - from loguru import logger from sqlalchemy.orm import Session -from app.celery.background_jobs import start_trainer_job_async, use_trainer_model_async +from app.celery.background_jobs import start_trainer_job_async from app.core.data.crud.project import crud_project from app.core.data.dto.background_job_base import BackgroundJobStatus from app.core.data.dto.trainer_job import ( @@ -37,13 +35,13 @@ def create_and_start_trainer_job_async( return trainer_job_read - def use_trainer_model(self, *, db: Session, trainer_job_id: str) -> List[float]: + def use_trainer_model(self, *, db: Session, trainer_job_id: str): # make sure the trainer job exists! trainer_job = self.redis.load_trainer_job(trainer_job_id) # make sure the project exists! crud_project.read(db=db, id=trainer_job.parameters.project_id) - return use_trainer_model_async(trainer_job_id=trainer_job_id).get() + # return use_trainer_model_async(trainer_job_id=trainer_job_id).get() def _start_trainer_job_sync(self, trainer_job_id: str) -> TrainerJobRead: trainer_job = self.redis.load_trainer_job(trainer_job_id) diff --git a/backend/src/main.py b/backend/src/main.py index 715a9317a..777bd8c87 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -14,7 +14,7 @@ from loguru import logger from psycopg2.errors import UniqueViolation from sqlalchemy.exc import IntegrityError -from uvicorn.main import uvicorn +from uvicorn.main import run ##################################################################################################################### # READ BEFORE CHANGING # @@ -286,7 +286,7 @@ def main() -> None: is_debug = conf.api.production_mode == "0" - uvicorn.run( + run( "main:app", host="0.0.0.0", port=port, diff --git a/docker/monkey_patch_docker_compose_for_backend_tests.py b/docker/monkey_patch_docker_compose_for_backend_tests.py index 3a2d62dc1..80d0a120f 100644 --- a/docker/monkey_patch_docker_compose_for_backend_tests.py +++ b/docker/monkey_patch_docker_compose_for_backend_tests.py @@ -28,5 +28,4 @@ data["services"][a].pop("deploy", None) with open("compose-test.yml", "w") as f: - dumpy = yaml.dump(data, f) - f.write = dumpy + yaml.dump(data, f) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..db59d72fd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[tool.ruff] +src = ["backend/src"] + +[tool.ruff.lint] + +# I: isort-compatible import sorting +# W291: Trailing whitespace +# W292: Add newline to end of file +extend-select = ["I", "W292", "W291"] + +[tool.pyright] +include = ["backend/src", "tools"] +exclude = ["**/__pycache__", + "backend/src/dev_notebooks", +] +extraPaths = ["backend/src"] +reportIncompatibleMethodOverride = false +reportIncompatibleVariableOverride = false diff --git a/tools/crawler/pipelines/htmlclean_pipeline.py b/tools/crawler/pipelines/htmlclean_pipeline.py index 4cd8702f0..06a5f6733 100644 --- a/tools/crawler/pipelines/htmlclean_pipeline.py +++ b/tools/crawler/pipelines/htmlclean_pipeline.py @@ -26,7 +26,7 @@ def process_item(self, item: GenericWebsiteItem, spider): cleaner = clean.Cleaner( safe_attrs_only=True, safe_attrs=safe_attrs, kill_tags=kill_tags ) - html_content = cleaner.clean_html(html_content) + html_content = str(cleaner.clean_html(html_content)) # manually remove tags html_content = html_content.replace("
", "") diff --git a/tools/crawler/pipelines/image_pipeline.py b/tools/crawler/pipelines/image_pipeline.py index d27e98e30..41c03970e 100644 --- a/tools/crawler/pipelines/image_pipeline.py +++ b/tools/crawler/pipelines/image_pipeline.py @@ -8,7 +8,7 @@ class MyImagesPipeline(ImagesPipeline): - def file_path(self, request, response=None, info=None, *, item=None): + def file_path(self, request, response=None, info=None, *, item): # the name of the html page (without .html) file_name = item["file_name"] diff --git a/tools/crawler/pipelines/readability_pipeline.py b/tools/crawler/pipelines/readability_pipeline.py index d9fbb89e9..4ba2ab36b 100644 --- a/tools/crawler/pipelines/readability_pipeline.py +++ b/tools/crawler/pipelines/readability_pipeline.py @@ -5,7 +5,7 @@ # useful for handling different item types with a single interface -from readability.readability import Readability +from readability.readability import Readability # type: ignore from crawler.items import GenericWebsiteItem diff --git a/tools/crawler/spiders/crawl_spider_base.py b/tools/crawler/spiders/crawl_spider_base.py index 3727a9469..b4118b145 100644 --- a/tools/crawler/spiders/crawl_spider_base.py +++ b/tools/crawler/spiders/crawl_spider_base.py @@ -54,14 +54,11 @@ def write_raw_response( filename = self.generate_filename(response=response) filename_with_extension = f"{filename}.html" - try: - with open( - self.output_dir / filename_with_extension, "w", encoding="UTF-8" - ) as f: - f.write(response.body.decode(response.encoding)) - except UnicodeDecodeError: - with open(self.output_dir / filename_with_extension, "wb") as f2: - f2.write(response.body) + with open( + self.output_dir / filename_with_extension, "w", encoding="UTF-8" + ) as f: + f.write(response.text) + self.log(f"Saved raw html {filename_with_extension}") def _create_cookies_dict(self, cookie: str) -> Dict[str, str]: @@ -95,10 +92,7 @@ def init_item( filename if filename else self.generate_filename(response=response) ) - try: - item["raw_html"] = response.body.decode(response.encoding) - except UnicodeDecodeError: - item["raw_html"] = response.body + item["raw_html"] = response.text if html: item["extracted_html"] = html diff --git a/tools/crawler/spiders/forums/ilforumdegliincel.py b/tools/crawler/spiders/forums/ilforumdegliincel.py index 340b66bc7..400e1ef12 100644 --- a/tools/crawler/spiders/forums/ilforumdegliincel.py +++ b/tools/crawler/spiders/forums/ilforumdegliincel.py @@ -23,8 +23,13 @@ def __init__(self, thread_id=None, max_pages=10, *args, **kwargs): def parse(self, response, **kwargs): # set current thread - self.current_thread = re.search(r"(\?t=\w*)", response.url).group(1) - thread_id = re.search(r"(\?t=)(\w*)", self.current_thread).group(2) + match = re.search(r"(\?t=\w*)", response.url) + if match: + self.current_thread = match.group(1) + + # extract thread id + match = re.search(r"(\?t=)(\w*)", self.current_thread) + thread_id = match.group(2) if match else -1 # find the number of pages of this thread pages_element = response.css( diff --git a/tools/crawler/spiders/forums/ilforumdeibrutti.py b/tools/crawler/spiders/forums/ilforumdeibrutti.py index 5248c2a60..a140c3193 100644 --- a/tools/crawler/spiders/forums/ilforumdeibrutti.py +++ b/tools/crawler/spiders/forums/ilforumdeibrutti.py @@ -24,8 +24,13 @@ def __init__(self, thread_id=None, max_pages=10, *args, **kwargs): def parse(self, response, **kwargs): # set current thread - self.current_thread = re.search(r"(\?t=\w*)", response.url).group(1) - thread_id = re.search(r"(\?t=)(\w*)", self.current_thread).group(2) + match = re.search(r"(\?t=\w*)", response.url) + if match: + self.current_thread = match.group(1) + + # extract thread id + match = re.search(r"(\?t=)(\w*)", self.current_thread) + thread_id = match.group(2) if match else -1 # find the number of pages of this thread pages_element = response.css( diff --git a/tools/crawler/spiders/forums/unbruttoforum.py b/tools/crawler/spiders/forums/unbruttoforum.py index 39c2a66c2..604bc4abf 100644 --- a/tools/crawler/spiders/forums/unbruttoforum.py +++ b/tools/crawler/spiders/forums/unbruttoforum.py @@ -24,8 +24,13 @@ def __init__(self, thread_id=None, max_pages=10, *args, **kwargs): def parse(self, response, **kwargs): # set current thread - self.current_thread = re.search(r"(\?t=\w*)", response.url).group(1) - thread_id = re.search(r"(\?t=)(\w*)", self.current_thread).group(2) + match = re.search(r"(\?t=\w*)", response.url) + if match: + self.current_thread = match.group(1) + + # extract thread id + match = re.search(r"(\?t=)(\w*)", self.current_thread) + thread_id = match.group(2) if match else -1 # find the number of pages of this thread pages_element = response.css( diff --git a/tools/crawler/spiders/spider_base.py b/tools/crawler/spiders/spider_base.py index 8d39c10ed..e85e389d7 100644 --- a/tools/crawler/spiders/spider_base.py +++ b/tools/crawler/spiders/spider_base.py @@ -53,14 +53,10 @@ def write_raw_response( filename = self.generate_filename(response=response) filename_with_extension = f"{filename}.html" - try: - with open( - self.output_dir / filename_with_extension, "w", encoding="UTF-8" - ) as f: - f.write(response.body.decode(response.encoding)) - except UnicodeDecodeError: - with open(self.output_dir / filename_with_extension, "wb") as f2: - f2.write(response.body) + with open( + self.output_dir / filename_with_extension, "w", encoding="UTF-8" + ) as f: + f.write(response.text) self.log(f"Saved raw html {filename_with_extension}") def _create_cookies_dict(self, cookie: str) -> Dict[str, str]: @@ -94,10 +90,7 @@ def init_item( filename if filename else self.generate_filename(response=response) ) - try: - item["raw_html"] = response.body.decode(response.encoding) - except UnicodeDecodeError: - item["raw_html"] = response.body + item["raw_html"] = response.text if html: item["extracted_html"] = html diff --git a/tools/importer/dats_importer.py b/tools/importer/dats_importer.py index 2ef4bcd91..ec7387224 100755 --- a/tools/importer/dats_importer.py +++ b/tools/importer/dats_importer.py @@ -238,10 +238,12 @@ files = list(temp.values()) -def upload_file_batch(file_batch: List[Tuple[str, Tuple[str, bytes, str]]]): +def upload_file_batch( + project_id: int, file_batch: List[Tuple[str, Tuple[str, bytes, str]]] +): # file upload preprocessing_job = api.upload_files( - proj_id=project["id"], + proj_id=project_id, files=file_batch, filter_duplicate_files_before_upload=args.filter_duplicate_files_before_upload, ) @@ -294,7 +296,9 @@ def upload_file_batch(file_batch: List[Tuple[str, Tuple[str, bytes, str]]]): desc="Uploading batches... ", total=num_batches, ): - upload_file_batch(file_batch=files[i : i + args.batch_size]) + upload_file_batch( + project_id=project["id"], file_batch=files[i : i + args.batch_size] + ) api.refresh_login() if args.max_num_docs != -1 and (i + args.batch_size) >= args.max_num_docs: break diff --git a/tools/importer/dats_importer_metadata.py b/tools/importer/dats_importer_metadata.py index 55421a542..7fc4dab03 100755 --- a/tools/importer/dats_importer_metadata.py +++ b/tools/importer/dats_importer_metadata.py @@ -92,7 +92,11 @@ for key, metatype in zip(args.metadata_keys, args.metadata_types): if key not in project_metadata_map: project_metadata = api.create_project_metadata( - proj_id=project["id"], key=key, metatype=metatype, doctype=args.doctype + proj_id=project["id"], + key=key, + metatype=metatype, + doctype=args.doctype, + description=key, ) project_metadata_map[key] = project_metadata