Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve sdoc read #450

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions backend/src/api/endpoints/source_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from app.core.data.dto.source_document import (
SourceDocumentRead,
SourceDocumentUpdate,
SourceDocumentWithDataRead,
)
from app.core.data.dto.source_document_data import SourceDocumentDataRead
from app.core.data.dto.source_document_metadata import (
SourceDocumentMetadataReadResolved,
)
Expand All @@ -54,7 +54,7 @@

@router.get(
"/{sdoc_id}",
response_model=SourceDocumentWithDataRead,
response_model=SourceDocumentRead,
summary="Returns the SourceDocument with the given ID if it exists",
)
def get_by_id(
Expand All @@ -63,13 +63,53 @@ def get_by_id(
sdoc_id: int,
only_if_finished: bool = True,
authz_user: AuthzUser = Depends(),
) -> SourceDocumentWithDataRead:
) -> SourceDocumentRead:
authz_user.assert_in_same_project_as(Crud.SOURCE_DOCUMENT, sdoc_id)

if not only_if_finished:
crud_sdoc.get_status(db=db, sdoc_id=sdoc_id, raise_error_on_unfinished=True)

return crud_sdoc.read_with_data(db=db, id=sdoc_id)
return SourceDocumentRead.model_validate(crud_sdoc.read(db=db, id=sdoc_id))


@router.get(
"/data/{sdoc_id}",
response_model=SourceDocumentDataRead,
summary="Returns the SourceDocumentData with the given ID if it exists",
)
def get_by_id_with_data(
*,
db: Session = Depends(get_db_session),
sdoc_id: int,
only_if_finished: bool = True,
authz_user: AuthzUser = Depends(),
) -> SourceDocumentDataRead:
authz_user.assert_in_same_project_as(Crud.SOURCE_DOCUMENT, sdoc_id)

if not only_if_finished:
crud_sdoc.get_status(db=db, sdoc_id=sdoc_id, raise_error_on_unfinished=True)

sdoc_data = crud_sdoc.read_data(db=db, id=sdoc_id)
if sdoc_data is None:
# if data is none, that means the document is not a text document
# instead of returning html, we return the URL to the image / video / audio file
sdoc = SourceDocumentRead.model_validate(crud_sdoc.read(db=db, id=sdoc_id))
url = RepoService().get_sdoc_url(
sdoc=sdoc,
relative=True,
webp=True,
thumbnail=False,
)
return SourceDocumentDataRead(
id=sdoc_id,
project_id=sdoc.project_id,
token_character_offsets=[],
tokens=[],
sentences=[],
html=url,
)
else:
return SourceDocumentDataRead.model_validate(sdoc_data)


@router.delete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
COTASentence,
)
from app.core.data.dto.search import SearchColumns, SimSearchQuery
from app.core.data.dto.source_document import SourceDocumentWithDataRead
from app.core.data.orm.source_document import SourceDocumentORM
from app.core.data.orm.source_document_data import SourceDocumentDataORM
from app.core.data.orm.source_document_metadata import SourceDocumentMetadataORM
from app.core.db.sql_service import SQLService
from app.core.filters.filtering import Filter, LogicalOperator
Expand Down Expand Up @@ -91,24 +91,26 @@ def add_sentences_to_search_space(

# get the data from the database
with sqls.db_session() as db:
sdoc_data = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)

# map the data
sdoc_id2sdocreadwithdata: Dict[int, SourceDocumentWithDataRead] = {
sdoc_data_read.id: sdoc_data_read for sdoc_data_read in sdoc_data
sdoc_id2sdocdata: Dict[int, SourceDocumentDataORM] = {
sdoc_data_read.id: sdoc_data_read
for sdoc_data_read in sdoc_datas
if sdoc_data_read is not None
}

sentences = []
for cota_sent in search_space:
if cota_sent.sdoc_id not in sdoc_id2sdocreadwithdata:
if cota_sent.sdoc_id not in sdoc_id2sdocdata:
raise ValueError(
f"Could not find SourceDocumentWithDataRead for sdoc_id {cota_sent.sdoc_id}!"
f"Could not find SourceDocumentDataORM for sdoc_id {cota_sent.sdoc_id}!"
)
sdoc_data_read = sdoc_id2sdocreadwithdata[cota_sent.sdoc_id]
sdoc_data_read = sdoc_id2sdocdata[cota_sent.sdoc_id]

if cota_sent.sentence_id >= len(sdoc_data_read.sentences):
raise ValueError(
f"Could not find sentence with id {cota_sent.sentence_id} in SourceDocumentWithDataRead with id {sdoc_data_read.id}!"
f"Could not find sentence with id {cota_sent.sentence_id} in SourceDocumentDataORM with id {sdoc_data_read.id}!"
)
sentences.append(sdoc_data_read.sentences[cota_sent.sentence_id])

Expand Down
78 changes: 12 additions & 66 deletions backend/src/app/core/data/crud/source_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy import and_, desc, func, or_
from sqlalchemy.orm import Session

from app.core.data.crud.crud_base import CRUDBase, NoSuchElementError
from app.core.data.crud.crud_base import CRUDBase
from app.core.data.crud.source_document_metadata import crud_sdoc_meta
from app.core.data.dto.action import ActionType
from app.core.data.dto.document_tag import DocumentTagRead
Expand All @@ -14,7 +14,6 @@
SourceDocumentRead,
SourceDocumentReadAction,
SourceDocumentUpdate,
SourceDocumentWithDataRead,
)
from app.core.data.dto.source_document_data import SourceDocumentDataRead
from app.core.data.dto.source_document_metadata import SourceDocumentMetadataRead
Expand Down Expand Up @@ -57,78 +56,25 @@ def get_status(
raise SourceDocumentPreprocessingUnfinishedError(sdoc_id=sdoc_id)
return status

def read_with_data(self, db: Session, *, id: int) -> SourceDocumentWithDataRead:
def read_data(self, db: Session, *, id: int) -> Optional[SourceDocumentDataRead]:
db_obj = (
db.query(self.model, SourceDocumentDataORM)
.join(SourceDocumentDataORM, isouter=True)
.filter(self.model.id == id)
db.query(SourceDocumentDataORM)
.filter(SourceDocumentDataORM.id == id)
.first()
)
if not db_obj:
raise NoSuchElementError(self.model, id=id)
sdoc, data = db_obj.tuple()
sdoc_read = SourceDocumentRead.model_validate(sdoc)

# sdoc data is None for audio and video documents
if data is None:
sdoc_data_read = SourceDocumentDataRead(
id=sdoc.id,
content="",
html="",
token_starts=[],
token_ends=[],
sentence_starts=[],
sentence_ends=[],
tokens=[],
token_character_offsets=[],
sentences=[],
sentence_character_offsets=[],
)
else:
sdoc_data_read = SourceDocumentDataRead.model_validate(data)
return SourceDocumentWithDataRead(
**(sdoc_read.model_dump() | sdoc_data_read.model_dump())
)
return SourceDocumentDataRead.model_validate(db_obj) if db_obj else None

def read_with_data_batch(
def read_data_batch(
self, db: Session, *, ids: List[int]
) -> List[SourceDocumentWithDataRead]:
) -> List[Optional[SourceDocumentDataORM]]:
db_objs = (
db.query(SourceDocumentORM, SourceDocumentDataORM)
.join(SourceDocumentDataORM, isouter=True)
.filter(SourceDocumentORM.id.in_(ids))
db.query(SourceDocumentDataORM)
.filter(SourceDocumentDataORM.id.in_(ids))
.all()
)

results = []
for db_obj in db_objs:
sdoc, data = db_obj
sdoc_read = SourceDocumentRead.model_validate(sdoc)

if data is None:
sdoc_data_read = SourceDocumentDataRead(
id=sdoc.id,
content="",
html="",
token_starts=[],
token_ends=[],
sentence_starts=[],
sentence_ends=[],
tokens=[],
token_character_offsets=[],
sentences=[],
sentence_character_offsets=[],
)
else:
sdoc_data_read = SourceDocumentDataRead.model_validate(data)

results.append(
SourceDocumentWithDataRead(
**(sdoc_read.model_dump() | sdoc_data_read.model_dump())
)
)

return results
# create id, data map
id2data = {db_obj.id: db_obj for db_obj in db_objs}
return [id2data.get(id) for id in ids]

def remove(self, db: Session, *, id: int) -> SourceDocumentORM:
# Import SimSearchService here to prevent a cyclic dependency
Expand Down
5 changes: 0 additions & 5 deletions backend/src/app/core/data/dto/source_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from app.core.data.doc_type import DocType
from app.core.data.dto.document_tag import DocumentTagRead
from app.core.data.dto.dto_base import UpdateDTOBase
from app.core.data.dto.source_document_data import SourceDocumentDataRead
from app.core.data.dto.source_document_metadata import SourceDocumentMetadataRead

SDOC_FILENAME_MAX_LENGTH = 200
Expand Down Expand Up @@ -57,7 +56,3 @@ class SourceDocumentReadAction(SourceDocumentRead):

class SourceDocumentCreate(SourceDocumentBaseDTO):
pass


class SourceDocumentWithDataRead(SourceDocumentRead, SourceDocumentDataRead):
pass
11 changes: 6 additions & 5 deletions backend/src/app/core/data/dto/source_document_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ class SourceDocumentDataBase(BaseModel):
)


class SourceDocumentDataRead(SourceDocumentDataBase):
class SourceDocumentDataRead(BaseModel):
id: int = Field(description="ID of the SourceDocument")
project_id: int = Field(
description="ID of the Project the SourceDocument belongs to"
)
html: str = Field(description="Processed HTML of the SourceDocument")
tokens: List[str] = Field(description="List of tokens in the SourceDocument")
token_character_offsets: List[Tuple[int, int]] = Field(
description="List of character offsets of each token"
)

sentences: List[str] = Field(description="List of sentences in the SourceDocument")
sentence_character_offsets: List[Tuple[int, int]] = Field(
description="List of character offsets of each sentence"
)

model_config = ConfigDict(from_attributes=True)

Expand Down
27 changes: 21 additions & 6 deletions backend/src/app/core/data/llm/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,16 @@ def _llm_document_tagging(
)

# read sdocs
sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)

# automatic document tagging
result: List[DocumentTaggingResult] = []
for idx, sdoc_data in enumerate(sdoc_datas):
for idx, (sdoc_id, sdoc_data) in enumerate(zip(sdoc_ids, sdoc_datas)):
if sdoc_data is None:
raise ValueError(
f"Could not find SourceDocumentDataORM for sdoc_id {sdoc_id}!"
)

# get current tag ids
current_tag_ids = [
tag.id for tag in crud_sdoc.read(db=db, id=sdoc_data.id).document_tags
Expand Down Expand Up @@ -316,10 +321,15 @@ def _llm_metadata_extraction(
)

# read sdocs
sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)
# automatic metadata extraction
result: List[MetadataExtractionResult] = []
for idx, sdoc_data in enumerate(sdoc_datas):
for idx, (sdoc_id, sdoc_data) in enumerate(zip(sdoc_ids, sdoc_datas)):
if sdoc_data is None:
raise ValueError(
f"Could not find SourceDocumentDataORM for sdoc_id {sdoc_id}!"
)

# get current metadata values
current_metadata = [
SourceDocumentMetadataReadResolved.model_validate(metadata)
Expand Down Expand Up @@ -426,12 +436,17 @@ def _llm_annotation(
)

# read sdocs
sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids)
sdoc_datas = crud_sdoc.read_data_batch(db=db, ids=sdoc_ids)

# automatic annotation
annotation_id = 0
result: List[AnnotationResult] = []
for idx, sdoc_data in enumerate(sdoc_datas):
for idx, (sdoc_id, sdoc_data) in enumerate(zip(sdoc_ids, sdoc_datas)):
if sdoc_data is None:
raise ValueError(
f"Could not find SourceDocumentDataORM for sdoc_id {sdoc_id}!"
)

# get language
language = crud_sdoc_meta.read_by_sdoc_and_key(
db=db, sdoc_id=sdoc_data.id, key="language"
Expand Down
14 changes: 12 additions & 2 deletions backend/src/app/core/data/orm/source_document_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import List
from typing import TYPE_CHECKING, List

from sqlalchemy import ForeignKey, Integer, String
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.core.data.orm.orm_base import ORMBase

if TYPE_CHECKING:
from app.core.data.orm.source_document import SourceDocumentORM


class SourceDocumentDataORM(ORMBase):
id: Mapped[int] = mapped_column(
Expand All @@ -15,6 +18,9 @@ class SourceDocumentDataORM(ORMBase):
nullable=False,
index=True,
)
source_document: Mapped["SourceDocumentORM"] = relationship(
"SourceDocumentORM", back_populates="data"
)
content: Mapped[str] = mapped_column(String, nullable=False, index=False)
html: Mapped[str] = mapped_column(String, nullable=False, index=False)
token_starts: Mapped[List[int]] = mapped_column(
Expand All @@ -30,6 +36,10 @@ class SourceDocumentDataORM(ORMBase):
ARRAY(Integer), nullable=False, index=False
)

@property
def project_id(self) -> int:
return self.source_document.project_id

@property
def tokens(self):
return [self.content[s:e] for s, e in zip(self.token_starts, self.token_ends)]
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/api/QueryKey.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export const QueryKey = {

// a single document (by sdoc id)
SDOC: "sdoc",
// a single document's data (by sdoc id)
SDOC_DATA: "sdocData",
// all tags of a document (by sdoc id)
SDOC_TAGS: "sdocTags",
// Count how many source documents each tag has
Expand Down
Loading
Loading