Skip to content

Commit

Permalink
Merge pull request #450 from uhh-lt/improve-sdoc-read
Browse files Browse the repository at this point in the history
Improve sdoc read
  • Loading branch information
bigabig authored Oct 18, 2024
2 parents 3243817 + f3dd09a commit 522a268
Show file tree
Hide file tree
Showing 29 changed files with 405 additions and 440 deletions.
48 changes: 44 additions & 4 deletions backend/src/api/endpoints/source_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from app.core.data.dto.source_document import (
SourceDocumentRead,
SourceDocumentUpdate,
SourceDocumentWithDataRead,
)
from app.core.data.dto.source_document_data import SourceDocumentDataRead
from app.core.data.dto.source_document_metadata import (
SourceDocumentMetadataReadResolved,
)
Expand All @@ -54,7 +54,7 @@

@router.get(
"/{sdoc_id}",
response_model=SourceDocumentWithDataRead,
response_model=SourceDocumentRead,
summary="Returns the SourceDocument with the given ID if it exists",
)
def get_by_id(
Expand All @@ -63,13 +63,53 @@ def get_by_id(
sdoc_id: int,
only_if_finished: bool = True,
authz_user: AuthzUser = Depends(),
) -> SourceDocumentWithDataRead:
) -> SourceDocumentRead:
authz_user.assert_in_same_project_as(Crud.SOURCE_DOCUMENT, sdoc_id)

if not only_if_finished:
crud_sdoc.get_status(db=db, sdoc_id=sdoc_id, raise_error_on_unfinished=True)

return crud_sdoc.read_with_data(db=db, id=sdoc_id)
return SourceDocumentRead.model_validate(crud_sdoc.read(db=db, id=sdoc_id))


@router.get(
"/data/{sdoc_id}",
response_model=SourceDocumentDataRead,
summary="Returns the SourceDocumentData with the given ID if it exists",
)
def get_by_id_with_data(
*,
db: Session = Depends(get_db_session),
sdoc_id: int,
only_if_finished: bool = True,
authz_user: AuthzUser = Depends(),
) -> SourceDocumentDataRead:
authz_user.assert_in_same_project_as(Crud.SOURCE_DOCUMENT, sdoc_id)

if not only_if_finished:
crud_sdoc.get_status(db=db, sdoc_id=sdoc_id, raise_error_on_unfinished=True)

sdoc_data = crud_sdoc.read_data(db=db, id=sdoc_id)
if sdoc_data is None:
# if data is none, that means the document is not a text document
# instead of returning html, we return the URL to the image / video / audio file
sdoc = SourceDocumentRead.model_validate(crud_sdoc.read(db=db, id=sdoc_id))
url = RepoService().get_sdoc_url(
sdoc=sdoc,
relative=True,
webp=True,
thumbnail=False,
)
return SourceDocumentDataRead(
id=sdoc_id,
project_id=sdoc.project_id,
token_character_offsets=[],
tokens=[],
sentences=[],
html=url,
)
else:
return SourceDocumentDataRead.model_validate(sdoc_data)


@router.delete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
COTASentence,
)
from app.core.data.dto.search import SearchColumns, SimSearchQuery
from app.core.data.dto.source_document import SourceDocumentWithDataRead
from app.core.data.orm.source_document import SourceDocumentORM
from app.core.data.orm.source_document_data import SourceDocumentDataORM
from app.core.data.orm.source_document_metadata import SourceDocumentMetadataORM
from app.core.db.sql_service import SQLService
from app.core.filters.filtering import Filter, LogicalOperator
Expand Down Expand Up @@ -91,24 +91,26 @@ def add_sentences_to_search_space(

# get the data from the database
with sqls.db_session() as db:
sdoc_data = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)

# map the data
sdoc_id2sdocreadwithdata: Dict[int, SourceDocumentWithDataRead] = {
sdoc_data_read.id: sdoc_data_read for sdoc_data_read in sdoc_data
sdoc_id2sdocdata: Dict[int, SourceDocumentDataORM] = {
sdoc_data_read.id: sdoc_data_read
for sdoc_data_read in sdoc_datas
if sdoc_data_read is not None
}

sentences = []
for cota_sent in search_space:
if cota_sent.sdoc_id not in sdoc_id2sdocreadwithdata:
if cota_sent.sdoc_id not in sdoc_id2sdocdata:
raise ValueError(
f"Could not find SourceDocumentWithDataRead for sdoc_id {cota_sent.sdoc_id}!"
f"Could not find SourceDocumentDataORM for sdoc_id {cota_sent.sdoc_id}!"
)
sdoc_data_read = sdoc_id2sdocreadwithdata[cota_sent.sdoc_id]
sdoc_data_read = sdoc_id2sdocdata[cota_sent.sdoc_id]

if cota_sent.sentence_id >= len(sdoc_data_read.sentences):
raise ValueError(
f"Could not find sentence with id {cota_sent.sentence_id} in SourceDocumentWithDataRead with id {sdoc_data_read.id}!"
f"Could not find sentence with id {cota_sent.sentence_id} in SourceDocumentDataORM with id {sdoc_data_read.id}!"
)
sentences.append(sdoc_data_read.sentences[cota_sent.sentence_id])

Expand Down
78 changes: 12 additions & 66 deletions backend/src/app/core/data/crud/source_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy import and_, desc, func, or_
from sqlalchemy.orm import Session

from app.core.data.crud.crud_base import CRUDBase, NoSuchElementError
from app.core.data.crud.crud_base import CRUDBase
from app.core.data.crud.source_document_metadata import crud_sdoc_meta
from app.core.data.dto.action import ActionType
from app.core.data.dto.document_tag import DocumentTagRead
Expand All @@ -14,7 +14,6 @@
SourceDocumentRead,
SourceDocumentReadAction,
SourceDocumentUpdate,
SourceDocumentWithDataRead,
)
from app.core.data.dto.source_document_data import SourceDocumentDataRead
from app.core.data.dto.source_document_metadata import SourceDocumentMetadataRead
Expand Down Expand Up @@ -57,78 +56,25 @@ def get_status(
raise SourceDocumentPreprocessingUnfinishedError(sdoc_id=sdoc_id)
return status

def read_with_data(self, db: Session, *, id: int) -> SourceDocumentWithDataRead:
def read_data(self, db: Session, *, id: int) -> Optional[SourceDocumentDataRead]:
db_obj = (
db.query(self.model, SourceDocumentDataORM)
.join(SourceDocumentDataORM, isouter=True)
.filter(self.model.id == id)
db.query(SourceDocumentDataORM)
.filter(SourceDocumentDataORM.id == id)
.first()
)
if not db_obj:
raise NoSuchElementError(self.model, id=id)
sdoc, data = db_obj.tuple()
sdoc_read = SourceDocumentRead.model_validate(sdoc)

# sdoc data is None for audio and video documents
if data is None:
sdoc_data_read = SourceDocumentDataRead(
id=sdoc.id,
content="",
html="",
token_starts=[],
token_ends=[],
sentence_starts=[],
sentence_ends=[],
tokens=[],
token_character_offsets=[],
sentences=[],
sentence_character_offsets=[],
)
else:
sdoc_data_read = SourceDocumentDataRead.model_validate(data)
return SourceDocumentWithDataRead(
**(sdoc_read.model_dump() | sdoc_data_read.model_dump())
)
return SourceDocumentDataRead.model_validate(db_obj) if db_obj else None

def read_with_data_batch(
def read_data_batch(
self, db: Session, *, ids: List[int]
) -> List[SourceDocumentWithDataRead]:
) -> List[Optional[SourceDocumentDataORM]]:
db_objs = (
db.query(SourceDocumentORM, SourceDocumentDataORM)
.join(SourceDocumentDataORM, isouter=True)
.filter(SourceDocumentORM.id.in_(ids))
db.query(SourceDocumentDataORM)
.filter(SourceDocumentDataORM.id.in_(ids))
.all()
)

results = []
for db_obj in db_objs:
sdoc, data = db_obj
sdoc_read = SourceDocumentRead.model_validate(sdoc)

if data is None:
sdoc_data_read = SourceDocumentDataRead(
id=sdoc.id,
content="",
html="",
token_starts=[],
token_ends=[],
sentence_starts=[],
sentence_ends=[],
tokens=[],
token_character_offsets=[],
sentences=[],
sentence_character_offsets=[],
)
else:
sdoc_data_read = SourceDocumentDataRead.model_validate(data)

results.append(
SourceDocumentWithDataRead(
**(sdoc_read.model_dump() | sdoc_data_read.model_dump())
)
)

return results
# create id, data map
id2data = {db_obj.id: db_obj for db_obj in db_objs}
return [id2data.get(id) for id in ids]

def remove(self, db: Session, *, id: int) -> SourceDocumentORM:
# Import SimSearchService here to prevent a cyclic dependency
Expand Down
5 changes: 0 additions & 5 deletions backend/src/app/core/data/dto/source_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from app.core.data.doc_type import DocType
from app.core.data.dto.document_tag import DocumentTagRead
from app.core.data.dto.dto_base import UpdateDTOBase
from app.core.data.dto.source_document_data import SourceDocumentDataRead
from app.core.data.dto.source_document_metadata import SourceDocumentMetadataRead

SDOC_FILENAME_MAX_LENGTH = 200
Expand Down Expand Up @@ -57,7 +56,3 @@ class SourceDocumentReadAction(SourceDocumentRead):

class SourceDocumentCreate(SourceDocumentBaseDTO):
pass


class SourceDocumentWithDataRead(SourceDocumentRead, SourceDocumentDataRead):
pass
11 changes: 6 additions & 5 deletions backend/src/app/core/data/dto/source_document_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ class SourceDocumentDataBase(BaseModel):
)


class SourceDocumentDataRead(SourceDocumentDataBase):
class SourceDocumentDataRead(BaseModel):
id: int = Field(description="ID of the SourceDocument")
project_id: int = Field(
description="ID of the Project the SourceDocument belongs to"
)
html: str = Field(description="Processed HTML of the SourceDocument")
tokens: List[str] = Field(description="List of tokens in the SourceDocument")
token_character_offsets: List[Tuple[int, int]] = Field(
description="List of character offsets of each token"
)

sentences: List[str] = Field(description="List of sentences in the SourceDocument")
sentence_character_offsets: List[Tuple[int, int]] = Field(
description="List of character offsets of each sentence"
)

model_config = ConfigDict(from_attributes=True)

Expand Down
27 changes: 21 additions & 6 deletions backend/src/app/core/data/llm/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,16 @@ def _llm_document_tagging(
)

# read sdocs
sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)

# automatic document tagging
result: List[DocumentTaggingResult] = []
for idx, sdoc_data in enumerate(sdoc_datas):
for idx, (sdoc_id, sdoc_data) in enumerate(zip(sdoc_ids, sdoc_datas)):
if sdoc_data is None:
raise ValueError(
f"Could not find SourceDocumentDataORM for sdoc_id {sdoc_id}!"
)

# get current tag ids
current_tag_ids = [
tag.id for tag in crud_sdoc.read(db=db, id=sdoc_data.id).document_tags
Expand Down Expand Up @@ -316,10 +321,15 @@ def _llm_metadata_extraction(
)

# read sdocs
sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)
# automatic metadata extraction
result: List[MetadataExtractionResult] = []
for idx, sdoc_data in enumerate(sdoc_datas):
for idx, (sdoc_id, sdoc_data) in enumerate(zip(sdoc_ids, sdoc_datas)):
if sdoc_data is None:
raise ValueError(
f"Could not find SourceDocumentDataORM for sdoc_id {sdoc_id}!"
)

# get current metadata values
current_metadata = [
SourceDocumentMetadataReadResolved.model_validate(metadata)
Expand Down Expand Up @@ -426,12 +436,17 @@ def _llm_annotation(
)

# read sdocs
sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)

# automatic annotation
annotation_id = 0
result: List[AnnotationResult] = []
for idx, sdoc_data in enumerate(sdoc_datas):
for idx, (sdoc_id, sdoc_data) in enumerate(zip(sdoc_ids, sdoc_datas)):
if sdoc_data is None:
raise ValueError(
f"Could not find SourceDocumentDataORM for sdoc_id {sdoc_id}!"
)

# get language
language = crud_sdoc_meta.read_by_sdoc_and_key(
db=db, sdoc_id=sdoc_data.id, key="language"
Expand Down
14 changes: 12 additions & 2 deletions backend/src/app/core/data/orm/source_document_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import List
from typing import TYPE_CHECKING, List

from sqlalchemy import ForeignKey, Integer, String
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.core.data.orm.orm_base import ORMBase

if TYPE_CHECKING:
from app.core.data.orm.source_document import SourceDocumentORM


class SourceDocumentDataORM(ORMBase):
id: Mapped[int] = mapped_column(
Expand All @@ -15,6 +18,9 @@ class SourceDocumentDataORM(ORMBase):
nullable=False,
index=True,
)
source_document: Mapped["SourceDocumentORM"] = relationship(
"SourceDocumentORM", back_populates="data"
)
content: Mapped[str] = mapped_column(String, nullable=False, index=False)
html: Mapped[str] = mapped_column(String, nullable=False, index=False)
token_starts: Mapped[List[int]] = mapped_column(
Expand All @@ -30,6 +36,10 @@ class SourceDocumentDataORM(ORMBase):
ARRAY(Integer), nullable=False, index=False
)

@property
def project_id(self) -> int:
return self.source_document.project_id

@property
def tokens(self):
return [self.content[s:e] for s, e in zip(self.token_starts, self.token_ends)]
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/api/QueryKey.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export const QueryKey = {

// a single document (by sdoc id)
SDOC: "sdoc",
// a single document's data (by sdoc id)
SDOC_DATA: "sdocData",
// all tags of a document (by sdoc id)
SDOC_TAGS: "sdocTags",
// Count how many source documents each tag has
Expand Down
Loading

0 comments on commit 522a268

Please sign in to comment.