Skip to content

Commit

Permalink
Merge pull request #448 from uhh-lt/fix-backend-problems
Browse files Browse the repository at this point in the history
Fix backend problems
  • Loading branch information
bigabig authored Oct 17, 2024
2 parents d711d7a + b5d8ece commit a6cc3de
Show file tree
Hide file tree
Showing 60 changed files with 324 additions and 275 deletions.
8 changes: 0 additions & 8 deletions .ruff.toml

This file was deleted.

1 change: 1 addition & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// backend
"ms-python.python",
"ms-python.debugpy",
"ms-python.vscode-pylance",
"charliermarsh.ruff",
// frontend
"dbaeumer.vscode-eslint",
Expand Down
4 changes: 2 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
{
"python.envFile": "${workspaceFolder}/backend/.env",
"python.analysis.extraPaths": ["./backend/src"],
"python.autoComplete.extraPaths": ["./backend/src"],
"editor.formatOnSave": true,
"prettier.configPath": "./frontend/package.json",
"javascript.preferences.importModuleSpecifierEnding": "js",
"typescript.preferences.importModuleSpecifierEnding": "js",
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
}
},
"python.analysis.diagnosticMode": "workspace"
}
8 changes: 0 additions & 8 deletions backend/src/app/celery/background_jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@ def start_trainer_job_async(
start_trainer_job_task.apply_async(kwargs={"trainer_job_id": trainer_job_id})


def use_trainer_model_async(
trainer_job_id: str,
) -> Any:
from app.celery.background_jobs.tasks import use_trainer_model_task

return use_trainer_model_task.apply_async(kwargs={"trainer_job_id": trainer_job_id})


def import_uploaded_archive_apply_async(
archive_file_path: Path, project_id: int
) -> Any:
Expand Down
12 changes: 1 addition & 11 deletions backend/src/app/celery/background_jobs/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Tuple
from typing import Tuple

from app.celery.background_jobs.cota import start_cota_refinement_job_
from app.celery.background_jobs.crawl import start_crawler_job_
Expand All @@ -14,7 +14,6 @@
)
from app.celery.background_jobs.trainer import (
start_trainer_job_,
use_trainer_model_task_,
)
from app.celery.celery_worker import celery_worker
from app.core.data.dto.crawler_job import CrawlerJobRead
Expand All @@ -41,15 +40,6 @@ def start_trainer_job_task(trainer_job_id: str) -> None:
start_trainer_job_(trainer_job_id=trainer_job_id)


@celery_worker.task(
acks_late=True,
autoretry_for=(Exception,),
retry_kwargs={"max_retries": 0, "countdown": 5},
)
def use_trainer_model_task(trainer_job_id: str) -> List[float]:
return use_trainer_model_task_(trainer_job_id=trainer_job_id)


@celery_worker.task(acks_late=True)
def start_export_job(export_job: ExportJobRead) -> None:
start_export_job_(export_job=export_job)
Expand Down
6 changes: 0 additions & 6 deletions backend/src/app/celery/background_jobs/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from loguru import logger

from app.trainer.trainer_service import TrainerService
Expand All @@ -11,7 +9,3 @@ def start_trainer_job_(trainer_job_id: str) -> None:
ts_result = ts._start_trainer_job_sync(trainer_job_id=trainer_job_id)

logger.info(f"TrainerJob {trainer_job_id} has finished! Result: {ts_result}")


def use_trainer_model_task_(trainer_job_id: str) -> List[float]:
return ts._use_trainer_model_sync(trainer_job_id=trainer_job_id)
6 changes: 4 additions & 2 deletions backend/src/app/core/analysis/analysis_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def find_code_occurrences(
CodeORM.id == code_id,
)
)
query = query.group_by(SourceDocumentORM, CodeORM, SpanTextORM.text)
query = query.group_by(SourceDocumentORM.id, CodeORM.id, SpanTextORM.text)
res = query.all()
span_code_occurrences = [
CodeOccurrence(
Expand Down Expand Up @@ -205,7 +205,9 @@ def find_code_occurrences(
)
)
query = query.group_by(
SourceDocumentORM, CodeORM, BBoxAnnotationORM.annotation_document_id
SourceDocumentORM.id,
CodeORM.id,
BBoxAnnotationORM.annotation_document_id,
)
res = query.all()
bbox_code_occurrences = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from datetime import datetime
from typing import Dict, List, Optional

Expand Down Expand Up @@ -40,7 +41,7 @@ def init_search_space(cargo: Cargo) -> Cargo:
top_k=cota.training_settings.search_space_topk,
threshold=cota.training_settings.search_space_threshold,
filter=Filter[SearchColumns](
items=[], logic_operator=LogicalOperator.and_
id=str(uuid.uuid4()), items=[], logic_operator=LogicalOperator.and_
),
),
)
Expand Down
6 changes: 4 additions & 2 deletions backend/src/app/core/analysis/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,19 @@ def timeline_analysis(
.subquery()
)

subquery.c

sdoc_ids_agg = aggregate_ids(SourceDocumentORM.id, label="sdoc_ids")

query = db.query(
sdoc_ids_agg,
*group_by.apply(subquery.c["date"]), # EXTRACT (WEEK FROM TIMESTAMP ...)
*group_by.apply(subquery.c[1]), # type: ignore
).join(subquery, SourceDocumentORM.id == subquery.c.id)

query = apply_filtering(
query=query, filter=filter, db=db, subquery_dict=subquery.c
)
query = query.group_by(*group_by.apply(column=subquery.c["date"]))
query = query.group_by(*group_by.apply(column=subquery.c["date"])) # type: ignore

result_rows = query.all()

Expand Down
4 changes: 2 additions & 2 deletions backend/src/app/core/data/crud/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session

from app.core.data.crud.crud_base import CRUDBase
from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed
from app.core.data.dto.action import ActionCreate, ActionTargetObjectType, ActionType
from app.core.data.orm.action import ActionORM


class CRUDAction(CRUDBase[ActionORM, ActionCreate, None]):
class CRUDAction(CRUDBase[ActionORM, ActionCreate, UpdateNotAllowed]):
def create(self, db: Session, *, create_dto: ActionCreate) -> ActionORM:
# we have to override this to avoid recursion
dto_obj_data = jsonable_encoder(create_dto)
Expand Down
4 changes: 4 additions & 0 deletions backend/src/app/core/data/crud/crud_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
UpdateDTOType = TypeVar("UpdateDTOType", bound=BaseModel)


class UpdateNotAllowed(BaseModel):
pass


class NoSuchElementError(Exception):
def __init__(self, model: Type[ORMModelType], **kwargs):
self.model = model
Expand Down
21 changes: 14 additions & 7 deletions backend/src/app/core/data/crud/object_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from app.core.data.crud.action import crud_action
from app.core.data.crud.bbox_annotation import crud_bbox_anno
from app.core.data.crud.code import crud_code
from app.core.data.crud.crud_base import CRUDBase, NoSuchElementError
from app.core.data.crud.crud_base import CRUDBase, NoSuchElementError, UpdateNotAllowed
from app.core.data.crud.document_tag import crud_document_tag
from app.core.data.crud.memo import crud_memo
from app.core.data.crud.project import crud_project
Expand All @@ -30,7 +30,7 @@
from app.core.db.sql_service import SQLService


class CRUDObjectHandle(CRUDBase[ObjectHandleORM, ObjectHandleCreate, None]):
class CRUDObjectHandle(CRUDBase[ObjectHandleORM, ObjectHandleCreate, UpdateNotAllowed]):
__obj_id_crud_map = {
"code_id": crud_code,
"document_tag_id": crud_document_tag,
Expand Down Expand Up @@ -65,11 +65,18 @@ def create(self, db: Session, *, create_dto: ObjectHandleCreate) -> ObjectHandle
if isinstance(e.orig, UniqueViolation):
db.close() # Flo: close the session because we have to start a new transaction
with SQLService().db_session() as sess:
for obj_id_key, obj_id_val in create_dto.model_dump().items():
if obj_id_val:
return self.read_by_attached_object_id(
db=sess, obj_id_key=obj_id_key, obj_id_val=obj_id_val
)
obj_id_key, obj_id_val = next(
filter(
lambda item: item[0] is not None and item[1] is not None,
create_dto.model_dump().items(),
),
(None, None),
)
if obj_id_key is not None and obj_id_val is not None:
return self.read_by_attached_object_id(
db=sess, obj_id_key=obj_id_key, obj_id_val=obj_id_val
)
raise e
else:
# Flo: re-raise Exception since it's not a UC Violation
raise e
Expand Down
4 changes: 2 additions & 2 deletions backend/src/app/core/data/crud/source_document_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from app.core.data.crud.crud_base import CRUDBase
from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed
from app.core.data.dto.source_document_data import SourceDocumentDataCreate
from app.core.data.orm.source_document_data import SourceDocumentDataORM

Expand All @@ -7,7 +7,7 @@ class CRUDSourceDocumentData(
CRUDBase[
SourceDocumentDataORM,
SourceDocumentDataCreate,
None,
UpdateNotAllowed,
]
):
pass
Expand Down
14 changes: 7 additions & 7 deletions backend/src/app/core/data/crud/source_document_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

from sqlalchemy.orm import Session

from app.core.data.crud.crud_base import CRUDBase, ORMModelType, UpdateDTOType
from app.core.data.crud.crud_base import (
CRUDBase,
UpdateNotAllowed,
)
from app.core.data.dto.source_document_link import SourceDocumentLinkCreate
from app.core.data.orm.source_document import SourceDocumentORM
from app.core.data.orm.source_document_link import SourceDocumentLinkORM


class CRUDSourceDocumentLink(
CRUDBase[SourceDocumentLinkORM, SourceDocumentLinkCreate, None]
CRUDBase[SourceDocumentLinkORM, SourceDocumentLinkCreate, UpdateNotAllowed]
):
def update(
self, db: Session, *, id: int, update_dto: UpdateDTOType
) -> ORMModelType:
def update(self, db: Session, *, id: int, update_dto):
raise NotImplementedError()

def resolve_filenames_to_sdoc_ids(
Expand All @@ -37,8 +38,7 @@ def resolve_filenames_to_sdoc_ids(
),
SourceDocumentORM.project_id == proj_id,
)
# noinspection PyTypeChecker
sdoc_fn_to_id: Dict[str, int] = dict(query2.all())
sdoc_fn_to_id: Dict[str, int] = {filename: id for filename, id in query2.all()}

resolved_links: List[SourceDocumentLinkORM] = []

Expand Down
8 changes: 4 additions & 4 deletions backend/src/app/core/data/crud/span_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from sqlalchemy.orm import Session

from app.core.data.crud.crud_base import CRUDBase
from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed
from app.core.data.dto.span_text import SpanTextCreate
from app.core.data.orm.span_text import SpanTextORM


class CRUDSpanText(CRUDBase[SpanTextORM, SpanTextCreate, None]):
def update(self, db: Session, *, id: int, update_dto) -> SpanTextORM:
# Flo: We no not want to update SourceDocument
class CRUDSpanText(CRUDBase[SpanTextORM, SpanTextCreate, UpdateNotAllowed]):
def update(self, db: Session, *, id: int, update_dto):
# Flo: We no not want to update SpanText
raise NotImplementedError()

def create(self, db: Session, *, create_dto: SpanTextCreate) -> SpanTextORM:
Expand Down
7 changes: 5 additions & 2 deletions backend/src/app/core/data/crud/word_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy.orm import Session

from app.core.data.crud.crud_base import CRUDBase
from app.core.data.crud.crud_base import CRUDBase, UpdateNotAllowed
from app.core.data.doc_type import DocType
from app.core.data.dto.word_frequency import WordFrequencyCreate, WordFrequencyRead
from app.core.data.orm.source_document import SourceDocumentORM
Expand All @@ -13,9 +13,12 @@ class CrudWordFrequency(
CRUDBase[
WordFrequencyORM,
WordFrequencyCreate,
None,
UpdateNotAllowed,
]
):
def update(self, db: Session, *, id: int, update_dto):
raise NotImplementedError()

def read_by_project_and_doctype(
self, db: Session, *, project_id: int, doctype: DocType
) -> List[WordFrequencyRead]:
Expand Down
2 changes: 1 addition & 1 deletion backend/src/app/core/data/dto/export_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class ExportJobParameters(BaseModel):
export_job_type: ExportJobType = Field(
description="The type of the export job (what to export)"
)
export_format: Optional[ExportFormat] = Field(
export_format: ExportFormat = Field(
description="The format of the exported data.",
default=ExportFormat.CSV,
)
Expand Down
13 changes: 7 additions & 6 deletions backend/src/app/core/data/dto/object_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ class ObjectHandleBaseDTO(BaseModel):
document_tag_id: Optional[int] = None
memo_id: Optional[int] = None

# noinspection PyMethodParameters
@model_validator(mode="before") # TODO: Before == root?
def check_at_least_one_not_null(cls, values):
for val in values:
if val:
return values
@model_validator(mode="after")
def check_at_least_one_not_null(self):
# make sure that at least one of the fields is not null
values = self.model_dump()
for val in values.values():
if val is not None:
return self
raise ValueError("At least one of the fields has to be not null!")


Expand Down
2 changes: 1 addition & 1 deletion backend/src/app/core/data/dto/source_document_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def with_value(
id=sdoc_metadata_id,
str_value=None,
boolean_value=None,
date_value=value if value is not None else datetime.now(),
date_value=datetime.fromisoformat(value),
int_value=None,
list_value=None,
source_document_id=source_document_id,
Expand Down
14 changes: 4 additions & 10 deletions backend/src/app/core/data/export/export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,13 +603,10 @@ def _export_all_user_annotations_from_sdoc(
)

# export the data
export_data = None
export_data = pd.DataFrame()
for adoc in all_adocs:
adoc_data = self.__generate_export_df_for_adoc(db=db, adoc=adoc)
if export_data is None:
export_data = adoc_data
else:
export_data = pd.concat((export_data, adoc_data))
export_data = pd.concat((export_data, adoc_data))

# write single file for all annos of that doc
export_file = self.__write_export_data_to_temp_file(
Expand Down Expand Up @@ -656,15 +653,12 @@ def _export_user_memos_from_proj(
f"There are no memos for User {user_id} in Project {project_id}!"
)

export_data = None
export_data = pd.DataFrame()
for memo in memos:
memo_data = self.__generate_export_df_for_memo(
db=db, memo_id=memo.id, memo=memo
)
if export_data is None:
export_data = memo_data
else:
export_data = pd.concat((export_data, memo_data))
export_data = pd.concat((export_data, memo_data))

export_file = self.__write_export_data_to_temp_file(
data=export_data,
Expand Down
Loading

0 comments on commit a6cc3de

Please sign in to comment.