diff --git a/.github/workflows/backend_check_schema.yml b/.github/workflows/backend_check_schema.yml index 3e789d174..65ecb67d3 100644 --- a/.github/workflows/backend_check_schema.yml +++ b/.github/workflows/backend_check_schema.yml @@ -16,6 +16,7 @@ jobs: env: API_PRODUCTION_WORKERS: 1 RAY_ENABLED: False + OLLAMA_ENABLED: False COMPOSE_PROFILES: "background" steps: - uses: actions/checkout@v3 @@ -25,7 +26,7 @@ jobs: ./setup-folders.sh cp .env.example .env chmod -R a+rwx backend_repo/ models_cache/ spacy_models/ - python monkey_patch_docker_compose_for_backend_tests.py --disable_ray + python monkey_patch_docker_compose_for_backend_tests.py --disable_ray --disable_ollama export GID=$(id -g) docker compose -f compose-test.yml up -d --quiet-pull postgres echo Waiting for containers to start... diff --git a/.github/workflows/backend_unit_tests.yml b/.github/workflows/backend_unit_tests.yml index 0760325fc..2c9702044 100644 --- a/.github/workflows/backend_unit_tests.yml +++ b/.github/workflows/backend_unit_tests.yml @@ -52,6 +52,7 @@ jobs: # disable backend and frontend COMPOSE_PROFILES: "background" RAY_ENABLED: False + OLLAMA_ENABLED: False POSTGRES_DB: dats-test JWT_SECRET: ${{ secrets.JWT_SECRET }} steps: @@ -66,7 +67,7 @@ jobs: ./setup-folders.sh cp .env.example .env chmod -R a+rwx backend_repo/ models_cache/ spacy_models/ - python monkey_patch_docker_compose_for_backend_tests.py --disable_ray + python monkey_patch_docker_compose_for_backend_tests.py --disable_ray --disable_ollama export GID=$(id -g) docker compose -f compose-test.yml up -d --quiet-pull echo Waiting for containers to start... diff --git a/.github/workflows/update-openapi-spec.yml b/.github/workflows/update-openapi-spec.yml index 9304252ab..1f7ddbd06 100644 --- a/.github/workflows/update-openapi-spec.yml +++ b/.github/workflows/update-openapi-spec.yml @@ -18,6 +18,7 @@ jobs: env: API_PRODUCTION_WORKERS: 1 RAY_ENABLED: False + OLLAMA_ENABLED: False API_EXPOSED: 5500 VITE_APP_SERVER: http://localhost:5500 steps: @@ -33,7 +34,7 @@ jobs: ./setup-folders.sh cp .env.example .env chmod -R a+rwx backend_repo/ models_cache/ spacy_models/ - python monkey_patch_docker_compose_for_backend_tests.py --disable_ray + python monkey_patch_docker_compose_for_backend_tests.py --disable_ray --disable_ollama export GID=$(id -g) docker compose -f compose-test.yml up -d --quiet-pull --wait --wait-timeout 300 echo Waiting for containers to start... diff --git a/backend/.env.example b/backend/.env.example index b80dd72d7..8c6077458 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -70,6 +70,11 @@ REDIS_PASSWORD=dats123 WEAVIATE_HOST=localhost WEAVIATE_PORT=13241 +OLLAMA_ENABLED=True +OLLAMA_HOST=localhost +OLLAMA_PORT=13242 +OLLAMA_MODEL=gemma2:latest + # Mail sending configuration MAIL_ENABLED=False MAIL_FROM=dats@uni-hamburg.de diff --git a/backend/.env.testing.example b/backend/.env.testing.example index 13b683919..f6b3e700a 100644 --- a/backend/.env.testing.example +++ b/backend/.env.testing.example @@ -5,6 +5,7 @@ # This way, you can keep only necessary overrides in .env.testing RAY_ENABLED=False +OLLAMA_ENABLED=False POSTGRES_DB=dats-testing # These are separate variables from `WEAVIATE_PORT` etc. # because we need to spin up separate containers for testing diff --git a/backend/Dockerfile b/backend/Dockerfile index 8a0d6cb49..1900fbfe3 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,7 +1,7 @@ # docker build -f Dockerfile -t uhhlt/dats_backend: . # docker push uhhlt/dats_backend: -FROM ubuntu:jammy-20221020 as ubuntu +FROM ubuntu:jammy-20221020 AS ubuntu CMD ["/bin/bash"] # makes CUDA devices visible to the container by default diff --git a/backend/requirements.txt b/backend/requirements.txt index 5c292412d..1832967fe 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,4 +1,5 @@ mammoth==1.6.0 +ollama==0.3.1 pymupdf==1.23.4 pytest-order==1.2.1 Scrapy==2.10.0 diff --git a/backend/src/alembic/versions/45549c9c4ff2_add_project_metadata_description.py b/backend/src/alembic/versions/45549c9c4ff2_add_project_metadata_description.py new file mode 100644 index 000000000..3bbe4081e --- /dev/null +++ b/backend/src/alembic/versions/45549c9c4ff2_add_project_metadata_description.py @@ -0,0 +1,44 @@ +"""add project metadata description + +Revision ID: 45549c9c4ff2 +Revises: 2b91203d1bb6 +Create Date: 2024-09-16 08:41:55.296647 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "45549c9c4ff2" +down_revision: Union[str, None] = "2b91203d1bb6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # add new column + op.add_column( + "projectmetadata", sa.Column("description", sa.String(), nullable=True) + ) + + # edit all existing rows to have a description + op.execute( + """ + UPDATE projectmetadata + SET description = 'Placeholder description' + WHERE description IS NULL + """ + ) + + # make the column not nullable + op.alter_column("projectmetadata", "description", nullable=False) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("projectmetadata", "description") + # ### end Alembic commands ### diff --git a/backend/src/api/endpoints/document_tag.py b/backend/src/api/endpoints/document_tag.py index 9f581a116..d0a85745d 100644 --- a/backend/src/api/endpoints/document_tag.py +++ b/backend/src/api/endpoints/document_tag.py @@ -14,6 +14,7 @@ DocumentTagCreate, DocumentTagRead, DocumentTagUpdate, + SourceDocumentDocumentTagLinks, SourceDocumentDocumentTagMultiLink, ) from app.core.data.dto.memo import AttachedObjectType, MemoCreate, MemoInDB, MemoRead @@ -114,6 +115,36 @@ def unlink_multiple_tags( ) +@router.patch( + "/bulk/set", + response_model=int, + summary="Sets SourceDocuments' tags to the provided tags", +) +def set_document_tags_batch( + *, + db: Session = Depends(get_db_session), + links: List[SourceDocumentDocumentTagLinks], + authz_user: AuthzUser = Depends(), + validate: Validate = Depends(), +) -> int: + sdoc_ids = [link.source_document_id for link in links] + tag_ids = list(set([tag_id for link in links for tag_id in link.document_tag_ids])) + # TODO this is a little inefficient, but at the moment + # the fronend is never sending more than one id at a time + authz_user.assert_in_same_project_as_many(Crud.SOURCE_DOCUMENT, sdoc_ids) + authz_user.assert_in_same_project_as_many(Crud.DOCUMENT_TAG, tag_ids) + + validate.validate_objects_in_same_project( + [(Crud.SOURCE_DOCUMENT, sdoc_id) for sdoc_id in sdoc_ids] + + [(Crud.DOCUMENT_TAG, tag_id) for tag_id in tag_ids] + ) + + return crud_document_tag.set_document_tags_batch( + db=db, + links={link.source_document_id: link.document_tag_ids for link in links}, + ) + + @router.get( "/{tag_id}", response_model=DocumentTagRead, diff --git a/backend/src/api/endpoints/llm.py b/backend/src/api/endpoints/llm.py new file mode 100644 index 000000000..7aa38ae34 --- /dev/null +++ b/backend/src/api/endpoints/llm.py @@ -0,0 +1,68 @@ +from typing import List + +from fastapi import APIRouter, Depends + +from api.dependencies import get_current_user +from app.celery.background_jobs import prepare_and_start_llm_job_async +from app.core.authorization.authz_user import AuthzUser +from app.core.data.dto.llm_job import LLMJobParameters, LLMJobRead, LLMPromptTemplates +from app.core.data.llm.llm_service import LLMService + +router = APIRouter( + prefix="/llm", dependencies=[Depends(get_current_user)], tags=["llm"] +) + +llms: LLMService = LLMService() + + +@router.post( + "", + response_model=LLMJobRead, + summary="Returns the LLMJob for the given Parameters", +) +def start_llm_job( + *, llm_job_params: LLMJobParameters, authz_user: AuthzUser = Depends() +) -> LLMJobRead: + authz_user.assert_in_project(llm_job_params.project_id) + + return prepare_and_start_llm_job_async(llm_job_params=llm_job_params) + + +@router.get( + "/{llm_job_id}", + response_model=LLMJobRead, + summary="Returns the LLMJob for the given ID if it exists", +) +def get_llm_job(*, llm_job_id: str, authz_user: AuthzUser = Depends()) -> LLMJobRead: + job = llms.get_llm_job(llm_job_id=llm_job_id) + authz_user.assert_in_project(job.parameters.project_id) + + return job + + +@router.get( + "/project/{project_id}", + response_model=List[LLMJobRead], + summary="Returns all LLMJobRead for the given project ID if it exists", +) +def get_all_llm_jobs( + *, project_id: int, authz_user: AuthzUser = Depends() +) -> List[LLMJobRead]: + authz_user.assert_in_project(project_id) + + llm_jobs = llms.get_all_llm_jobs(project_id=project_id) + llm_jobs.sort(key=lambda x: x.created, reverse=True) + return llm_jobs + + +@router.post( + "/create_prompt_templates", + response_model=List[LLMPromptTemplates], + summary="Returns the system and user prompt templates for the given llm task in all supported languages", +) +def create_prompt_templates( + *, llm_job_params: LLMJobParameters, authz_user: AuthzUser = Depends() +) -> List[LLMPromptTemplates]: + authz_user.assert_in_project(llm_job_params.project_id) + + return llms.create_prompt_templates(llm_job_params=llm_job_params) diff --git a/backend/src/api/endpoints/source_document_metadata.py b/backend/src/api/endpoints/source_document_metadata.py index 1ac20022a..cecc3b5e2 100644 --- a/backend/src/api/endpoints/source_document_metadata.py +++ b/backend/src/api/endpoints/source_document_metadata.py @@ -1,3 +1,5 @@ +from typing import List + from fastapi import APIRouter, Depends from sqlalchemy.orm import Session @@ -7,6 +9,7 @@ from app.core.data.crud import Crud from app.core.data.crud.source_document_metadata import crud_sdoc_meta from app.core.data.dto.source_document_metadata import ( + SourceDocumentMetadataBulkUpdate, SourceDocumentMetadataCreate, SourceDocumentMetadataRead, SourceDocumentMetadataReadResolved, @@ -82,6 +85,27 @@ def update_by_id( return SourceDocumentMetadataRead.model_validate(db_obj) +@router.patch( + "/bulk/update", + response_model=List[SourceDocumentMetadataRead], + summary="Updates multiple metadata objects at once.", +) +def update_bulk( + *, + db: Session = Depends(get_db_session), + metadatas: List[SourceDocumentMetadataBulkUpdate], + authz_user: AuthzUser = Depends(), +) -> List[SourceDocumentMetadataRead]: + authz_user.assert_in_same_project_as_many( + Crud.SOURCE_DOCUMENT_METADATA, [m.id for m in metadatas] + ) + + print("HI!") + + db_objs = crud_sdoc_meta.update_bulk(db=db, update_dtos=metadatas) + return [SourceDocumentMetadataRead.model_validate(db_obj) for db_obj in db_objs] + + @router.delete( "/{metadata_id}", response_model=SourceDocumentMetadataRead, diff --git a/backend/src/api/endpoints/span_annotation.py b/backend/src/api/endpoints/span_annotation.py index d30dc1604..50e720986 100644 --- a/backend/src/api/endpoints/span_annotation.py +++ b/backend/src/api/endpoints/span_annotation.py @@ -13,6 +13,7 @@ from app.core.data.dto.code import CodeRead from app.core.data.dto.memo import AttachedObjectType, MemoCreate, MemoInDB, MemoRead from app.core.data.dto.span_annotation import ( + SpanAnnotationCreateBulkWithCodeId, SpanAnnotationCreateWithCodeId, SpanAnnotationRead, SpanAnnotationReadResolved, @@ -66,6 +67,46 @@ def add_span_annotation( return span_dto +@router.put( + "/bulk/create", + response_model=Union[List[SpanAnnotationRead], List[SpanAnnotationReadResolved]], + summary="Creates a SpanAnnotations in Bulk", +) +def add_span_annotations_bulk( + *, + db: Session = Depends(get_db_session), + spans: List[SpanAnnotationCreateBulkWithCodeId], + resolve_code: bool = Depends(resolve_code_param), + authz_user: AuthzUser = Depends(), + validate: Validate = Depends(), +) -> Union[List[SpanAnnotationRead], List[SpanAnnotationReadResolved]]: + for span in spans: + authz_user.assert_in_same_project_as(Crud.CODE, span.code_id) + authz_user.assert_in_same_project_as(Crud.SOURCE_DOCUMENT, span.sdoc_id) + validate.validate_objects_in_same_project( + [ + (Crud.CODE, span.code_id), + (Crud.SOURCE_DOCUMENT, span.sdoc_id), + ] + ) + + db_objs = crud_span_anno.create_bulk(db=db, create_dtos=spans) + span_dtos = [SpanAnnotationRead.model_validate(db_obj) for db_obj in db_objs] + if resolve_code: + return [ + SpanAnnotationReadResolved( + **span_dto.model_dump(exclude={"current_code_id", "span_text_id"}), + code=CodeRead.model_validate(db_obj.current_code.code), + span_text=db_obj.span_text.text, + user_id=db_obj.annotation_document.user_id, + sdoc_id=db_obj.annotation_document.source_document_id, + ) + for span_dto, db_obj in zip(span_dtos, db_objs) + ] + else: + return span_dtos + + @router.get( "/{span_id}", response_model=Union[SpanAnnotationRead, SpanAnnotationReadResolved], diff --git a/backend/src/app/celery/background_jobs/__init__.py b/backend/src/app/celery/background_jobs/__init__.py index d7cb3dd57..9383b9d4a 100644 --- a/backend/src/app/celery/background_jobs/__init__.py +++ b/backend/src/app/celery/background_jobs/__init__.py @@ -4,7 +4,9 @@ from app.core.data.crawler.crawler_service import CrawlerService from app.core.data.dto.crawler_job import CrawlerJobParameters, CrawlerJobRead from app.core.data.dto.export_job import ExportJobParameters, ExportJobRead +from app.core.data.dto.llm_job import LLMJobParameters, LLMJobRead from app.core.data.export.export_service import ExportService +from app.core.data.llm.llm_service import LLMService from app.preprocessing.pipeline.model.pipeline_cargo import PipelineCargo @@ -76,6 +78,17 @@ def prepare_and_start_crawling_job_async( return cj +def prepare_and_start_llm_job_async( + llm_job_params: LLMJobParameters, +) -> LLMJobRead: + from app.celery.background_jobs.tasks import start_llm_job + + llms: LLMService = LLMService() + llm_job = llms.prepare_llm_job(llm_job_params) + start_llm_job.apply_async(kwargs={"llm_job": llm_job}) + return llm_job + + def execute_text_preprocessing_pipeline_apply_async( cargos: List[PipelineCargo], ) -> None: diff --git a/backend/src/app/celery/background_jobs/llm.py b/backend/src/app/celery/background_jobs/llm.py new file mode 100644 index 000000000..348c954ce --- /dev/null +++ b/backend/src/app/celery/background_jobs/llm.py @@ -0,0 +1,18 @@ +from loguru import logger + +from app.core.data.dto.llm_job import LLMJobRead +from app.core.data.llm.llm_service import LLMService + +llms: LLMService = LLMService() + + +def start_llm_job_(llm_job: LLMJobRead) -> None: + logger.info( + ( + f"Starting LLMJob {llm_job.id}", + f" with parameters:\n\t{llm_job.parameters.model_dump_json(indent=2)}", + ) + ) + llms.start_llm_job_sync(llm_job_id=llm_job.id) + + logger.info(f"LLMJob {llm_job.id} has finished!") diff --git a/backend/src/app/celery/background_jobs/tasks.py b/backend/src/app/celery/background_jobs/tasks.py index fb4389639..3e004155a 100644 --- a/backend/src/app/celery/background_jobs/tasks.py +++ b/backend/src/app/celery/background_jobs/tasks.py @@ -4,6 +4,7 @@ from app.celery.background_jobs.cota import start_cota_refinement_job_ from app.celery.background_jobs.crawl import start_crawler_job_ from app.celery.background_jobs.export import start_export_job_ +from app.celery.background_jobs.llm import start_llm_job_ from app.celery.background_jobs.preprocess import ( execute_audio_preprocessing_pipeline_, execute_image_preprocessing_pipeline_, @@ -18,6 +19,7 @@ from app.celery.celery_worker import celery_worker from app.core.data.dto.crawler_job import CrawlerJobRead from app.core.data.dto.export_job import ExportJobRead +from app.core.data.dto.llm_job import LLMJobRead from app.preprocessing.pipeline.model.pipeline_cargo import PipelineCargo @@ -59,6 +61,11 @@ def start_crawler_job(crawler_job: CrawlerJobRead) -> Tuple[Path, int]: return archive_file_path, project_id +@celery_worker.task(acks_late=True) +def start_llm_job(llm_job: LLMJobRead) -> None: + start_llm_job_(llm_job=llm_job) + + @celery_worker.task( acks_late=True, autoretry_for=(Exception,), diff --git a/backend/src/app/core/data/crud/annotation_document.py b/backend/src/app/core/data/crud/annotation_document.py index 39fe272fc..808cbc029 100644 --- a/backend/src/app/core/data/crud/annotation_document.py +++ b/backend/src/app/core/data/crud/annotation_document.py @@ -24,6 +24,25 @@ def update_timestamp( update_dto=AnnotationDocumentUpdate(updated=datetime.datetime.now()), ) + def exists_or_create( + self, db: Session, *, user_id: int, sdoc_id: int + ) -> AnnotationDocumentORM: + db_obj = ( + db.query(self.model) + .filter( + self.model.user_id == user_id, self.model.source_document_id == sdoc_id + ) + .first() + ) + if db_obj is None: + return self.create( + db=db, + create_dto=AnnotationDocumentCreate( + user_id=user_id, source_document_id=sdoc_id + ), + ) + return db_obj + def read_by_user(self, db: Session, *, user_id: int) -> List[AnnotationDocumentORM]: return db.query(self.model).filter(self.model.user_id == user_id).all() diff --git a/backend/src/app/core/data/crud/document_tag.py b/backend/src/app/core/data/crud/document_tag.py index 8579116d7..ab469d5b8 100644 --- a/backend/src/app/core/data/crud/document_tag.py +++ b/backend/src/app/core/data/crud/document_tag.py @@ -156,6 +156,45 @@ def unlink_multiple_document_tags( return len(del_rows) + def set_document_tags( + self, db: Session, *, sdoc_id: int, tag_ids: List[int] + ) -> int: + """ + Link/Unlink DocTags so that sdoc has exactly the tags + """ + # current state + from app.core.data.crud.source_document import crud_sdoc + + current_tag_ids = [ + tag.id for tag in crud_sdoc.read(db, id=sdoc_id).document_tags + ] + + # find tags to be added and removed + add_tag_ids = list(set(tag_ids) - set(current_tag_ids)) + del_tag_ids = list(set(current_tag_ids) - set(tag_ids)) + + modifications = 0 + if len(del_tag_ids) > 0: + modifications += self.unlink_multiple_document_tags( + db, sdoc_ids=[sdoc_id], tag_ids=del_tag_ids + ) + if len(add_tag_ids) > 0: + modifications += self.link_multiple_document_tags( + db, sdoc_ids=[sdoc_id], tag_ids=add_tag_ids + ) + + return modifications + + def set_document_tags_batch( + self, db: Session, *, links: Dict[int, List[int]] + ) -> int: + modifications = 0 + for sdoc_id, tag_ids in links.items(): + modifications += self.set_document_tags( + db, sdoc_id=sdoc_id, tag_ids=tag_ids + ) + return modifications + # Return a dictionary in the following format: # tag id => count of documents that have this tag # for all tags in the database diff --git a/backend/src/app/core/data/crud/project_metadata.py b/backend/src/app/core/data/crud/project_metadata.py index e54abe4b5..29b05bcf7 100644 --- a/backend/src/app/core/data/crud/project_metadata.py +++ b/backend/src/app/core/data/crud/project_metadata.py @@ -135,6 +135,7 @@ def create_project_metadata_for_project( metatype=project_metadata["metatype"], read_only=project_metadata["read_only"], doctype=project_metadata["doctype"], + description=project_metadata["description"], ) db_obj = self.create(db=db, create_dto=create_dto) created.append(db_obj) diff --git a/backend/src/app/core/data/crud/source_document_metadata.py b/backend/src/app/core/data/crud/source_document_metadata.py index 949e1ac38..18e5a77a8 100644 --- a/backend/src/app/core/data/crud/source_document_metadata.py +++ b/backend/src/app/core/data/crud/source_document_metadata.py @@ -8,6 +8,7 @@ from app.core.data.dto.project_metadata import ProjectMetadataRead from app.core.data.dto.source_document_metadata import ( SourceDocumentMetadataBaseDTO, + SourceDocumentMetadataBulkUpdate, SourceDocumentMetadataCreate, SourceDocumentMetadataUpdate, ) @@ -124,6 +125,21 @@ def update( return metadata_orm + def update_bulk( + self, db: Session, *, update_dtos: List[SourceDocumentMetadataBulkUpdate] + ) -> List[SourceDocumentMetadataORM]: + db_objs = [] + for update_dto in update_dtos: + db_obj = self.update( + db=db, + metadata_id=update_dto.id, + update_dto=SourceDocumentMetadataUpdate( + **update_dto.model_dump(exclude={"id"}) + ), + ) + db_objs.append(db_obj) + return db_objs + def read_by_project( self, db: Session, diff --git a/backend/src/app/core/data/crud/span_annotation.py b/backend/src/app/core/data/crud/span_annotation.py index 4847fca83..ef26c1f96 100644 --- a/backend/src/app/core/data/crud/span_annotation.py +++ b/backend/src/app/core/data/crud/span_annotation.py @@ -1,10 +1,11 @@ -from typing import List, Optional +from typing import Dict, List, Optional import srsly from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session from app.core.data.crud.annotation_document import crud_adoc +from app.core.data.crud.code import crud_code from app.core.data.crud.crud_base import CRUDBase from app.core.data.crud.span_group import crud_span_group from app.core.data.crud.span_text import crud_span_text @@ -12,6 +13,7 @@ from app.core.data.dto.code import CodeRead from app.core.data.dto.span_annotation import ( SpanAnnotationCreate, + SpanAnnotationCreateBulkWithCodeId, SpanAnnotationCreateWithCodeId, SpanAnnotationRead, SpanAnnotationReadResolved, @@ -113,6 +115,52 @@ def create_multi( return db_objs + def create_bulk( + self, db: Session, *, create_dtos: List[SpanAnnotationCreateBulkWithCodeId] + ) -> List[SpanAnnotationORM]: + # group by user and sdoc_id + # identify codes + annotations_by_user_sdoc = { + (create_dto.user_id, create_dto.sdoc_id): [] for create_dto in create_dtos + } + for create_dto in create_dtos: + annotations_by_user_sdoc[(create_dto.user_id, create_dto.sdoc_id)].append( + create_dto + ) + + # find or create annotation documents for each user and sdoc_id + adoc_id_by_user_sdoc = {} + for user_id, sdoc_id in annotations_by_user_sdoc.keys(): + adoc_id_by_user_sdoc[(user_id, sdoc_id)] = crud_adoc.exists_or_create( + db=db, user_id=user_id, sdoc_id=sdoc_id + ).id + + # find all codes + code_ids = list(set([create_dto.code_id for create_dto in create_dtos])) + db_codes = crud_code.read_by_ids(db=db, ids=code_ids) + cid2ccid: Dict[int, int] = {} + for db_code in db_codes: + cid2ccid[db_code.id] = db_code.current_code.id + + # create the annotations + return self.create_multi( + db=db, + create_dtos=[ + SpanAnnotationCreate( + begin=create_dto.begin, + end=create_dto.end, + span_text=create_dto.span_text, + begin_token=create_dto.begin_token, + end_token=create_dto.end_token, + current_code_id=cid2ccid[create_dto.code_id], + annotation_document_id=adoc_id_by_user_sdoc[ + (create_dto.user_id, create_dto.sdoc_id) + ], + ) + for create_dto in create_dtos + ], + ) + def read_by_adoc( self, db: Session, *, adoc_id: int, skip: int = 0, limit: int = 1000 ) -> List[SpanAnnotationORM]: diff --git a/backend/src/app/core/data/dto/document_tag.py b/backend/src/app/core/data/dto/document_tag.py index ffe9d8724..7a6039f3d 100644 --- a/backend/src/app/core/data/dto/document_tag.py +++ b/backend/src/app/core/data/dto/document_tag.py @@ -56,3 +56,8 @@ class SourceDocumentDocumentTagLink(BaseModel): class SourceDocumentDocumentTagMultiLink(BaseModel): source_document_ids: List[int] = Field(description="List of IDs of SourceDocuments") document_tag_ids: List[int] = Field(description="List of IDs of DocumentTags") + + +class SourceDocumentDocumentTagLinks(BaseModel): + source_document_id: int = Field(description="ID of SourceDocument") + document_tag_ids: List[int] = Field(description="List of IDs of DocumentTags") diff --git a/backend/src/app/core/data/dto/llm_job.py b/backend/src/app/core/data/dto/llm_job.py new file mode 100644 index 000000000..080980878 --- /dev/null +++ b/backend/src/app/core/data/dto/llm_job.py @@ -0,0 +1,178 @@ +from datetime import datetime +from enum import Enum +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from app.core.data.dto.background_job_base import BackgroundJobStatus +from app.core.data.dto.dto_base import UpdateDTOBase +from app.core.data.dto.source_document_metadata import ( + SourceDocumentMetadataReadResolved, +) +from app.core.data.dto.span_annotation import SpanAnnotationReadResolved + + +class LLMJobType(str, Enum): + DOCUMENT_TAGGING = "DOCUMENT_TAGGING" + METADATA_EXTRACTION = "METADATA_EXTRACTION" + ANNOTATION = "ANNOTATION" + + +# Prompt template +class LLMPromptTemplates(BaseModel): + language: str = Field(description="The language of the prompt template") + system_prompt: str = Field(description="The system prompt to use for the job") + user_prompt: str = Field(description="The user prompt to use for the job") + + +# --- START PARAMETERS --- + + +class SpecificLLMJobParameters(BaseModel): + llm_job_type: LLMJobType = Field(description="The type of the LLMJob (what to llm)") + + +class DocumentBasedLLMJobParams(SpecificLLMJobParameters): + sdoc_ids: List[int] = Field(description="IDs of the source documents to analyse") + + +class DocumentTaggingLLMJobParams(DocumentBasedLLMJobParams): + llm_job_type: Literal[LLMJobType.DOCUMENT_TAGGING] + tag_ids: List[int] = Field( + description="IDs of the tags to use for the document tagging" + ) + + +class MetadataExtractionLLMJobParams(DocumentBasedLLMJobParams): + llm_job_type: Literal[LLMJobType.METADATA_EXTRACTION] + project_metadata_ids: List[int] = Field( + description="IDs of the project metadata to use for the metadata extraction" + ) + + +class AnnotationLLMJobParams(DocumentBasedLLMJobParams): + llm_job_type: Literal[LLMJobType.ANNOTATION] + code_ids: List[int] = Field( + description="IDs of the codes to use for the annotation" + ) + + +class LLMJobParameters(BaseModel): + llm_job_type: LLMJobType = Field(description="The type of the LLMJob (what to llm)") + project_id: int = Field(description="The ID of the Project to analyse") + prompts: List[LLMPromptTemplates] = Field( + description="The prompt templates to use for the job" + ) + specific_llm_job_parameters: Union[ + DocumentTaggingLLMJobParams, + MetadataExtractionLLMJobParams, + AnnotationLLMJobParams, + ] = Field( + description="Specific parameters for the LLMJob w.r.t it's type", + discriminator="llm_job_type", + ) + + +# --- END PARAMETERS --- + +# --- START RESULTS --- + + +class DocumentTaggingResult(BaseModel): + sdoc_id: int = Field(description="ID of the source document") + current_tag_ids: List[int] = Field( + description="IDs of the tags currently assigned to the document" + ) + suggested_tag_ids: List[int] = Field( + description="IDs of the tags suggested by the LLM to assign to the document" + ) + reasoning: str = Field(description="Reasoning for the tagging") + + +class DocumentTaggingLLMJobResult(BaseModel): + llm_job_type: Literal[LLMJobType.DOCUMENT_TAGGING] + results: List[DocumentTaggingResult] + + +class MetadataExtractionResult(BaseModel): + sdoc_id: int = Field(description="ID of the source document") + current_metadata: List[SourceDocumentMetadataReadResolved] = Field( + description="Current metadata" + ) + suggested_metadata: List[SourceDocumentMetadataReadResolved] = Field( + description="Suggested metadata" + ) + + +class MetadataExtractionLLMJobResult(BaseModel): + llm_job_type: Literal[LLMJobType.METADATA_EXTRACTION] + results: List[MetadataExtractionResult] + + +class AnnotationResult(BaseModel): + sdoc_id: int = Field(description="ID of the source document") + suggested_annotations: List[SpanAnnotationReadResolved] = Field( + description="Suggested annotations" + ) + + +class AnnotationLLMJobResult(BaseModel): + llm_job_type: Literal[LLMJobType.ANNOTATION] + results: List[AnnotationResult] + + +class LLMJobResult(BaseModel): + llm_job_type: LLMJobType = Field(description="The type of the LLMJob (what to llm)") + specific_llm_job_result: Union[ + DocumentTaggingLLMJobResult, + MetadataExtractionLLMJobResult, + AnnotationLLMJobResult, + ] = Field( + description="Specific result for the LLMJob w.r.t it's type", + discriminator="llm_job_type", + ) + + +# --- END RESULTS --- + + +# Properties shared across all DTOs +class LLMJobBaseDTO(BaseModel): + status: BackgroundJobStatus = Field( + default=BackgroundJobStatus.WAITING, description="Status of the LLMJob" + ) + num_steps_finished: int = Field(description="Number of steps LLMJob has completed.") + num_steps_total: int = Field(description="Number of total steps.") + result: Optional[LLMJobResult] = Field( + default=None, description="Results of hte LLMJob." + ) + + +# Properties to create +class LLMJobCreate(LLMJobBaseDTO): + parameters: LLMJobParameters = Field( + description="The parameters of the LLMJob that defines what to do!" + ) + + +# Properties to update +class LLMJobUpdate(BaseModel, UpdateDTOBase): + status: Optional[BackgroundJobStatus] = Field( + default=None, description="Status of the LLMJob" + ) + num_steps_finished: Optional[int] = Field( + default=None, description="Number of steps LLMJob has completed." + ) + result: Optional[LLMJobResult] = Field( + default=None, description="Result of the LLMJob." + ) + + +# Properties to read +class LLMJobRead(LLMJobBaseDTO): + id: str = Field(description="ID of the LLMJob") + parameters: LLMJobParameters = Field( + description="The parameters of the LLMJob that defines what to llm!" + ) + created: datetime = Field(description="Created timestamp of the LLMJob") + updated: datetime = Field(description="Updated timestamp of the LLMJob") diff --git a/backend/src/app/core/data/dto/project_metadata.py b/backend/src/app/core/data/dto/project_metadata.py index ac3667009..5ddbc77b9 100644 --- a/backend/src/app/core/data/dto/project_metadata.py +++ b/backend/src/app/core/data/dto/project_metadata.py @@ -21,6 +21,7 @@ class ProjectMetadataBaseDTO(BaseModel): doctype: DocType = Field( description="DOCTYPE of the SourceDocument this metadata refers to" ) + description: str = Field(description="Description of the ProjectMetadata") # Properties for creation @@ -31,7 +32,12 @@ class ProjectMetadataCreate(ProjectMetadataBaseDTO): # Properties for updating class ProjectMetadataUpdate(BaseModel, UpdateDTOBase): key: Optional[str] = Field(description="Key of the ProjectMetadata", default=None) - metatype: Optional[MetaType] = Field(description="Type of the ProjectMetadata") + metatype: Optional[MetaType] = Field( + description="Type of the ProjectMetadata", default=None + ) + description: Optional[str] = Field( + description="Description of the ProjectMetadata", default=None + ) # Properties for reading (as in ORM) diff --git a/backend/src/app/core/data/dto/source_document_metadata.py b/backend/src/app/core/data/dto/source_document_metadata.py index ba7ce7783..880d93496 100644 --- a/backend/src/app/core/data/dto/source_document_metadata.py +++ b/backend/src/app/core/data/dto/source_document_metadata.py @@ -111,6 +111,10 @@ class SourceDocumentMetadataUpdate(SourceDocumentMetadataBaseDTO, UpdateDTOBase) pass +class SourceDocumentMetadataBulkUpdate(SourceDocumentMetadataBaseDTO, UpdateDTOBase): + id: int = Field(description="ID of the SourceDocumentMetadata") + + # Properties for reading (as in ORM) class SourceDocumentMetadataRead(SourceDocumentMetadataBaseDTO): id: int = Field(description="ID of the SourceDocumentMetadata") @@ -144,3 +148,70 @@ def get_value(self) -> Union[str, int, datetime, bool, List, None]: case MetaType.LIST: return self.list_value return None + + @staticmethod + def with_value( + sdoc_metadata_id: int, + source_document_id: int, + project_metadata: ProjectMetadataRead, + value: str, + ) -> "SourceDocumentMetadataReadResolved": + match project_metadata.metatype: + case MetaType.STRING: + return SourceDocumentMetadataReadResolved( + id=sdoc_metadata_id, + str_value=str(value) if value is not None else "", + boolean_value=None, + date_value=None, + int_value=None, + list_value=None, + source_document_id=source_document_id, + project_metadata=project_metadata, + ) + case MetaType.NUMBER: + return SourceDocumentMetadataReadResolved( + id=sdoc_metadata_id, + str_value=None, + boolean_value=None, + date_value=None, + int_value=round(float(value)) if value is not None else 0, + list_value=None, + source_document_id=source_document_id, + project_metadata=project_metadata, + ) + case MetaType.DATE: + return SourceDocumentMetadataReadResolved( + id=sdoc_metadata_id, + str_value=None, + boolean_value=None, + date_value=value if value is not None else datetime.now(), + int_value=None, + list_value=None, + source_document_id=source_document_id, + project_metadata=project_metadata, + ) + case MetaType.BOOLEAN: + return SourceDocumentMetadataReadResolved( + id=sdoc_metadata_id, + str_value=None, + boolean_value=bool(value) if value is not None else False, + date_value=None, + int_value=None, + list_value=None, + source_document_id=source_document_id, + project_metadata=project_metadata, + ) + case MetaType.LIST: + list_value = value if value is not None else [] + if isinstance(list_value, str): + list_value = [list_value] + return SourceDocumentMetadataReadResolved( + id=sdoc_metadata_id, + str_value=None, + boolean_value=None, + date_value=None, + int_value=None, + list_value=list_value, + source_document_id=source_document_id, + project_metadata=project_metadata, + ) diff --git a/backend/src/app/core/data/dto/span_annotation.py b/backend/src/app/core/data/dto/span_annotation.py index 12a5c4ed2..c07ad8c3e 100644 --- a/backend/src/app/core/data/dto/span_annotation.py +++ b/backend/src/app/core/data/dto/span_annotation.py @@ -31,6 +31,13 @@ class SpanAnnotationCreateWithCodeId(SpanAnnotationBaseDTO): ) +class SpanAnnotationCreateBulkWithCodeId(SpanAnnotationBaseDTO): + span_text: str = Field(description="The SpanText the SpanAnnotation spans.") + code_id: int = Field(description="Code the SpanAnnotation refers to") + sdoc_id: int = Field(description="SourceDocument the SpanAnnotation refers to") + user_id: int = Field(description="User the SpanAnnotation belongs to") + + # Properties for updating class SpanAnnotationUpdate(BaseModel, UpdateDTOBase): current_code_id: int = Field(description="CurrentCode the SpanAnnotation refers to") diff --git a/backend/src/app/core/data/llm/llm_service.py b/backend/src/app/core/data/llm/llm_service.py new file mode 100644 index 000000000..ffc37091c --- /dev/null +++ b/backend/src/app/core/data/llm/llm_service.py @@ -0,0 +1,538 @@ +from datetime import datetime +from typing import Callable, Dict, List, Type, Union + +from loguru import logger +from sqlalchemy.orm import Session + +from app.core.data.crud.source_document import crud_sdoc +from app.core.data.crud.source_document_metadata import crud_sdoc_meta +from app.core.data.crud.user import SYSTEM_USER_ID +from app.core.data.dto.background_job_base import BackgroundJobStatus +from app.core.data.dto.code import CodeRead +from app.core.data.dto.llm_job import ( + AnnotationLLMJobResult, + AnnotationResult, + DocumentTaggingLLMJobResult, + DocumentTaggingResult, + LLMJobCreate, + LLMJobParameters, + LLMJobRead, + LLMJobResult, + LLMJobType, + LLMJobUpdate, + LLMPromptTemplates, + MetadataExtractionLLMJobResult, + MetadataExtractionResult, +) +from app.core.data.dto.source_document_metadata import ( + SourceDocumentMetadataReadResolved, +) +from app.core.data.dto.span_annotation import SpanAnnotationReadResolved +from app.core.data.llm.ollama_service import OllamaService +from app.core.data.llm.prompts.annotation_prompt_builder import AnnotationPromptBuilder +from app.core.data.llm.prompts.metadata_prompt_builder import MetadataPromptBuilder +from app.core.data.llm.prompts.prompt_builder import PromptBuilder +from app.core.data.llm.prompts.tagging_prompt_builder import TaggingPromptBuilder +from app.core.data.repo.repo_service import RepoService +from app.core.db.redis_service import RedisService +from app.core.db.sql_service import SQLService +from app.util.singleton_meta import SingletonMeta + + +class LLMJobPreparationError(Exception): + def __init__(self, cause: Union[Exception, str]) -> None: + super().__init__(f"Cannot prepare and create the LLMJob! {cause}") + + +class LLMJobAlreadyStartedOrDoneError(Exception): + def __init__(self, llm_job_id: str) -> None: + super().__init__(f"The LLMJob with ID {llm_job_id} already started or is done!") + + +class NoSuchLLMJobError(Exception): + def __init__(self, llm_job_id: str, cause: Exception) -> None: + super().__init__(f"There exists not LLMJob with ID {llm_job_id}! {cause}") + + +class UnsupportedLLMJobTypeError(Exception): + def __init__(self, llm_job_type: LLMJobType) -> None: + super().__init__(f"LLMJobType {llm_job_type} is not supported! ") + + +class LLMService(metaclass=SingletonMeta): + def __new__(cls, *args, **kwargs): + cls.repo: RepoService = RepoService() + cls.redis: RedisService = RedisService() + cls.sqls: SQLService = SQLService() + cls.ollamas: OllamaService = OllamaService() + + # map from job_type to function + cls.llm_method_for_job_type: Dict[LLMJobType, Callable[..., LLMJobResult]] = { + LLMJobType.DOCUMENT_TAGGING: cls._llm_document_tagging, + LLMJobType.METADATA_EXTRACTION: cls._llm_metadata_extraction, + LLMJobType.ANNOTATION: cls._llm_annotation, + } + + # map from job_type to promt builder + cls.llm_prompt_builder_for_job_type: Dict[LLMJobType, Type[PromptBuilder]] = { + LLMJobType.DOCUMENT_TAGGING: TaggingPromptBuilder, + LLMJobType.METADATA_EXTRACTION: MetadataPromptBuilder, + LLMJobType.ANNOTATION: AnnotationPromptBuilder, + } + + return super(LLMService, cls).__new__(cls) + + def _assert_all_requested_data_exists(self, llm_params: LLMJobParameters) -> bool: + # TODO check all job type specific parameters + return True + + def prepare_llm_job(self, llm_params: LLMJobParameters) -> LLMJobRead: + if not self._assert_all_requested_data_exists(llm_params=llm_params): + raise LLMJobPreparationError( + cause="Not all requested data for the LLM job exists!" + ) + + llmj_create = LLMJobCreate( + parameters=llm_params, + num_steps_total=len(llm_params.specific_llm_job_parameters.sdoc_ids), + num_steps_finished=0, + ) + try: + llmj_read = self.redis.store_llm_job(llm_job=llmj_create) + except Exception as e: + raise LLMJobPreparationError(cause=e) + + return llmj_read + + def get_llm_job(self, llm_job_id: str) -> LLMJobRead: + try: + llmj = self.redis.load_llm_job(key=llm_job_id) + except Exception as e: + raise NoSuchLLMJobError(llm_job_id=llm_job_id, cause=e) + + return llmj + + def get_all_llm_jobs(self, project_id: int) -> List[LLMJobRead]: + return self.redis.get_all_llm_jobs(project_id=project_id) + + def _update_llm_job(self, llm_job_id: str, update: LLMJobUpdate) -> LLMJobRead: + try: + llmj = self.redis.update_llm_job(key=llm_job_id, update=update) + except Exception as e: + raise NoSuchLLMJobError(llm_job_id=llm_job_id, cause=e) + return llmj + + def start_llm_job_sync(self, llm_job_id: str) -> LLMJobRead: + llmj = self.get_llm_job(llm_job_id=llm_job_id) + if llmj.status != BackgroundJobStatus.WAITING: + raise LLMJobAlreadyStartedOrDoneError(llm_job_id=llm_job_id) + + llmj = self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(status=BackgroundJobStatus.RUNNING), + ) + + try: + with self.sqls.db_session() as db: + # get the llm method based on the jobtype + llm_method = self.llm_method_for_job_type.get( + llmj.parameters.llm_job_type, None + ) + if llm_method is None: + raise UnsupportedLLMJobTypeError(llmj.parameters.llm_job_type) + + # execute the llm_method with the provided specific parameters + result = llm_method( + self=self, + db=db, + llm_job_id=llm_job_id, + prompts=llmj.parameters.prompts, + project_id=llmj.parameters.project_id, + **llmj.parameters.specific_llm_job_parameters.model_dump( + exclude={"llm_job_type"} + ), + ) + + llmj = self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(result=result, status=BackgroundJobStatus.FINISHED), + ) + + except Exception as e: + logger.error(f"Cannot finish LLMJob: {e}") + self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(status=BackgroundJobStatus.ERROR), + ) + + return llmj + + def create_prompt_templates( + self, llm_job_params: LLMJobParameters + ) -> List[LLMPromptTemplates]: + with self.sqls.db_session() as db: + # get the llm method based on the jobtype + llm_prompt_builder = self.llm_prompt_builder_for_job_type.get( + llm_job_params.llm_job_type, None + ) + if llm_prompt_builder is None: + raise UnsupportedLLMJobTypeError(llm_job_params.llm_job_type) + + # execute the the prompt builder with the provided specific parameters + prompt_builder = llm_prompt_builder( + db=db, project_id=llm_job_params.project_id + ) + return prompt_builder.build_prompt_templates( + **llm_job_params.specific_llm_job_parameters.model_dump( + exclude={"llm_job_type"} + ) + ) + + def construct_prompt_dict( + self, prompts: List[LLMPromptTemplates], prompt_builder: PromptBuilder + ) -> Dict[str, Dict[str, str]]: + prompt_dict = {} + for prompt in prompts: + # validate prompts + if not prompt_builder.is_system_prompt_valid( + system_prompt=prompt.system_prompt + ): + raise ValueError("system prompt is not valid!") + if not prompt_builder.is_user_prompt_valid(user_prompt=prompt.user_prompt): + raise ValueError("User prompt is not valid!") + + prompt_dict[prompt.language] = { + "system_prompt": prompt.system_prompt, + "user_prompt": prompt.user_prompt, + } + return prompt_dict + + def _llm_document_tagging( + self, + db: Session, + llm_job_id: str, + prompts: List[LLMPromptTemplates], + project_id: int, + sdoc_ids: List[int], + tag_ids: List[int], + ) -> LLMJobResult: + logger.info(f"Starting LLMJob - Document Tagging, num docs: {len(sdoc_ids)}") + + prompt_builder = TaggingPromptBuilder(db=db, project_id=project_id) + + # build prompt dict (to validate and access prompts by language and system / user) + prompt_dict = self.construct_prompt_dict( + prompts=prompts, prompt_builder=prompt_builder + ) + + # read sdocs + sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids) + + # automatic document tagging + result: List[DocumentTaggingResult] = [] + for idx, sdoc_data in enumerate(sdoc_datas): + # get current tag ids + current_tag_ids = [ + tag.id for tag in crud_sdoc.read(db=db, id=sdoc_data.id).document_tags + ] + + # get language + language = crud_sdoc_meta.read_by_sdoc_and_key( + db=db, sdoc_id=sdoc_data.id, key="language" + ).str_value + logger.info(f"Processing SDOC id={sdoc_data.id}, lang={language}") + if language is None or language not in prompt_builder.supported_languages: + result.append( + DocumentTaggingResult( + sdoc_id=sdoc_data.id, + suggested_tag_ids=[], + current_tag_ids=current_tag_ids, + reasoning="Language not supported", + ) + ) + self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(num_steps_finished=idx + 1), + ) + continue + + # construct prompts + system_prompt = prompt_builder.build_system_prompt( + system_prompt_template=prompt_dict[language]["system_prompt"] + ) + user_prompt = prompt_builder.build_user_prompt( + user_prompt_template=prompt_dict[language]["user_prompt"], + document=sdoc_data.content, + ) + + # prompt the model + response = self.ollamas.chat( + system_prompt=system_prompt, user_prompt=user_prompt + ) + logger.info(f"Got chat response! Response={response}") + + # parse the response + tag_ids, reason = prompt_builder.parse_response( + language=language, response=response + ) + logger.info(f"Parsed the response! Tag IDs={tag_ids}, Reason={reason}") + + result.append( + DocumentTaggingResult( + sdoc_id=sdoc_data.id, + suggested_tag_ids=tag_ids, + current_tag_ids=current_tag_ids, + reasoning=reason, + ) + ) + self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(num_steps_finished=idx + 1), + ) + + return LLMJobResult( + llm_job_type=LLMJobType.DOCUMENT_TAGGING, + specific_llm_job_result=DocumentTaggingLLMJobResult( + llm_job_type=LLMJobType.DOCUMENT_TAGGING, results=result + ), + ) + + def _llm_metadata_extraction( + self, + db: Session, + llm_job_id: str, + prompts: List[LLMPromptTemplates], + project_id: int, + sdoc_ids: List[int], + project_metadata_ids: List[int], + ) -> LLMJobResult: + logger.info(f"Starting LLMJob - Metadata Extraction, num docs: {len(sdoc_ids)}") + + prompt_builder = MetadataPromptBuilder(db=db, project_id=project_id) + + # build prompt dict (to validate and access prompts by language and system / user) + prompt_dict = self.construct_prompt_dict( + prompts=prompts, prompt_builder=prompt_builder + ) + + # read sdocs + sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids) + # automatic metadata extraction + result: List[MetadataExtractionResult] = [] + for idx, sdoc_data in enumerate(sdoc_datas): + # get current metadata values + current_metadata = [ + SourceDocumentMetadataReadResolved.model_validate(metadata) + for metadata in crud_sdoc.read(db=db, id=sdoc_data.id).metadata_ + if metadata.project_metadata_id in project_metadata_ids + ] + current_metadata_dict = { + metadata.project_metadata.id: metadata for metadata in current_metadata + } + + # get language + language = crud_sdoc_meta.read_by_sdoc_and_key( + db=db, sdoc_id=sdoc_data.id, key="language" + ).str_value + logger.info(f"Processing SDOC id={sdoc_data.id}, lang={language}") + if language is None or language not in prompt_builder.supported_languages: + result.append( + MetadataExtractionResult( + sdoc_id=sdoc_data.id, + current_metadata=current_metadata, + suggested_metadata=[], + ) + ) + self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(num_steps_finished=idx + 1), + ) + continue + + # construct prompts + system_prompt = prompt_builder.build_system_prompt( + system_prompt_template=prompt_dict[language]["system_prompt"] + ) + user_prompt = prompt_builder.build_user_prompt( + user_prompt_template=prompt_dict[language]["user_prompt"], + document=sdoc_data.content, + ) + + # prompt the model + response = self.ollamas.chat( + system_prompt=system_prompt, user_prompt=user_prompt + ) + logger.info(f"Got chat response! Response={response}") + + # parse the response + parsed_response = prompt_builder.parse_response( + language=language, response=response + ) + + # create correct suggested metadata (map the parsed response to the current metadata) + suggested_metadata = [] + for project_metadata_id in project_metadata_ids: + current = current_metadata_dict.get(project_metadata_id) + suggestion = parsed_response.get(project_metadata_id) + if current is None or suggestion is None: + continue + + suggested_metadata.append( + SourceDocumentMetadataReadResolved.with_value( + sdoc_metadata_id=current.id, + source_document_id=current.source_document_id, + project_metadata=current.project_metadata, + value=suggestion, + ) + ) + logger.info(f"Parsed the response! suggested metadata={suggested_metadata}") + + result.append( + MetadataExtractionResult( + sdoc_id=sdoc_data.id, + current_metadata=current_metadata, + suggested_metadata=suggested_metadata, + ) + ) + self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(num_steps_finished=idx + 1), + ) + + return LLMJobResult( + llm_job_type=LLMJobType.METADATA_EXTRACTION, + specific_llm_job_result=MetadataExtractionLLMJobResult( + llm_job_type=LLMJobType.METADATA_EXTRACTION, results=result + ), + ) + + def _llm_annotation( + self, + db: Session, + llm_job_id: str, + prompts: List[LLMPromptTemplates], + project_id: int, + sdoc_ids: List[int], + code_ids: List[int], + ) -> LLMJobResult: + logger.info(f"Starting LLMJob - Annotation, num docs: {len(sdoc_ids)}") + + prompt_builder = AnnotationPromptBuilder(db=db, project_id=project_id) + project_codes = prompt_builder.codeids2code_dict + + # build prompt dict (to validate and access prompts by language and system / user) + prompt_dict = self.construct_prompt_dict( + prompts=prompts, prompt_builder=prompt_builder + ) + + # read sdocs + sdoc_datas = crud_sdoc.read_with_data_batch(db=db, ids=sdoc_ids) + + # automatic annotation + annotation_id = 0 + result: List[AnnotationResult] = [] + for idx, sdoc_data in enumerate(sdoc_datas): + # get language + language = crud_sdoc_meta.read_by_sdoc_and_key( + db=db, sdoc_id=sdoc_data.id, key="language" + ).str_value + logger.info(f"Processing SDOC id={sdoc_data.id}, lang={language}") + if language is None or language not in prompt_builder.supported_languages: + result.append( + AnnotationResult(sdoc_id=sdoc_data.id, suggested_annotations=[]) + ) + self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(num_steps_finished=idx + 1), + ) + continue + + # construct prompts + system_prompt = prompt_builder.build_system_prompt( + system_prompt_template=prompt_dict[language]["system_prompt"] + ) + user_prompt = prompt_builder.build_user_prompt( + user_prompt_template=prompt_dict[language]["user_prompt"], + document=sdoc_data.content, + ) + + # prompt the model + response = self.ollamas.chat( + system_prompt=system_prompt, user_prompt=user_prompt + ) + logger.info(f"Got chat response! Response={response}") + + # parse the response + parsed_response = prompt_builder.parse_response( + language=language, response=response + ) + + # validate the response and create the suggested annotation + suggested_annotations: List[SpanAnnotationReadResolved] = [] + for code_id, span_text in parsed_response: + # check if the code_id is valid + if code_id not in project_codes: + continue + + document_text = sdoc_data.content.lower() + annotation_text = span_text.lower() + + # find start and end character of the annotation_text in the document_text + start = document_text.find(annotation_text) + end = start + len(annotation_text) + if start == -1: + continue + + # find start and end token of the annotation_text in the document_tokens + # create a map of character offsets to token ids + document_token_map = {} # character offset -> token id + last_character_offset = 0 + for token_id, token_end in enumerate(sdoc_data.token_ends): + for i in range(last_character_offset, token_end): + document_token_map[i] = token_id + last_character_offset = token_end + + begin_token = document_token_map.get(start, -1) + end_token = document_token_map.get(end, -1) + if begin_token == -1 or end_token == -1: + continue + + # create the suggested annotation + suggested_annotations.append( + SpanAnnotationReadResolved( + id=annotation_id, + annotation_document_id=-1, + sdoc_id=sdoc_data.id, + user_id=SYSTEM_USER_ID, + begin=start, + end=end, + begin_token=begin_token, + end_token=end_token, + span_text=span_text, + code=CodeRead.model_validate(project_codes.get(code_id)), + created=datetime.now(), + updated=datetime.now(), + ) + ) + annotation_id += 1 + logger.info( + f"Parsed the response! suggested annotations={suggested_annotations}" + ) + + result.append( + AnnotationResult( + sdoc_id=sdoc_data.id, + suggested_annotations=suggested_annotations, + ) + ) + self._update_llm_job( + llm_job_id=llm_job_id, + update=LLMJobUpdate(num_steps_finished=idx + 1), + ) + + return LLMJobResult( + llm_job_type=LLMJobType.ANNOTATION, + specific_llm_job_result=AnnotationLLMJobResult( + llm_job_type=LLMJobType.ANNOTATION, results=result + ), + ) diff --git a/backend/src/app/core/data/llm/ollama_service.py b/backend/src/app/core/data/llm/ollama_service.py new file mode 100644 index 000000000..88ee3f515 --- /dev/null +++ b/backend/src/app/core/data/llm/ollama_service.py @@ -0,0 +1,62 @@ +from loguru import logger +from ollama import Client + +from app.util.singleton_meta import SingletonMeta +from config import conf + + +class OllamaService(metaclass=SingletonMeta): + def __new__(cls, *args, **kwargs): + if conf.ollama.enabled != "True": + # When running in tests, don't use the ray service at all + return super(OllamaService, cls).__new__(cls) + + try: + # Ollama Connection + ollamac = Client(host=f"{conf.ollama.host}:{conf.ollama.port}") + + # ensure connection to Ollama works + if not ollamac.list(): + raise Exception( + f"Cant connect to Ollama on {conf.ollama.host}:{conf.ollama.port}" + ) + + # ensure that the configured model is available + model = conf.ollama.model + available_models = [x["name"] for x in ollamac.list()["models"]] + if model not in available_models: + print(f"Model {model} is not available. Pulling it now.") + ollamac.pull(model) + print(f"Model {model} has been pulled successfully.") + available_models = [x["name"] for x in ollamac.list()["models"]] + assert ( + model in available_models + ), f"Model {model} is not available. Available models are: {available_models}" + + cls.__model = model + cls.__client = ollamac + + except Exception as e: + msg = f"Cannot instantiate OllamaService - Error '{e}'" + logger.error(msg) + raise SystemExit(msg) + + logger.info("Successfully established connection to Ollama!") + + return super(OllamaService, cls).__new__(cls) + + def chat(self, system_prompt: str, user_prompt: str) -> str: + response = self.__client.chat( + model=self.__model, + messages=[ + { + "role": "system", + "content": system_prompt.strip(), + }, + { + "role": "user", + "content": user_prompt.strip(), + }, + ], + ) + return response["message"]["content"].strip() diff --git a/backend/src/app/core/data/llm/prompts/annotation_prompt_builder.py b/backend/src/app/core/data/llm/prompts/annotation_prompt_builder.py new file mode 100644 index 000000000..6d6454d69 --- /dev/null +++ b/backend/src/app/core/data/llm/prompts/annotation_prompt_builder.py @@ -0,0 +1,124 @@ +import random +import re +from typing import Dict, List, Tuple + +from sqlalchemy.orm import Session + +from app.core.data.crud.project import crud_project +from app.core.data.llm.prompts.prompt_builder import PromptBuilder + +# ENGLISH + +en_prompt_template = """ +Please extract text passages from the provided document that are relevant to the following categories. The categories are: +{}. + +Please answer in this format. Not every category may be present in the text. There can be multiple relevant passages per category: +: +: +: + +e.g. +{} + +Document: + + +Remember, you have to extract text passages that are relevant to the categories verbatim, do not generate passages! +""" + + +# GERMAN + +de_prompt_template = """ +Bitte extrahiere Textpassagen aus dem gegebenen Dokument, die gut zu den folgenden Kategorien passen. Die Kategorien sind: +{}. + +Bitte anworte in diesem Format. Nicht alle Kategorien müssen im Text vorkommen. Es können mehrere Textpassagen pro Kategorie relevant sein: +: +: +: + +e.g. +{} + +Dokument: + + +Denke daran, dass du Textpassagen wörtlich extrahieren musst, die zu den Kategorien passen. Generiere keine neuen Textpassagen! +""" + + +class AnnotationPromptBuilder(PromptBuilder): + supported_languages = ["en", "de"] + prompt_templates = { + "en": en_prompt_template.strip(), + "de": de_prompt_template.strip(), + } + + def __init__(self, db: Session, project_id: int): + super().__init__(db, project_id) + + project = crud_project.read(db=db, id=project_id) + self.codes = project.codes + self.codename2id_dict = {code.name.lower(): code.id for code in self.codes} + self.codeids2code_dict = {code.id: code for code in self.codes} + + # get one example annotation per code + examples: Dict[int, str] = {} + for code in project.codes: + # get all annotations for the code + annotations = code.current_code.span_annotations + if len(annotations) == 0: + continue + random_annotation = random.choice(annotations) + examples[code.id] = f"{code.name}: {random_annotation.span_text.text}" + self.examples = examples + + def _build_example(self, language: str, code_ids: List[int]) -> str: + examples: List[str] = [] + for code_id in code_ids: + if code_id not in self.examples: + continue + examples.append(self.examples[code_id]) + + if len(examples) == 0: + # choose 3 random examples + examples.extend(random.sample(list(self.examples.values()), 3)) + + return "\n".join(examples) + + def _build_user_prompt_template( + self, language: str, code_ids: List[int], **kwargs + ) -> str: + task_data = "\n".join( + [ + f"{self.codeids2code_dict[code_id].name}: {self.codeids2code_dict[code_id].description}" + for code_id in code_ids + ] + ) + answer_example = self._build_example(language, code_ids) + return self.prompt_templates[language].format(task_data, answer_example) + + def parse_response(self, language: str, response: str) -> List[Tuple[int, str]]: + components = re.split(r"\n+", response) + + results: List[Tuple[int, str]] = [] + for component in components: + if ":" not in component: + continue + + # extract the code_name and value + code_name, value = component.split(":", 1) + + # check if the code_name is valid + if code_name.lower() not in self.codename2id_dict: + continue + + # get the code + code_id = self.codename2id_dict[code_name.lower()] + value = value.strip() + + results.append((code_id, value)) + + return results diff --git a/backend/src/app/core/data/llm/prompts/metadata_prompt_builder.py b/backend/src/app/core/data/llm/prompts/metadata_prompt_builder.py new file mode 100644 index 000000000..e26870be8 --- /dev/null +++ b/backend/src/app/core/data/llm/prompts/metadata_prompt_builder.py @@ -0,0 +1,139 @@ +import re +from typing import Dict, List + +from sqlalchemy.orm import Session + +from app.core.data.crud.project import crud_project +from app.core.data.dto.project_metadata import ProjectMetadataRead +from app.core.data.llm.prompts.prompt_builder import PromptBuilder +from app.core.data.meta_type import MetaType + +# ENGLISH + +en_prompt_template = """ +Please extract the following information from the provided document. It is possible that not all information is contained in the document: +{}. + +Please answer in this format. If the information is not contained in the document, leave the field empty with "None": +{} + +e.g. +{} + +Document: + + +Remember, you MUST extract the information verbatim from the document, do not generate facts! +""" + +# GERMAN + +de_prompt_template = """ +Bitte extrahiere die folgenden Informationen aus dem Dokument. Es kann sein, dass nicht alle Informationen im Dokument enthalten sind: +{}. + +Bitte anworte in diesem Format. Wenn die Information nicht im Dokument enthalten ist, lasse das Feld leer mit "None": +{} + +e.g. +{} + +Dokument: + + +Denke daran, die Informationen MÜSSEN wörtlich aus dem Dokument extrahiert werden, generiere keine Fakten! +""" + + +class MetadataPromptBuilder(PromptBuilder): + supported_languages = ["en", "de"] + prompt_templates = { + "en": en_prompt_template.strip(), + "de": de_prompt_template.strip(), + } + + def __init__(self, db: Session, project_id: int): + super().__init__(db, project_id) + + project = crud_project.read(db=db, id=project_id) + self.project_metadata = [ + ProjectMetadataRead.model_validate(pm) for pm in project.metadata_ + ] + self.metadataid2metadata = { + metadata.id: metadata for metadata in self.project_metadata + } + self.metadataname2metadata = { + metadata.key.lower(): metadata for metadata in self.project_metadata + } + + def _build_answer_template(self, project_metadata_ids: List[int]) -> str: + # The example will be a list of metadata keys and some example values + answer_templates: Dict[MetaType, str] = { + MetaType.STRING: "", + MetaType.NUMBER: "", + MetaType.DATE: "", + MetaType.BOOLEAN: "", + MetaType.LIST: ", , ...", + } + + return "\n".join( + [ + f"{self.metadataid2metadata[pmid].key}: {answer_templates[self.metadataid2metadata[pmid].metatype]}" + for pmid in project_metadata_ids + ] + ) + + def _build_example(self, project_metadata_ids: List[int]) -> str: + # The example will be a list of metadata keys and some example values + example_values: Dict[MetaType, str] = { + MetaType.STRING: "relevant information here", + MetaType.NUMBER: "42", + MetaType.DATE: "2024-01-01", + MetaType.BOOLEAN: "True", + MetaType.LIST: "info1, info2, info3", + } + + return "\n".join( + [ + f"{self.metadataid2metadata[pmid].key}: {example_values[self.metadataid2metadata[pmid].metatype]}" + for pmid in project_metadata_ids + ] + ) + + def _build_user_prompt_template( + self, language: str, project_metadata_ids: List[int], **kwargs + ) -> str: + task_data = "\n".join( + [ + f"{self.metadataid2metadata[pmid].key} - {self.metadataid2metadata[pmid].description}" + for pmid in project_metadata_ids + ] + ) + answer_template = self._build_answer_template(project_metadata_ids) + answer_example = self._build_example(project_metadata_ids) + return self.prompt_templates[language].format( + task_data, answer_template, answer_example + ) + + def parse_response(self, language: str, response: str) -> Dict[int, str]: + components = re.split(r"\n+", response) + + results: Dict[int, str] = {} + for component in components: + if ":" not in component: + continue + + # extract the key and value + key, value = component.split(":", 1) + + # check if the key is valid + if key.lower() not in self.metadataname2metadata: + continue + + # get the metadata + proj_metadata = self.metadataname2metadata[key.lower()] + value = value.strip() + + results[proj_metadata.id] = value + + return results diff --git a/backend/src/app/core/data/llm/prompts/prompt_builder.py b/backend/src/app/core/data/llm/prompts/prompt_builder.py new file mode 100644 index 000000000..e8f45d9a5 --- /dev/null +++ b/backend/src/app/core/data/llm/prompts/prompt_builder.py @@ -0,0 +1,102 @@ +from typing import List + +from sqlalchemy.orm import Session + +from app.core.data.crud.project import crud_project +from app.core.data.dto.llm_job import LLMPromptTemplates + +# ENGLISH + +en_system_prompt_template = """ +You are a system to support the analysis of large amounts of text. You will always answer in the required format and use no other formatting than expected by the user! +""" + +# GERMAN + +de_system_prompt_template = """ +Du bist ein System zur Unterstützung bei der Analyse großer Textmengen. Du antwortest immer in dem geforderten Format und verwendest keine andere Formatierung als vom Benutzer erwartet! +""" + + +class PromptBuilder: + """ + Base class for building LLM prompts. + A system prompt template may contain the placeholders "" and "". + A user prompt template must contain the placeholder "". + + A user prompt template always consists of the same building blocks + 1. The task description, e.g. Please classify the documents ... + 2. The categories to work with, e.g. Category 1 - Description 1, Category 2 - Description 2... + 3.1. Instructions on how to answer, e.g. Please answer in this format. The reasoning is optional. + 3.2. A generalized answer tempalte, e.g. Category: \n Reasoning: + 3.3. An example answer, e.g. Category: News\n Reasoning: Becase ... + 4. The document to work with + 5. Reiteration of the task, e.g. Remember, you have to classify the document into one of the provided categories, do not generate new categories! + + Consequently, we have the following building blocks: + + + + + + + + """ + + supported_languages = ["en", "de"] + system_prompt_templates = { + "en": en_system_prompt_template.strip(), + "de": de_system_prompt_template.strip(), + } + + def __init__(self, db: Session, project_id: int): + project = crud_project.read(db=db, id=project_id) + self.project_title = project.title + self.project_description = project.description + + # VALIDATION + + def is_system_prompt_valid(self, system_prompt: str) -> bool: + return True + + def is_user_prompt_valid(self, user_prompt: str) -> bool: + if "" in user_prompt: + return True + return False + + # PROMPT BUILDING + + def build_system_prompt(self, system_prompt_template: str) -> str: + system_prompt = system_prompt_template.replace( + "", self.project_title + ) + return system_prompt.replace("", self.project_description) + + def build_user_prompt(self, user_prompt_template: str, document: str) -> str: + return user_prompt_template.replace("", document) + + # PROMPT TEMPLATE BUILDING + + def _build_system_prompt_template(self, language: str) -> str: + return self.system_prompt_templates[language] + + def _build_user_prompt_template(self, language: str, **kwargs) -> str: + raise NotImplementedError() + + def build_prompt_templates(self, **kwargs) -> List[LLMPromptTemplates]: + # create the prompt templates for all supported languages + result: List[LLMPromptTemplates] = [] + for language in self.supported_languages: + result.append( + LLMPromptTemplates( + language=language, + system_prompt=self._build_system_prompt_template(language), + user_prompt=self._build_user_prompt_template(language, **kwargs), + ) + ) + return result + + # PARSING + + def parse_response(self, response: str): + raise NotImplementedError() diff --git a/backend/src/app/core/data/llm/prompts/tagging_prompt_builder.py b/backend/src/app/core/data/llm/prompts/tagging_prompt_builder.py new file mode 100644 index 000000000..e15ff6ad4 --- /dev/null +++ b/backend/src/app/core/data/llm/prompts/tagging_prompt_builder.py @@ -0,0 +1,139 @@ +import re +from typing import List, Tuple + +from sqlalchemy.orm import Session + +from app.core.data.crud.project import crud_project +from app.core.data.llm.prompts.prompt_builder import PromptBuilder + +# ENGLISH + +en_prompt_template = """ +Please classify the document into all appropriate categories below. Multiple or zero categories are possible: +{}. + +Please answer in this format. The reasoning is optional. +Categories: , , ... +Reasoning: + +e.g. +{} + +Document: + + +Remember, you have to classify the document into using the provided categories, do not generate new categories! +""" + +en_example_tempalate = """ +Categories: {} +Reasoning: This document is about {}. +""" + + +# GERMAN + +de_prompt_template = """ +Bitte klassifiziere das Dokument in alle passenden folgenden Kategorien. Es sind mehrere oder keine Kategorien möglich: +{}. + +Bitte anworte in diesem Format. Die Begründung ist optional. +Kategorien: , , ... +Begründung: + +e.g. +{} + +Dokument: + + +Denke daran, das Dokument MUSS mithilfe der gegebenen Kategorien klassifiziert werden, generiere keine neuen Kategorien! +""" + +de_example_tempalate = """ +Kategorien: {} +Begründung: Das Dokument handelt von {}. +""" + + +class TaggingPromptBuilder(PromptBuilder): + prompt_templates = { + "en": en_prompt_template.strip(), + "de": de_prompt_template.strip(), + } + example_templates = { + "en": en_example_tempalate.strip(), + "de": de_example_tempalate.strip(), + } + category_word = {"en": "Categories:", "de": "Kategorien:"} + reason_word = {"en": "Reasoning:", "de": "Begründung:"} + + def __init__(self, db: Session, project_id: int): + super().__init__(db, project_id) + + project = crud_project.read(db=db, id=project_id) + self.document_tags = project.document_tags + self.tagid2tag = {tag.id: tag for tag in self.document_tags} + self.tagname2id_dict = {tag.name.lower(): tag.id for tag in self.document_tags} + + def _build_example(self, language: str, tag_id: int) -> str: + tag = self.tagid2tag[tag_id] + + return self.example_templates[language].format(tag.name, tag.name) + + def _build_user_prompt_template( + self, language: str, tag_ids: List[int], **kwargs + ) -> str: + # create task data (the list of tags to use for classification) + task_data = "\n".join( + [ + f"{self.tagid2tag[tag_id].name} - {self.tagid2tag[tag_id].description}" + for tag_id in tag_ids + ] + ) + + # create answer example + answer_example = self._build_example(language, tag_ids[0]) + + return self.prompt_templates[language].format(task_data, answer_example) + + def parse_response(self, language: str, response: str) -> Tuple[List[int], str]: + if language not in self.category_word: + return [], f"Language '{language}' is not supported." + if language not in self.reason_word: + return [], f"Language '{language}' is not supported." + + components = re.split(r"\n+", response) + + # check that the answer starts with expected category word + if not components[0].startswith(f"{self.category_word[language]}"): + return ( + [], + f"The answer has to start with '{self.category_word[language]}'.", + ) + + # extract the categories + comma_separated_categories = components[0].split(":")[1].strip() + if len(comma_separated_categories) == 0: + categories = [] + else: + categories = [ + category.strip().lower() + for category in comma_separated_categories.split(",") + ] + + # map the categories to their tag ids + categories = [ + self.tagname2id_dict[category] + for category in categories + if category in self.tagname2id_dict + ] + + # extract the reason if the answer has multiple lines + reason = "No reason was provided" + if len(components) > 1 and components[1].startswith( + f"{self.reason_word[language]}" + ): + reason = components[1].split(":")[1].strip() + + return categories, reason diff --git a/backend/src/app/core/data/orm/project_metadata.py b/backend/src/app/core/data/orm/project_metadata.py index 86ddb38ff..85e6daa01 100644 --- a/backend/src/app/core/data/orm/project_metadata.py +++ b/backend/src/app/core/data/orm/project_metadata.py @@ -16,6 +16,7 @@ class ProjectMetadataORM(ORMBase): metatype: Mapped[str] = mapped_column(String, nullable=False, index=False) read_only: Mapped[bool] = mapped_column(Boolean, nullable=False, index=False) doctype: Mapped[str] = mapped_column(String, nullable=False, index=False) + description: Mapped[str] = mapped_column(String, nullable=False, index=False) # one to many sdoc_metadata: Mapped[List["SourceDocumentMetadataORM"]] = relationship( diff --git a/backend/src/app/core/db/redis_service.py b/backend/src/app/core/db/redis_service.py index 488b73e9b..00101426f 100644 --- a/backend/src/app/core/db/redis_service.py +++ b/backend/src/app/core/db/redis_service.py @@ -20,6 +20,7 @@ ) from app.core.data.dto.export_job import ExportJobCreate, ExportJobRead, ExportJobUpdate from app.core.data.dto.feedback import FeedbackCreate, FeedbackRead +from app.core.data.dto.llm_job import LLMJobCreate, LLMJobRead, LLMJobUpdate from app.core.data.dto.trainer_job import ( TrainerJobCreate, TrainerJobRead, @@ -416,3 +417,64 @@ def get_all_feedbacks(self) -> List[FeedbackRead]: def get_all_feedbacks_of_user(self, user_id: int) -> List[FeedbackRead]: fbs = self.get_all_feedbacks() return [fb for fb in fbs if fb.user_id == user_id] + + def store_llm_job(self, llm_job: Union[LLMJobCreate, LLMJobRead]) -> LLMJobRead: + client = self._get_client("llm") + + if isinstance(llm_job, LLMJobCreate): + key = self._generate_random_key() + llmj = LLMJobRead( + id=key, + **llm_job.model_dump(), + created=datetime.now(), + updated=datetime.now(), + ) + elif isinstance(llm_job, LLMJobRead): + key = llm_job.id + llmj = llm_job + + if client.set(key.encode("utf-8"), llmj.model_dump_json()) != 1: + msg = "Cannot store LLMJob!" + logger.error(msg) + raise RuntimeError(msg) + + logger.debug(f"Successfully stored LLMJob {key}!") + + return llmj + + def get_all_llm_jobs(self, project_id: int) -> List[LLMJobRead]: + client = self._get_client("llm") + all_llm_jobs: List[LLMJobRead] = [ + self.load_llm_job(str(key, "utf-8")) for key in client.keys() + ] + return [job for job in all_llm_jobs if job.parameters.project_id == project_id] + + def load_llm_job(self, key: str) -> LLMJobRead: + client = self._get_client("llm") + llmj = client.get(key.encode("utf-8")) + if llmj is None: + msg = f"LLMJob with ID {key} does not exist!" + logger.error(msg) + raise KeyError(msg) + + logger.debug(f"Successfully loaded LLMJob {key}") + return LLMJobRead.model_validate_json(llmj) + + def update_llm_job(self, key: str, update: LLMJobUpdate) -> LLMJobRead: + llmj = self.load_llm_job(key=key) + data = llmj.model_dump(exclude={"updated"}) + data.update(**update.model_dump(exclude_unset=True)) + llmj = LLMJobRead(**data, updated=datetime.now()) + llmj = self.store_llm_job(llm_job=llmj) + logger.debug(f"Updated LLMJob {key}") + return llmj + + def delete_llm_job(self, key: str) -> LLMJobRead: + llmj = self.load_llm_job(key=key) + client = self._get_client("llm") + if client.delete(key.encode("utf-8")) != 1: + msg = f"Cannot delete LLMJob {key}" + logger.error(msg) + raise RuntimeError(msg) + logger.debug(f"Deleted LLMJob {key}") + return llmj diff --git a/backend/src/app/core/startup.py b/backend/src/app/core/startup.py index cd9fb69c7..e02dd35f2 100644 --- a/backend/src/app/core/startup.py +++ b/backend/src/app/core/startup.py @@ -148,6 +148,16 @@ def __init_services__( SimSearchService(flush=reset_weaviate) + # import and init OllamaService + from app.core.data.llm.ollama_service import OllamaService + + OllamaService() + + # import and init LLMService + from app.core.data.llm.llm_service import LLMService + + LLMService() + def __create_system_user__() -> None: from app.core.data.crud.user import crud_user diff --git a/backend/src/configs/default.yaml b/backend/src/configs/default.yaml index 7b4256dab..70b319a64 100644 --- a/backend/src/configs/default.yaml +++ b/backend/src/configs/default.yaml @@ -86,6 +86,7 @@ redis: crawler: 3 trainer: 4 cota: 5 + llm: 6 logging: max_file_size: 500 # MB @@ -106,150 +107,184 @@ elasticsearch: docs: configs/default_sdoc_index_settings.json memos: configs/default_memo_index_settings.json +ollama: + enabled: ${oc.env:OLLAMA_ENABLED, True} + host: ${oc.env:OLLAMA_HOST, 127.0.0.1} + port: ${oc.env:OLLAMA_PORT, 9201} + model: ${oc.env:OLLAMA_MODEL, gemma2:9b-instruct-fp16} + project_metadata: text_url: key: "url" metatype: "STRING" read_only: True doctype: "text" + description: "The URL of the document" text_language: key: "language" metatype: "STRING" read_only: True doctype: "text" + description: "The language of the document" text_keywords: key: "keywords" metatype: "LIST" read_only: True doctype: "text" + description: "Keywords extracted from the document" image_url: key: "url" metatype: "STRING" read_only: True doctype: "image" + description: "The URL of the image" image_keywords: key: "keywords" metatype: "LIST" read_only: True doctype: "image" + description: "Keywords extracted from the image's caption" image_caption: key: "caption" metatype: "STRING" read_only: True doctype: "image" + description: "The caption of the image, generated by an image captioning model" image_width: key: "width" metatype: "NUMBER" read_only: True doctype: "image" + description: "The width of the image in pixels" image_height: key: "height" metatype: "NUMBER" read_only: True doctype: "image" + description: "The height of the image in pixels" image_format: key: "format" metatype: "STRING" read_only: True doctype: "image" + description: "The format of the image" image_mode: key: "mode" metatype: "STRING" read_only: True doctype: "image" + description: "The mode of the image" audio_url: key: "url" metatype: "STRING" read_only: True doctype: "audio" + description: "The URL of the audio file" audio_word_level_transcriptions: key: "word_level_transcriptions" metatype: "STRING" read_only: True doctype: "audio" + description: "Word-level transcriptions of the audio file" audio_duration: key: "duration" metatype: "NUMBER" read_only: True doctype: "audio" + description: "The duration of the audio file in seconds" audio_format_name: key: "format_name" metatype: "LIST" read_only: True doctype: "audio" + description: "The format of the audio file" audio_format_long_name: key: "format_long_name" metatype: "STRING" read_only: True doctype: "audio" + description: "The long format name of the audio file" audio_size: key: "size" metatype: "NUMBER" read_only: True doctype: "audio" + description: "The size of the audio file in bytes" audio_bit_rate: key: "bit_rate" metatype: "NUMBER" read_only: True doctype: "audio" + description: "The bit rate of the audio file in bits per second" audio_tags: key: "tags" metatype: "STRING" read_only: True doctype: "audio" + description: "Tags extracted from the audio file" video_url: key: "url" metatype: "STRING" read_only: True doctype: "video" + description: "The URL of the video file" video_word_level_transcriptions: key: "word_level_transcriptions" metatype: "STRING" read_only: True doctype: "video" + description: "Word-level transcriptions of the video file" video_width: key: "width" metatype: "NUMBER" read_only: True doctype: "video" + description: "The width of the video in pixels" video_height: key: "height" metatype: "NUMBER" read_only: True doctype: "video" + description: "The height of the video in pixels" video_duration: key: "duration" metatype: "NUMBER" read_only: True doctype: "video" + description: "The duration of the video file in seconds" video_format_name: key: "format_name" metatype: "LIST" read_only: True doctype: "video" + description: "The format of the video file" video_format_long_name: key: "format_long_name" metatype: "STRING" read_only: True doctype: "video" + description: "The long format name of the video file" video_size: key: "size" metatype: "NUMBER" read_only: True doctype: "video" + description: "The size of the video file in bytes" video_bit_rate: key: "bit_rate" metatype: "NUMBER" read_only: True doctype: "video" + description: "The bit rate of the video file in bits per second" video_tags: key: "tags" metatype: "STRING" read_only: True doctype: "video" + description: "Tags extracted from the video file" system_codes: SYSTEM_CODE: diff --git a/backend/src/configs/default_localhost_dev.yaml b/backend/src/configs/default_localhost_dev.yaml index a0064cb4c..a7922750c 100644 --- a/backend/src/configs/default_localhost_dev.yaml +++ b/backend/src/configs/default_localhost_dev.yaml @@ -86,6 +86,7 @@ redis: crawler: 3 trainer: 4 cota: 5 + llm: 6 logging: max_file_size: 500 # MB @@ -106,150 +107,184 @@ elasticsearch: docs: configs/default_sdoc_index_settings.json memos: configs/default_memo_index_settings.json +ollama: + enabled: ${oc.env:OLLAMA_ENABLED, True} + host: ${oc.env:OLLAMA_HOST, 127.0.0.1} + port: ${oc.env:OLLAMA_PORT, 9201} + model: ${oc.env:OLLAMA_MODEL, gemma2:9b-instruct-fp16} + project_metadata: text_url: key: "url" metatype: "STRING" read_only: True doctype: "text" + description: "The URL of the document" text_language: key: "language" metatype: "STRING" read_only: True doctype: "text" + description: "The language of the document" text_keywords: key: "keywords" metatype: "LIST" read_only: True doctype: "text" + description: "Keywords extracted from the document" image_url: key: "url" metatype: "STRING" read_only: True doctype: "image" + description: "The URL of the image" image_keywords: key: "keywords" metatype: "LIST" read_only: True doctype: "image" + description: "Keywords extracted from the image's caption" image_caption: key: "caption" metatype: "STRING" read_only: True doctype: "image" + description: "The caption of the image, generated by an image captioning model" image_width: key: "width" metatype: "NUMBER" read_only: True doctype: "image" + description: "The width of the image in pixels" image_height: key: "height" metatype: "NUMBER" read_only: True doctype: "image" + description: "The height of the image in pixels" image_format: key: "format" metatype: "STRING" read_only: True doctype: "image" + description: "The format of the image" image_mode: key: "mode" metatype: "STRING" read_only: True doctype: "image" + description: "The mode of the image" audio_url: key: "url" metatype: "STRING" read_only: True doctype: "audio" + description: "The URL of the audio file" audio_word_level_transcriptions: key: "word_level_transcriptions" metatype: "STRING" read_only: True doctype: "audio" + description: "Word-level transcriptions of the audio file" audio_duration: key: "duration" metatype: "NUMBER" read_only: True doctype: "audio" + description: "The duration of the audio file in seconds" audio_format_name: key: "format_name" metatype: "LIST" read_only: True doctype: "audio" + description: "The format of the audio file" audio_format_long_name: key: "format_long_name" metatype: "STRING" read_only: True doctype: "audio" + description: "The long format name of the audio file" audio_size: key: "size" metatype: "NUMBER" read_only: True doctype: "audio" + description: "The size of the audio file in bytes" audio_bit_rate: key: "bit_rate" metatype: "NUMBER" read_only: True doctype: "audio" + description: "The bit rate of the audio file in bits per second" audio_tags: key: "tags" metatype: "STRING" read_only: True doctype: "audio" + description: "Tags extracted from the audio file" video_url: key: "url" metatype: "STRING" read_only: True doctype: "video" + description: "The URL of the video file" video_word_level_transcriptions: key: "word_level_transcriptions" metatype: "STRING" read_only: True doctype: "video" + description: "Word-level transcriptions of the video file" video_width: key: "width" metatype: "NUMBER" read_only: True doctype: "video" + description: "The width of the video in pixels" video_height: key: "height" metatype: "NUMBER" read_only: True doctype: "video" + description: "The height of the video in pixels" video_duration: key: "duration" metatype: "NUMBER" read_only: True doctype: "video" + description: "The duration of the video file in seconds" video_format_name: key: "format_name" metatype: "LIST" read_only: True doctype: "video" + description: "The format of the video file" video_format_long_name: key: "format_long_name" metatype: "STRING" read_only: True doctype: "video" + description: "The long format name of the video file" video_size: key: "size" metatype: "NUMBER" read_only: True doctype: "video" + description: "The size of the video file in bytes" video_bit_rate: key: "bit_rate" metatype: "NUMBER" read_only: True doctype: "video" + description: "The bit rate of the video file in bits per second" video_tags: key: "tags" metatype: "STRING" read_only: True doctype: "video" + description: "Tags extracted from the video file" system_codes: SYSTEM_CODE: diff --git a/backend/src/main.py b/backend/src/main.py index deaa35a73..9d9c51672 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -45,6 +45,7 @@ export, feedback, general, + llm, memo, prepro, project, @@ -282,6 +283,7 @@ def invalid_error_handler(_, exc: InvalidError): app.include_router(trainer.router) app.include_router(concept_over_time_analysis.router) app.include_router(timeline_analysis.router) +app.include_router(llm.router) def main() -> None: diff --git a/backend/src/migration/migrate.py b/backend/src/migration/migrate.py index d0751e57d..5adedadd4 100644 --- a/backend/src/migration/migrate.py +++ b/backend/src/migration/migrate.py @@ -178,6 +178,7 @@ def __create_or_get_project_metadata_keywords( read_only=True, doctype=doctype, project_id=project_id, + description="Keywords extracted from the document.", ), ).id @@ -358,6 +359,7 @@ def __migrate_add_default_metadata(db: Session): metatype=project_metadata["metatype"], read_only=project_metadata["read_only"], doctype=project_metadata["doctype"], + description=project_metadata["description"], ) __create_project_metadata_if_not_exists(db, create_dto) diff --git a/backend/src/test/conftest.py b/backend/src/test/conftest.py index 9f49869cb..4a1de1e83 100755 --- a/backend/src/test/conftest.py +++ b/backend/src/test/conftest.py @@ -23,6 +23,7 @@ from config import conf os.environ["RAY_ENABLED"] = "False" +os.environ["OLLAMA_ENABLED"] = "False" # Flo: just do it once. We have to check because if we start the main function, unvicorn will import this # file once more manually, so it would be executed twice. diff --git a/docker/.env.example b/docker/.env.example index b465409d3..04968c72e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -3,7 +3,7 @@ COMPOSE_PROJECT_NAME=demo # outside of containers, # remove their profiles to disable their containers # add `dev_frontend` if you run the frontend externally -COMPOSE_PROFILES=backend,frontend,background,ray +COMPOSE_PROFILES=backend,frontend,background,ray,ollama # Which user and group to use for running processes # inside containers. @@ -75,6 +75,11 @@ ES_HOST=elasticsearch ES_PORT=9200 ES_MIN_HEALTH=50 +OLLAMA_ENABLED=True +OLLAMA_HOST=ollama +OLLAMA_PORT=11434 +OLLAMA_MODEL=gemma2:latest + RAY_ENABLED=True RAY_HOST=ray RAY_PORT=8000 @@ -140,3 +145,5 @@ RAY_API_EXPOSED=13134 RAY_DASHBOARD_EXPOSED=13135 WEAVIATE_EXPOSED=13241 + +OLLAMA_EXPOSED=13242 diff --git a/docker/.gitignore b/docker/.gitignore index d94434a75..2d73bd345 100644 --- a/docker/.gitignore +++ b/docker/.gitignore @@ -2,4 +2,5 @@ /backend_repo/* /models_cache/* /numba_cache/* +/ollama_cache/* .env diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index d5000f1c3..28056fdf2 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -108,6 +108,71 @@ services: profiles: - dev_frontend + ollama: + image: ollama/ollama:0.3.4 + ports: + - ${OLLAMA_EXPOSED:-19290}:11434 + environment: + - OLLAMA_KEEP_ALIVE=24h + tty: true + restart: unless-stopped + volumes: + - ./ollama_cache:/root/.ollama + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["0"] + capabilities: [gpu] + networks: + - dats_demo_network + profiles: + - ollama + + ray: + image: uhhlt/dats_ray:${DATS_RAY_DOCKER_VERSION:-debian_dev_latest} + command: /dats_code_ray/ray_model_worker_entrypoint.sh + user: ${UID:-1000}:${GID:-1000} + environment: + LOG_LEVEL: ${LOG_LEVEL:-info} + DATS_BACKEND_CONFIG: ${DATS_BACKEND_CONFIG:-/dats_code/src/configs/default.yaml} + HUGGINGFACE_HUB_CACHE: /models_cache + TRANSFORMERS_CACHE: /models_cache + TORCH_HOME: /models_cache + RAY_PROCESSING_DEVICE_SPACY: ${RAY_PROCESSING_DEVICE_SPACY:-cpu} + RAY_PROCESSING_DEVICE_WHISPER: ${RAY_PROCESSING_DEVICE_WHISPER:-cuda} + RAY_PROCESSING_DEVICE_DETR: ${RAY_PROCESSING_DEVICE_DETR:-cuda} + RAY_PROCESSING_DEVICE_VIT_GPT2: ${RAY_PROCESSING_DEVICE_VIT_GPT2:-cuda} + RAY_PROCESSING_DEVICE_BLIP2: ${RAY_PROCESSING_DEVICE_BLIP2:-cuda} + RAY_PROCESSING_DEVICE_CLIP: ${RAY_PROCESSING_DEVICE_CLIP:-cuda} + RAY_PROCESSING_DEVICE_COTA: ${RAY_PROCESSING_DEVICE_COTA:-cuda} + RAY_BLIP2_PRECISION_BIT: ${RAY_BLIP2_PRECISION_BIT:-32} + SHARED_REPO_ROOT: ${SHARED_REPO_ROOT:-/tmp/dats} + NUMBA_CACHE_DIR: /numba_cache + volumes: + - ../backend/src/app/preprocessing/ray_model_worker:/dats_code_ray + - ./spacy_models:/spacy_models + - ./backend_repo:${SHARED_REPO_ROOT:-/tmp/dats} + - ./models_cache:/models_cache + - ./numba_cache:/numba_cache + ports: + - "${RAY_API_EXPOSED:-8000}:8000" + - "${RAY_DASHBOARD_EXPOSED:-8265}:8265" + restart: always + shm_size: 12gb + networks: + - dats_demo_network + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["0"] + capabilities: [gpu] + profiles: + - ray + celery-background-jobs-worker: image: uhhlt/dats_backend:${DATS_BACKEND_DOCKER_VERSION:-debian_dev_latest} command: /dats_code/src/celery_background_jobs_worker_entrypoint.sh @@ -144,6 +209,10 @@ services: TORCH_HOME: /models_cache NUMBA_CACHE_DIR: /numba_cache SHARED_REPO_ROOT: ${SHARED_REPO_ROOT:-/tmp/dats} + OLLAMA_ENABLED: ${OLLAMA_ENABLED:-True} + OLLAMA_HOST: ${OLLAMA_HOST:-ollama} + OLLAMA_PORT: ${OLLAMA_PORT:-11434} + OLLAMA_MODEL: ${OLLAMA_MODEL:-gemma2:9b-instruct-fp16} volumes: - ../backend/src:/dats_code/src - ./backend_repo:${SHARED_REPO_ROOT:-/tmp/dats} @@ -153,62 +222,16 @@ services: - rabbitmq - redis - postgres + - ollama ports: - "${JUPYTER_BACKGROUND_JOBS_EXPOSED:-8880}:8888" - "${CELERY_DEBUG_PORT:-45678}:6900" restart: always - links: - - postgres - - redis - - rabbitmq networks: - dats_demo_network profiles: - background - ray: - image: uhhlt/dats_ray:${DATS_RAY_DOCKER_VERSION:-debian_dev_latest} - command: /dats_code_ray/ray_model_worker_entrypoint.sh - user: ${UID:-1000}:${GID:-1000} - environment: - LOG_LEVEL: ${LOG_LEVEL:-info} - DATS_BACKEND_CONFIG: ${DATS_BACKEND_CONFIG:-/dats_code/src/configs/default.yaml} - HUGGINGFACE_HUB_CACHE: /models_cache - TRANSFORMERS_CACHE: /models_cache - TORCH_HOME: /models_cache - RAY_PROCESSING_DEVICE_SPACY: ${RAY_PROCESSING_DEVICE_SPACY:-cpu} - RAY_PROCESSING_DEVICE_WHISPER: ${RAY_PROCESSING_DEVICE_WHISPER:-cuda} - RAY_PROCESSING_DEVICE_DETR: ${RAY_PROCESSING_DEVICE_DETR:-cuda} - RAY_PROCESSING_DEVICE_VIT_GPT2: ${RAY_PROCESSING_DEVICE_VIT_GPT2:-cuda} - RAY_PROCESSING_DEVICE_BLIP2: ${RAY_PROCESSING_DEVICE_BLIP2:-cuda} - RAY_PROCESSING_DEVICE_CLIP: ${RAY_PROCESSING_DEVICE_CLIP:-cuda} - RAY_PROCESSING_DEVICE_COTA: ${RAY_PROCESSING_DEVICE_COTA:-cuda} - RAY_BLIP2_PRECISION_BIT: ${RAY_BLIP2_PRECISION_BIT:-32} - SHARED_REPO_ROOT: ${SHARED_REPO_ROOT:-/tmp/dats} - NUMBA_CACHE_DIR: /numba_cache - volumes: - - ../backend/src/app/preprocessing/ray_model_worker:/dats_code_ray - - ./spacy_models:/spacy_models - - ./backend_repo:${SHARED_REPO_ROOT:-/tmp/dats} - - ./models_cache:/models_cache - - ./numba_cache:/numba_cache - ports: - - "${RAY_API_EXPOSED:-8000}:8000" - - "${RAY_DASHBOARD_EXPOSED:-8265}:8265" - restart: always - shm_size: 12gb - networks: - - dats_demo_network - deploy: - resources: - reservations: - devices: - - driver: nvidia - device_ids: ["0"] - capabilities: [gpu] - profiles: - - ray - dats-backend-api: image: uhhlt/dats_backend:${DATS_BACKEND_DOCKER_VERSION:-debian_dev_latest} command: /dats_code/src/backend_api_entrypoint.sh @@ -250,6 +273,10 @@ services: SYSTEM_USER_EMAIL: ${SYSTEM_USER_EMAIL} SYSTEM_USER_PASSWORD: ${SYSTEM_USER_PASSWORD} SHARED_REPO_ROOT: ${SHARED_REPO_ROOT:-/tmp/dats} + OLLAMA_ENABLED: ${OLLAMA_ENABLED:-True} + OLLAMA_HOST: ${OLLAMA_HOST:-ollama} + OLLAMA_PORT: ${OLLAMA_PORT:-11434} + OLLAMA_MODEL: ${OLLAMA_MODEL:-gemma2:9b-instruct-fp16} healthcheck: test: curl --fail "http://localhost:${API_PORT}" || exit 1 interval: 60s @@ -260,6 +287,7 @@ services: - ../backend/src:/dats_code/src - ./backend_repo:${SHARED_REPO_ROOT:-/tmp/dats} depends_on: + - celery-background-jobs-worker - elasticsearch - postgres - rabbitmq @@ -269,13 +297,6 @@ services: - "${API_EXPOSED:-5500}:${API_PORT}" - "${JUPYTER_API_EXPOSED:-8888}:8888" restart: always - links: - - postgres - - redis - - rabbitmq - - celery-background-jobs-worker - - elasticsearch - - weaviate networks: - dats_demo_network profiles: diff --git a/docker/monkey_patch_docker_compose_for_backend_tests.py b/docker/monkey_patch_docker_compose_for_backend_tests.py index 88882a517..b1485a62c 100644 --- a/docker/monkey_patch_docker_compose_for_backend_tests.py +++ b/docker/monkey_patch_docker_compose_for_backend_tests.py @@ -4,17 +4,23 @@ import yaml -with open("docker-compose.yml") as f: +with open("../docker/docker-compose.yml") as f: file = f.read() data = yaml.safe_load(file) disable_ray = len(sys.argv) > 1 and sys.argv[1] == "--disable_ray" +disable_ollama = len(sys.argv) > 2 and sys.argv[2] == "--disable_ollama" if disable_ray: # remove ray as it's too resource-intensive for CI data["services"].pop("ray", None) +if disable_ollama: + # remove ray as it's too resource-intensive for CI + data["services"].pop("ollama", None) + data["services"]["celery-background-jobs-worker"]["depends_on"].remove("ollama") + for a in data["services"]: data["services"][a].pop("deploy", None) diff --git a/docker/setup-folders.sh b/docker/setup-folders.sh index 4fd99291f..75b26af5f 100755 --- a/docker/setup-folders.sh +++ b/docker/setup-folders.sh @@ -4,3 +4,4 @@ mkdir -p backend_repo mkdir -p models_cache mkdir -p spacy_models mkdir -p numba_cache +mkdir -p ollama_cache diff --git a/frontend/src/api/LLMHooks.ts b/frontend/src/api/LLMHooks.ts new file mode 100644 index 000000000..648ee63b9 --- /dev/null +++ b/frontend/src/api/LLMHooks.ts @@ -0,0 +1,71 @@ +import { useMutation, useQuery } from "@tanstack/react-query"; +import queryClient from "../plugins/ReactQueryClient.ts"; +import { QueryKey } from "./QueryKey.ts"; +import { BackgroundJobStatus } from "./openapi/models/BackgroundJobStatus.ts"; +import { LLMJobRead } from "./openapi/models/LLMJobRead.ts"; +import { LlmService } from "./openapi/services/LlmService.ts"; + +const useStartLLMJob = () => + useMutation({ + mutationFn: LlmService.startLlmJob, + onSuccess: (job) => { + // force refetch of all llm jobs when adding a new one + queryClient.invalidateQueries({ queryKey: [QueryKey.PROJECT_LLM_JOBS, job.parameters.project_id] }); + }, + meta: { + successMessage: (data: LLMJobRead) => `Started LLM Job as a new background task (ID: ${data.id})`, + }, + }); + +const usePollLLMJob = (llmJobId: string | undefined, initialData: LLMJobRead | undefined) => { + return useQuery({ + queryKey: [QueryKey.LLM_JOB, llmJobId], + queryFn: () => + LlmService.getLlmJob({ + llmJobId: llmJobId!, + }), + enabled: !!llmJobId, + refetchInterval: (query) => { + if (!query.state.data) { + return 1000; + } + if (query.state.data.status) { + switch (query.state.data.status) { + case BackgroundJobStatus.ERRORNEOUS: + case BackgroundJobStatus.FINISHED: + return false; + case BackgroundJobStatus.WAITING: + case BackgroundJobStatus.RUNNING: + return 1000; + } + } + return false; + }, + initialData, + }); +}; + +const useGetAllLLMJobs = (projectId: number) => { + return useQuery({ + queryKey: [QueryKey.PROJECT_LLM_JOBS, projectId], + queryFn: () => + LlmService.getAllLlmJobs({ + projectId: projectId!, + }), + enabled: !!projectId, + }); +}; + +const useCreatePromptTemplates = () => + useMutation({ + mutationFn: LlmService.createPromptTemplates, + }); + +const LLMHooks = { + usePollLLMJob, + useStartLLMJob, + useGetAllLLMJobs, + useCreatePromptTemplates, +}; + +export default LLMHooks; diff --git a/frontend/src/api/QueryKey.ts b/frontend/src/api/QueryKey.ts index b1c8c6697..57ed8a0b7 100644 --- a/frontend/src/api/QueryKey.ts +++ b/frontend/src/api/QueryKey.ts @@ -15,8 +15,10 @@ export const QueryKey = { PROJECT_TAGS: "projectTags", // all crawler jobs of a project (by project id) PROJECT_CRAWLER_JOBS: "projectCrawlerJobs", - // all crawler jobs of a project (by project id) + // all prepro jobs of a project (by project id) PROJECT_PREPROCESSING_JOBS: "projectPreprocessingJobs", + // all llm jobs of a project (by project id) + PROJECT_LLM_JOBS: "projectLLMJobs", // all users USERS: "users", @@ -177,6 +179,9 @@ export const QueryKey = { // crawler (by crawler job id) CRAWLER_JOB: "crawlerJob", + // crawler (by llm job id) + LLM_JOB: "llmJob", + // tables SEARCH_TABLE: "search-document-table-data", }; diff --git a/frontend/src/api/SdocMetadataHooks.ts b/frontend/src/api/SdocMetadataHooks.ts index 632a0a0bb..89dbf4be9 100644 --- a/frontend/src/api/SdocMetadataHooks.ts +++ b/frontend/src/api/SdocMetadataHooks.ts @@ -11,8 +11,19 @@ const useUpdateMetadata = () => }, }); +const useUpdateBulkMetadata = () => + useMutation({ + mutationFn: SdocMetadataService.updateBulk, + onSuccess: (metadatas) => { + metadatas.forEach((metadata) => { + queryClient.invalidateQueries({ queryKey: [QueryKey.SDOC_METADATAS, metadata.source_document_id] }); + }); + }, + }); + const SdocMetadataHooks = { useUpdateMetadata, + useUpdateBulkMetadata, }; export default SdocMetadataHooks; diff --git a/frontend/src/api/SpanAnnotationHooks.ts b/frontend/src/api/SpanAnnotationHooks.ts index 36236a989..7db451066 100644 --- a/frontend/src/api/SpanAnnotationHooks.ts +++ b/frontend/src/api/SpanAnnotationHooks.ts @@ -81,6 +81,11 @@ const useCreateAnnotation = () => }, }); +const useCreateBulkAnnotations = () => + useMutation({ + mutationFn: SpanAnnotationService.addSpanAnnotationsBulk, + }); + const useGetAnnotation = (spanId: number | null | undefined) => useQuery({ queryKey: [QueryKey.SPAN_ANNOTATION, spanId], @@ -279,6 +284,7 @@ const useCreateMemo = () => const SpanAnnotationHooks = { useCreateAnnotation, + useCreateBulkAnnotations, useGetAnnotation, useGetByCodeAndUser, useUpdateSpan, diff --git a/frontend/src/api/TagHooks.ts b/frontend/src/api/TagHooks.ts index 41f7ba735..61b029809 100644 --- a/frontend/src/api/TagHooks.ts +++ b/frontend/src/api/TagHooks.ts @@ -43,6 +43,21 @@ const useDeleteTag = () => }, }); +const useBulkSetDocumentTags = () => + useMutation({ + mutationFn: DocumentTagService.setDocumentTagsBatch, + onSuccess: (_data, variables) => { + // we need to invalidate the document tags for every document that we updated + variables.requestBody.forEach((links) => { + queryClient.invalidateQueries({ queryKey: [QueryKey.SDOC_TAGS, links.source_document_id] }); + }); + queryClient.invalidateQueries({ queryKey: [QueryKey.SDOCS_BY_PROJECT_AND_FILTERS_SEARCH] }); + queryClient.invalidateQueries({ queryKey: [QueryKey.SEARCH_TAG_STATISTICS] }); // todo: zu unspezifisch! + // Invalidate cache of tag statistics query + queryClient.invalidateQueries({ queryKey: [QueryKey.TAG_SDOC_COUNT] }); + }, + }); + const useBulkLinkDocumentTags = () => useMutation({ mutationFn: (variables: { projectId: number; requestBody: SourceDocumentDocumentTagMultiLink }) => @@ -172,6 +187,7 @@ const TagHooks = { useCreateTag, useUpdateTag, useDeleteTag, + useBulkSetDocumentTags, useBulkUpdateDocumentTags, useBulkLinkDocumentTags, useBulkUnlinkDocumentTags, diff --git a/frontend/src/api/openapi/models/AnnotationLLMJobParams.ts b/frontend/src/api/openapi/models/AnnotationLLMJobParams.ts new file mode 100644 index 000000000..74243fc61 --- /dev/null +++ b/frontend/src/api/openapi/models/AnnotationLLMJobParams.ts @@ -0,0 +1,15 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type AnnotationLLMJobParams = { + llm_job_type: any; + /** + * IDs of the source documents to analyse + */ + sdoc_ids: Array; + /** + * IDs of the codes to use for the annotation + */ + code_ids: Array; +}; diff --git a/frontend/src/api/openapi/models/AnnotationLLMJobResult.ts b/frontend/src/api/openapi/models/AnnotationLLMJobResult.ts new file mode 100644 index 000000000..815f4a384 --- /dev/null +++ b/frontend/src/api/openapi/models/AnnotationLLMJobResult.ts @@ -0,0 +1,9 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { AnnotationResult } from "./AnnotationResult"; +export type AnnotationLLMJobResult = { + llm_job_type: any; + results: Array; +}; diff --git a/frontend/src/api/openapi/models/AnnotationResult.ts b/frontend/src/api/openapi/models/AnnotationResult.ts new file mode 100644 index 000000000..c6b5b6749 --- /dev/null +++ b/frontend/src/api/openapi/models/AnnotationResult.ts @@ -0,0 +1,15 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { SpanAnnotationReadResolved } from "./SpanAnnotationReadResolved"; +export type AnnotationResult = { + /** + * ID of the source document + */ + sdoc_id: number; + /** + * Suggested annotations + */ + suggested_annotations: Array; +}; diff --git a/frontend/src/api/openapi/models/DocumentTaggingLLMJobParams.ts b/frontend/src/api/openapi/models/DocumentTaggingLLMJobParams.ts new file mode 100644 index 000000000..33b09c0b0 --- /dev/null +++ b/frontend/src/api/openapi/models/DocumentTaggingLLMJobParams.ts @@ -0,0 +1,15 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type DocumentTaggingLLMJobParams = { + llm_job_type: any; + /** + * IDs of the source documents to analyse + */ + sdoc_ids: Array; + /** + * IDs of the tags to use for the document tagging + */ + tag_ids: Array; +}; diff --git a/frontend/src/api/openapi/models/DocumentTaggingLLMJobResult.ts b/frontend/src/api/openapi/models/DocumentTaggingLLMJobResult.ts new file mode 100644 index 000000000..a0af50e40 --- /dev/null +++ b/frontend/src/api/openapi/models/DocumentTaggingLLMJobResult.ts @@ -0,0 +1,9 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { DocumentTaggingResult } from "./DocumentTaggingResult"; +export type DocumentTaggingLLMJobResult = { + llm_job_type: any; + results: Array; +}; diff --git a/frontend/src/api/openapi/models/DocumentTaggingResult.ts b/frontend/src/api/openapi/models/DocumentTaggingResult.ts new file mode 100644 index 000000000..27233b004 --- /dev/null +++ b/frontend/src/api/openapi/models/DocumentTaggingResult.ts @@ -0,0 +1,22 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type DocumentTaggingResult = { + /** + * ID of the source document + */ + sdoc_id: number; + /** + * IDs of the tags currently assigned to the document + */ + current_tag_ids: Array; + /** + * IDs of the tags suggested by the LLM to assign to the document + */ + suggested_tag_ids: Array; + /** + * Reasoning for the tagging + */ + reasoning: string; +}; diff --git a/frontend/src/api/openapi/models/LLMJobParameters.ts b/frontend/src/api/openapi/models/LLMJobParameters.ts new file mode 100644 index 000000000..1aca2dde9 --- /dev/null +++ b/frontend/src/api/openapi/models/LLMJobParameters.ts @@ -0,0 +1,27 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { AnnotationLLMJobParams } from "./AnnotationLLMJobParams"; +import type { DocumentTaggingLLMJobParams } from "./DocumentTaggingLLMJobParams"; +import type { LLMJobType } from "./LLMJobType"; +import type { LLMPromptTemplates } from "./LLMPromptTemplates"; +import type { MetadataExtractionLLMJobParams } from "./MetadataExtractionLLMJobParams"; +export type LLMJobParameters = { + /** + * The type of the LLMJob (what to llm) + */ + llm_job_type: LLMJobType; + /** + * The ID of the Project to analyse + */ + project_id: number; + /** + * The prompt templates to use for the job + */ + prompts: Array; + /** + * Specific parameters for the LLMJob w.r.t it's type + */ + specific_llm_job_parameters: DocumentTaggingLLMJobParams | MetadataExtractionLLMJobParams | AnnotationLLMJobParams; +}; diff --git a/frontend/src/api/openapi/models/LLMJobRead.ts b/frontend/src/api/openapi/models/LLMJobRead.ts new file mode 100644 index 000000000..ceb2ab231 --- /dev/null +++ b/frontend/src/api/openapi/models/LLMJobRead.ts @@ -0,0 +1,41 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { BackgroundJobStatus } from "./BackgroundJobStatus"; +import type { LLMJobParameters } from "./LLMJobParameters"; +import type { LLMJobResult } from "./LLMJobResult"; +export type LLMJobRead = { + /** + * Status of the LLMJob + */ + status?: BackgroundJobStatus; + /** + * Number of steps LLMJob has completed. + */ + num_steps_finished: number; + /** + * Number of total steps. + */ + num_steps_total: number; + /** + * Results of hte LLMJob. + */ + result?: LLMJobResult | null; + /** + * ID of the LLMJob + */ + id: string; + /** + * The parameters of the LLMJob that defines what to llm! + */ + parameters: LLMJobParameters; + /** + * Created timestamp of the LLMJob + */ + created: string; + /** + * Updated timestamp of the LLMJob + */ + updated: string; +}; diff --git a/frontend/src/api/openapi/models/LLMJobResult.ts b/frontend/src/api/openapi/models/LLMJobResult.ts new file mode 100644 index 000000000..c020e84b2 --- /dev/null +++ b/frontend/src/api/openapi/models/LLMJobResult.ts @@ -0,0 +1,18 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { AnnotationLLMJobResult } from "./AnnotationLLMJobResult"; +import type { DocumentTaggingLLMJobResult } from "./DocumentTaggingLLMJobResult"; +import type { LLMJobType } from "./LLMJobType"; +import type { MetadataExtractionLLMJobResult } from "./MetadataExtractionLLMJobResult"; +export type LLMJobResult = { + /** + * The type of the LLMJob (what to llm) + */ + llm_job_type: LLMJobType; + /** + * Specific result for the LLMJob w.r.t it's type + */ + specific_llm_job_result: DocumentTaggingLLMJobResult | MetadataExtractionLLMJobResult | AnnotationLLMJobResult; +}; diff --git a/frontend/src/api/openapi/models/LLMJobType.ts b/frontend/src/api/openapi/models/LLMJobType.ts new file mode 100644 index 000000000..3f83e2693 --- /dev/null +++ b/frontend/src/api/openapi/models/LLMJobType.ts @@ -0,0 +1,9 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export enum LLMJobType { + DOCUMENT_TAGGING = "DOCUMENT_TAGGING", + METADATA_EXTRACTION = "METADATA_EXTRACTION", + ANNOTATION = "ANNOTATION", +} diff --git a/frontend/src/api/openapi/models/LLMPromptTemplates.ts b/frontend/src/api/openapi/models/LLMPromptTemplates.ts new file mode 100644 index 000000000..c5c0fb1e6 --- /dev/null +++ b/frontend/src/api/openapi/models/LLMPromptTemplates.ts @@ -0,0 +1,18 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type LLMPromptTemplates = { + /** + * The language of the prompt template + */ + language: string; + /** + * The system prompt to use for the job + */ + system_prompt: string; + /** + * The user prompt to use for the job + */ + user_prompt: string; +}; diff --git a/frontend/src/api/openapi/models/MetadataExtractionLLMJobParams.ts b/frontend/src/api/openapi/models/MetadataExtractionLLMJobParams.ts new file mode 100644 index 000000000..328bf2a9f --- /dev/null +++ b/frontend/src/api/openapi/models/MetadataExtractionLLMJobParams.ts @@ -0,0 +1,15 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type MetadataExtractionLLMJobParams = { + llm_job_type: any; + /** + * IDs of the source documents to analyse + */ + sdoc_ids: Array; + /** + * IDs of the project metadata to use for the metadata extraction + */ + project_metadata_ids: Array; +}; diff --git a/frontend/src/api/openapi/models/MetadataExtractionLLMJobResult.ts b/frontend/src/api/openapi/models/MetadataExtractionLLMJobResult.ts new file mode 100644 index 000000000..297199833 --- /dev/null +++ b/frontend/src/api/openapi/models/MetadataExtractionLLMJobResult.ts @@ -0,0 +1,9 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { MetadataExtractionResult } from "./MetadataExtractionResult"; +export type MetadataExtractionLLMJobResult = { + llm_job_type: any; + results: Array; +}; diff --git a/frontend/src/api/openapi/models/MetadataExtractionResult.ts b/frontend/src/api/openapi/models/MetadataExtractionResult.ts new file mode 100644 index 000000000..8356b157a --- /dev/null +++ b/frontend/src/api/openapi/models/MetadataExtractionResult.ts @@ -0,0 +1,19 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { SourceDocumentMetadataReadResolved } from "./SourceDocumentMetadataReadResolved"; +export type MetadataExtractionResult = { + /** + * ID of the source document + */ + sdoc_id: number; + /** + * Current metadata + */ + current_metadata: Array; + /** + * Suggested metadata + */ + suggested_metadata: Array; +}; diff --git a/frontend/src/api/openapi/models/ProjectMetadataCreate.ts b/frontend/src/api/openapi/models/ProjectMetadataCreate.ts index 3a3695ecc..d37e5bee1 100644 --- a/frontend/src/api/openapi/models/ProjectMetadataCreate.ts +++ b/frontend/src/api/openapi/models/ProjectMetadataCreate.ts @@ -21,6 +21,10 @@ export type ProjectMetadataCreate = { * DOCTYPE of the SourceDocument this metadata refers to */ doctype: DocType; + /** + * Description of the ProjectMetadata + */ + description: string; /** * Project the ProjectMetadata belongs to */ diff --git a/frontend/src/api/openapi/models/ProjectMetadataRead.ts b/frontend/src/api/openapi/models/ProjectMetadataRead.ts index 885d3039d..8850e4108 100644 --- a/frontend/src/api/openapi/models/ProjectMetadataRead.ts +++ b/frontend/src/api/openapi/models/ProjectMetadataRead.ts @@ -21,6 +21,10 @@ export type ProjectMetadataRead = { * DOCTYPE of the SourceDocument this metadata refers to */ doctype: DocType; + /** + * Description of the ProjectMetadata + */ + description: string; /** * ID of the ProjectMetadata */ diff --git a/frontend/src/api/openapi/models/ProjectMetadataUpdate.ts b/frontend/src/api/openapi/models/ProjectMetadataUpdate.ts index 4e5418347..484ef8915 100644 --- a/frontend/src/api/openapi/models/ProjectMetadataUpdate.ts +++ b/frontend/src/api/openapi/models/ProjectMetadataUpdate.ts @@ -11,5 +11,9 @@ export type ProjectMetadataUpdate = { /** * Type of the ProjectMetadata */ - metatype: MetaType | null; + metatype?: MetaType | null; + /** + * Description of the ProjectMetadata + */ + description?: string | null; }; diff --git a/frontend/src/api/openapi/models/SourceDocumentDocumentTagLinks.ts b/frontend/src/api/openapi/models/SourceDocumentDocumentTagLinks.ts new file mode 100644 index 000000000..59b29d2e8 --- /dev/null +++ b/frontend/src/api/openapi/models/SourceDocumentDocumentTagLinks.ts @@ -0,0 +1,14 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type SourceDocumentDocumentTagLinks = { + /** + * ID of SourceDocument + */ + source_document_id: number; + /** + * List of IDs of DocumentTags + */ + document_tag_ids: Array; +}; diff --git a/frontend/src/api/openapi/models/SourceDocumentMetadataBulkUpdate.ts b/frontend/src/api/openapi/models/SourceDocumentMetadataBulkUpdate.ts new file mode 100644 index 000000000..64a0455f6 --- /dev/null +++ b/frontend/src/api/openapi/models/SourceDocumentMetadataBulkUpdate.ts @@ -0,0 +1,30 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type SourceDocumentMetadataBulkUpdate = { + /** + * Int Value of the SourceDocumentMetadata + */ + int_value: number | null; + /** + * String Value of the SourceDocumentMetadata + */ + str_value: string | null; + /** + * Boolean Value of the SourceDocumentMetadata + */ + boolean_value: boolean | null; + /** + * Date Value of the SourceDocumentMetadata + */ + date_value: string | null; + /** + * List Value of the SourceDocumentMetadata + */ + list_value: Array | null; + /** + * ID of the SourceDocumentMetadata + */ + id: number; +}; diff --git a/frontend/src/api/openapi/models/SpanAnnotationCreateBulkWithCodeId.ts b/frontend/src/api/openapi/models/SpanAnnotationCreateBulkWithCodeId.ts new file mode 100644 index 000000000..647bd1abb --- /dev/null +++ b/frontend/src/api/openapi/models/SpanAnnotationCreateBulkWithCodeId.ts @@ -0,0 +1,38 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type SpanAnnotationCreateBulkWithCodeId = { + /** + * Begin of the SpanAnnotation in characters + */ + begin: number; + /** + * End of the SpanAnnotation in characters + */ + end: number; + /** + * Begin of the SpanAnnotation in tokens + */ + begin_token: number; + /** + * End of the SpanAnnotation in tokens + */ + end_token: number; + /** + * The SpanText the SpanAnnotation spans. + */ + span_text: string; + /** + * Code the SpanAnnotation refers to + */ + code_id: number; + /** + * SourceDocument the SpanAnnotation refers to + */ + sdoc_id: number; + /** + * User the SpanAnnotation belongs to + */ + user_id: number; +}; diff --git a/frontend/src/api/openapi/services/DocumentTagService.ts b/frontend/src/api/openapi/services/DocumentTagService.ts index b0f3e6b4d..cb62b4b87 100644 --- a/frontend/src/api/openapi/services/DocumentTagService.ts +++ b/frontend/src/api/openapi/services/DocumentTagService.ts @@ -7,6 +7,7 @@ import type { DocumentTagRead } from "../models/DocumentTagRead"; import type { DocumentTagUpdate } from "../models/DocumentTagUpdate"; import type { MemoCreate } from "../models/MemoCreate"; import type { MemoRead } from "../models/MemoRead"; +import type { SourceDocumentDocumentTagLinks } from "../models/SourceDocumentDocumentTagLinks"; import type { SourceDocumentDocumentTagMultiLink } from "../models/SourceDocumentDocumentTagMultiLink"; import type { CancelablePromise } from "../core/CancelablePromise"; import { OpenAPI } from "../core/OpenAPI"; @@ -72,6 +73,26 @@ export class DocumentTagService { }, }); } + /** + * Sets SourceDocuments' tags to the provided tags + * @returns number Successful Response + * @throws ApiError + */ + public static setDocumentTagsBatch({ + requestBody, + }: { + requestBody: Array; + }): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/doctag/bulk/set", + body: requestBody, + mediaType: "application/json", + errors: { + 422: `Validation Error`, + }, + }); + } /** * Returns the DocumentTag with the given ID. * @returns DocumentTagRead Successful Response diff --git a/frontend/src/api/openapi/services/LlmService.ts b/frontend/src/api/openapi/services/LlmService.ts new file mode 100644 index 000000000..d3b09949d --- /dev/null +++ b/frontend/src/api/openapi/services/LlmService.ts @@ -0,0 +1,82 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { LLMJobParameters } from "../models/LLMJobParameters"; +import type { LLMJobRead } from "../models/LLMJobRead"; +import type { LLMPromptTemplates } from "../models/LLMPromptTemplates"; +import type { CancelablePromise } from "../core/CancelablePromise"; +import { OpenAPI } from "../core/OpenAPI"; +import { request as __request } from "../core/request"; +export class LlmService { + /** + * Returns the LLMJob for the given Parameters + * @returns LLMJobRead Successful Response + * @throws ApiError + */ + public static startLlmJob({ requestBody }: { requestBody: LLMJobParameters }): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/llm", + body: requestBody, + mediaType: "application/json", + errors: { + 422: `Validation Error`, + }, + }); + } + /** + * Returns the LLMJob for the given ID if it exists + * @returns LLMJobRead Successful Response + * @throws ApiError + */ + public static getLlmJob({ llmJobId }: { llmJobId: string }): CancelablePromise { + return __request(OpenAPI, { + method: "GET", + url: "/llm/{llm_job_id}", + path: { + llm_job_id: llmJobId, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + /** + * Returns all LLMJobRead for the given project ID if it exists + * @returns LLMJobRead Successful Response + * @throws ApiError + */ + public static getAllLlmJobs({ projectId }: { projectId: number }): CancelablePromise> { + return __request(OpenAPI, { + method: "GET", + url: "/llm/project/{project_id}", + path: { + project_id: projectId, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + /** + * Returns the system and user prompt templates for the given llm task in all supported languages + * @returns LLMPromptTemplates Successful Response + * @throws ApiError + */ + public static createPromptTemplates({ + requestBody, + }: { + requestBody: LLMJobParameters; + }): CancelablePromise> { + return __request(OpenAPI, { + method: "POST", + url: "/llm/create_prompt_templates", + body: requestBody, + mediaType: "application/json", + errors: { + 422: `Validation Error`, + }, + }); + } +} diff --git a/frontend/src/api/openapi/services/SdocMetadataService.ts b/frontend/src/api/openapi/services/SdocMetadataService.ts index db72077d3..6c6417f27 100644 --- a/frontend/src/api/openapi/services/SdocMetadataService.ts +++ b/frontend/src/api/openapi/services/SdocMetadataService.ts @@ -2,6 +2,7 @@ /* istanbul ignore file */ /* tslint:disable */ /* eslint-disable */ +import type { SourceDocumentMetadataBulkUpdate } from "../models/SourceDocumentMetadataBulkUpdate"; import type { SourceDocumentMetadataCreate } from "../models/SourceDocumentMetadataCreate"; import type { SourceDocumentMetadataRead } from "../models/SourceDocumentMetadataRead"; import type { SourceDocumentMetadataReadResolved } from "../models/SourceDocumentMetadataReadResolved"; @@ -89,4 +90,24 @@ export class SdocMetadataService { }, }); } + /** + * Updates multiple metadata objects at once. + * @returns SourceDocumentMetadataRead Successful Response + * @throws ApiError + */ + public static updateBulk({ + requestBody, + }: { + requestBody: Array; + }): CancelablePromise> { + return __request(OpenAPI, { + method: "PATCH", + url: "/sdocmeta/bulk/update", + body: requestBody, + mediaType: "application/json", + errors: { + 422: `Validation Error`, + }, + }); + } } diff --git a/frontend/src/api/openapi/services/SpanAnnotationService.ts b/frontend/src/api/openapi/services/SpanAnnotationService.ts index d108567e8..4ab8fe8fc 100644 --- a/frontend/src/api/openapi/services/SpanAnnotationService.ts +++ b/frontend/src/api/openapi/services/SpanAnnotationService.ts @@ -5,6 +5,7 @@ import type { CodeRead } from "../models/CodeRead"; import type { MemoCreate } from "../models/MemoCreate"; import type { MemoRead } from "../models/MemoRead"; +import type { SpanAnnotationCreateBulkWithCodeId } from "../models/SpanAnnotationCreateBulkWithCodeId"; import type { SpanAnnotationCreateWithCodeId } from "../models/SpanAnnotationCreateWithCodeId"; import type { SpanAnnotationRead } from "../models/SpanAnnotationRead"; import type { SpanAnnotationReadResolved } from "../models/SpanAnnotationReadResolved"; @@ -42,6 +43,34 @@ export class SpanAnnotationService { }, }); } + /** + * Creates a SpanAnnotations in Bulk + * @returns any Successful Response + * @throws ApiError + */ + public static addSpanAnnotationsBulk({ + requestBody, + resolve = true, + }: { + requestBody: Array; + /** + * If true, the current_code_id of the SpanAnnotation gets resolved and replaced by the respective Code entity + */ + resolve?: boolean; + }): CancelablePromise | Array> { + return __request(OpenAPI, { + method: "PUT", + url: "/span/bulk/create", + query: { + resolve: resolve, + }, + body: requestBody, + mediaType: "application/json", + errors: { + 422: `Validation Error`, + }, + }); + } /** * Returns the SpanAnnotation with the given ID. * @returns any Successful Response diff --git a/frontend/src/components/FormInputs/FormTextMultiline.tsx b/frontend/src/components/FormInputs/FormTextMultiline.tsx index a0a5d1651..bda5312cf 100644 --- a/frontend/src/components/FormInputs/FormTextMultiline.tsx +++ b/frontend/src/components/FormInputs/FormTextMultiline.tsx @@ -2,7 +2,7 @@ import { TextField, TextFieldProps } from "@mui/material"; import { Control, Controller, ControllerProps, FieldValues } from "react-hook-form"; interface FormTextMultilineProps extends Omit, "render"> { - textFieldProps?: Omit; + textFieldProps?: Omit; control: Control; } @@ -16,7 +16,15 @@ function FormTextMultiline({ } + render={({ field }) => ( + + )} control={control} /> ); diff --git a/frontend/src/components/LLMDialog/LLMAssistanceButton.tsx b/frontend/src/components/LLMDialog/LLMAssistanceButton.tsx new file mode 100644 index 000000000..a51df446c --- /dev/null +++ b/frontend/src/components/LLMDialog/LLMAssistanceButton.tsx @@ -0,0 +1,17 @@ +import SmartToyIcon from "@mui/icons-material/SmartToy"; +import { IconButton, Tooltip } from "@mui/material"; +import { useOpenLLMDialog } from "./useOpenLLMDialog.ts"; + +function LLMAssistanceButton({ sdocIds }: { sdocIds: number[] }) { + const openLLmDialog = useOpenLLMDialog(); + + return ( + + openLLmDialog({ selectedDocumentIds: sdocIds })}> + + + + ); +} + +export default LLMAssistanceButton; diff --git a/frontend/src/components/LLMDialog/LLMDialog.tsx b/frontend/src/components/LLMDialog/LLMDialog.tsx new file mode 100644 index 000000000..1929b66d3 --- /dev/null +++ b/frontend/src/components/LLMDialog/LLMDialog.tsx @@ -0,0 +1,86 @@ +import { ButtonProps, Dialog, DialogContent, DialogTitle, Step, StepLabel, Stepper } from "@mui/material"; +import { useMemo } from "react"; +import { LLMJobType } from "../../api/openapi/models/LLMJobType.ts"; +import { useAppSelector } from "../../plugins/ReduxHooks.ts"; +import AnnotationResultStep from "./steps/AnnotationResultStep/AnnotationResultStep.tsx"; +import CodeSelectionStep from "./steps/CodeSelectionStep.tsx"; +import DocumentTagResultStep from "./steps/DocumentTaggingResultStep/DocumentTagResultStep.tsx"; +import DocumentTagSelectionStep from "./steps/DocumentTagSelectionStep.tsx"; +import MetadataExtractionResultStep from "./steps/MetadataExtractionResultStep/MetadataExtractionResultStep.tsx"; +import MethodSelectionStep from "./steps/MethodSelectionStep.tsx"; +import ProjectMetadataSelectionStep from "./steps/ProjectMetadataSelectionStep.tsx"; +import PromptEditorStep from "./steps/PromptEditorStep.tsx"; +import StatusStep from "./steps/StatusStep.tsx"; + +export interface LLMDialogProps extends ButtonProps { + projectId: number; +} + +const title: Record = { + [LLMJobType.DOCUMENT_TAGGING]: "Document Tagging", + [LLMJobType.METADATA_EXTRACTION]: "Metadata Extraction", + [LLMJobType.ANNOTATION]: "Annotation", +}; + +const steps: Record = { + [LLMJobType.DOCUMENT_TAGGING]: ["Select method", "Select tags", "Edit prompts", "Wait", "View results"], + [LLMJobType.METADATA_EXTRACTION]: ["Select method", "Select metadata", "Edit prompts", "Wait", "View results"], + [LLMJobType.ANNOTATION]: ["Select method", "Select codes", "Edit prompts", "Wait", "View results"], +}; + +function LLMDialog({ projectId }: LLMDialogProps) { + // global client state (redux) + const open = useAppSelector((state) => state.dialog.isLLMDialogOpen); + const method = useAppSelector((state) => state.dialog.llmMethod); + const step = useAppSelector((state) => state.dialog.llmStep); + + // this defines the flow of the dialog + const contentDict: Record> = useMemo(() => { + return { + 0: { + [LLMJobType.DOCUMENT_TAGGING]: , + [LLMJobType.METADATA_EXTRACTION]: , + [LLMJobType.ANNOTATION]: , + }, + 1: { + [LLMJobType.DOCUMENT_TAGGING]: , + [LLMJobType.METADATA_EXTRACTION]: , + [LLMJobType.ANNOTATION]: , + }, + 2: { + [LLMJobType.DOCUMENT_TAGGING]: , + [LLMJobType.METADATA_EXTRACTION]: , + [LLMJobType.ANNOTATION]: , + }, + 3: { + [LLMJobType.DOCUMENT_TAGGING]: , + [LLMJobType.METADATA_EXTRACTION]: , + [LLMJobType.ANNOTATION]: , + }, + 4: { + [LLMJobType.DOCUMENT_TAGGING]: , + [LLMJobType.METADATA_EXTRACTION]: , + [LLMJobType.ANNOTATION]: , + }, + }; + }, [projectId]); + + return ( + + LLM Assistant {method && <> - {title[method]}} + + + + {steps[method || LLMJobType.DOCUMENT_TAGGING].map((label) => ( + + {label} + + ))} + + + {contentDict[step][method || LLMJobType.DOCUMENT_TAGGING]} + + ); +} + +export default LLMDialog; diff --git a/frontend/src/components/LLMDialog/LLMEvent.ts b/frontend/src/components/LLMDialog/LLMEvent.ts new file mode 100644 index 000000000..c5024e77e --- /dev/null +++ b/frontend/src/components/LLMDialog/LLMEvent.ts @@ -0,0 +1,6 @@ +import { LLMJobType } from "../../api/openapi/models/LLMJobType.ts"; + +export interface LLMAssistanceEvent { + method?: LLMJobType; + selectedDocumentIds: number[]; +} diff --git a/frontend/src/components/LLMDialog/steps/AnnotationResultStep/AnnotationResultStep.tsx b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/AnnotationResultStep.tsx new file mode 100644 index 000000000..a23da9e7a --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/AnnotationResultStep.tsx @@ -0,0 +1,167 @@ +import PlayCircleIcon from "@mui/icons-material/PlayCircle"; +import { LoadingButton, TabContext, TabList, TabPanel } from "@mui/lab"; +import { Box, Button, DialogActions, DialogContent, Tab, Typography } from "@mui/material"; +import { useEffect, useMemo, useState } from "react"; +import LLMHooks from "../../../../api/LLMHooks.ts"; +import { AnnotationLLMJobResult } from "../../../../api/openapi/models/AnnotationLLMJobResult.ts"; +import { CodeRead } from "../../../../api/openapi/models/CodeRead.ts"; +import { SpanAnnotationCreateBulkWithCodeId } from "../../../../api/openapi/models/SpanAnnotationCreateBulkWithCodeId.ts"; +import { SpanAnnotationReadResolved } from "../../../../api/openapi/models/SpanAnnotationReadResolved.ts"; +import SpanAnnotationHooks from "../../../../api/SpanAnnotationHooks.ts"; +import { useAuth } from "../../../../auth/useAuth.ts"; +import { useAppDispatch, useAppSelector } from "../../../../plugins/ReduxHooks.ts"; +import { CRUDDialogActions } from "../../../dialogSlice.ts"; +import SdocRenderer from "../../../SourceDocument/SdocRenderer.tsx"; +import LLMUtterance from "../LLMUtterance.tsx"; +import TextAnnotationValidator from "./TextAnnotationValidator.tsx"; + +function AnnotationResultStep() { + // get user + const { user } = useAuth(); + + // get the job + const llmJobId = useAppSelector((state) => state.dialog.llmJobId); + const llmJob = LLMHooks.usePollLLMJob(llmJobId, undefined); + + // we extract the codes from the job + const codesForSelection = useMemo(() => { + if (!llmJob.data || !llmJob.data.result) return []; + const annotationResults = (llmJob.data.result?.specific_llm_job_result as AnnotationLLMJobResult).results; + const annotations = annotationResults.reduce((acc, r) => { + acc.push(...r.suggested_annotations); + return acc; + }, []); + + return Object.values( + annotations.reduce>((acc, a) => { + acc[a.code.id] = a.code; + return acc; + }, {}), + ); + }, [llmJob.data]); + + // local state to manage tabs + const [tab, setTab] = useState(); + const handleChangeTab = (_: React.SyntheticEvent, newValue: string) => { + setTab(newValue); + }; + + // local state to manage annotations + const [annotations, setAnnotations] = useState>(); + const handleChangeAnnotations = (sdocId: number) => (annotations: SpanAnnotationReadResolved[]) => { + setAnnotations((prev) => { + return { + ...prev, + [sdocId]: annotations, + }; + }); + }; + + // init state + useEffect(() => { + if (llmJob.data) { + setTab((llmJob.data.result?.specific_llm_job_result as AnnotationLLMJobResult).results[0].sdoc_id.toString()); + setAnnotations( + (llmJob.data.result?.specific_llm_job_result as AnnotationLLMJobResult).results.reduce< + Record + >((acc, r) => { + acc[r.sdoc_id] = r.suggested_annotations; + return acc; + }, {}), + ); + } + }, [llmJob.data]); + + // actions + const dispatch = useAppDispatch(); + const handleClose = () => { + dispatch(CRUDDialogActions.closeLLMDialog()); + }; + + const createBulkAnnotationsMutation = SpanAnnotationHooks.useCreateBulkAnnotations(); + const handleApplySuggestedAnnotations = () => { + if (!annotations || !user) return; + + createBulkAnnotationsMutation.mutate( + { + requestBody: Object.entries(annotations).reduce((acc, [sdocId, sdocAnnos]) => { + const sdocIdInt = parseInt(sdocId); + for (const annotation of sdocAnnos) { + acc.push({ + sdoc_id: sdocIdInt, + code_id: annotation.code.id, + begin: annotation.begin, + end: annotation.end, + begin_token: annotation.begin_token, + end_token: annotation.end_token, + span_text: annotation.span_text, + user_id: user.id, + }); + } + return acc; + }, [] as SpanAnnotationCreateBulkWithCodeId[]), + }, + { + onSuccess: () => { + dispatch(CRUDDialogActions.closeLLMDialog()); + }, + }, + ); + }; + + return ( + <> + + + + Here are the results! My suggestions are highlighted in the documents. Now, you can decide what to do with + them. You can click on an annotation and either: + +
    +
  • Delete my suggestion
  • +
  • Change the code of my annotated text passage
  • +
+ Remember to look through all the documents. +
+ {llmJob.isSuccess && llmJob.data.result && tab && annotations && ( + + + + {Object.keys(annotations).map((sdocId) => ( + } value={sdocId} /> + ))} + + + {Object.entries(annotations).map(([sdocIdStr, annotations]) => { + const sdocId = parseInt(sdocIdStr); + return ( + + + + ); + })} + + )} +
+ + + } + loading={createBulkAnnotationsMutation.isPending} + loadingPosition="start" + onClick={handleApplySuggestedAnnotations} + > + Apply annotations! + + + + ); +} + +export default AnnotationResultStep; diff --git a/frontend/src/components/LLMDialog/steps/AnnotationResultStep/TextAnnotationValidationMenu.tsx b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/TextAnnotationValidationMenu.tsx new file mode 100644 index 000000000..b0a8c9321 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/TextAnnotationValidationMenu.tsx @@ -0,0 +1,211 @@ +import DeleteIcon from "@mui/icons-material/Delete"; +import EditIcon from "@mui/icons-material/Edit"; +import { + Autocomplete, + Box, + createFilterOptions, + IconButton, + List, + ListItem, + ListItemText, + Popover, + PopoverPosition, + TextField, + Tooltip, + UseAutocompleteProps, +} from "@mui/material"; +import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react"; +import { CodeRead } from "../../../../api/openapi/models/CodeRead.ts"; +import { SpanAnnotationReadResolved } from "../../../../api/openapi/models/SpanAnnotationReadResolved.ts"; + +interface ICodeFilter extends CodeRead { + title: string; +} + +const filter = createFilterOptions(); + +export interface TextAnnotationValidationMenuProps { + codesForSelection: CodeRead[]; + onClose?: (reason?: "backdropClick" | "escapeKeyDown") => void; + onEdit: (annotationToEdit: SpanAnnotationReadResolved, newCode: CodeRead) => void; + onDelete: (annotationToDelete: SpanAnnotationReadResolved) => void; +} + +export interface TextAnnotationValidationMenuHandle { + open: (position: PopoverPosition, annotations?: SpanAnnotationReadResolved[] | undefined) => void; +} + +const TextAnnotationValidationMenu = forwardRef( + ({ codesForSelection, onClose, onEdit, onDelete }, ref) => { + // local client state + const [position, setPosition] = useState({ top: 0, left: 0 }); + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [showCodeSelection, setShowCodeSelection] = useState(false); + const [isAutoCompleteOpen, setIsAutoCompleteOpen] = useState(false); + const [annotationsToEdit, setAnnotationsToEdit] = useState(undefined); + const [editingAnnotation, setEditingAnnotation] = useState(undefined); + const [autoCompleteValue, setAutoCompleteValue] = useState(null); + + // computed + const codeOptions: ICodeFilter[] = useMemo(() => { + return codesForSelection.map((c) => { + return { + ...c, + title: c.name, + }; + }); + }, [codesForSelection]); + + // exposed methods (via ref) + useImperativeHandle(ref, () => ({ + open: openCodeSelector, + })); + + // methods + const openCodeSelector = ( + position: PopoverPosition, + annotations: SpanAnnotationReadResolved[] | undefined = undefined, + ) => { + setEditingAnnotation(undefined); + setAnnotationsToEdit(annotations); + setShowCodeSelection(annotations === undefined); + setIsPopoverOpen(true); + setPosition(position); + }; + + const closeCodeSelector = (reason?: "backdropClick" | "escapeKeyDown") => { + setShowCodeSelection(false); + setIsPopoverOpen(false); + setIsAutoCompleteOpen(false); + setAutoCompleteValue(null); + if (onClose) onClose(reason); + }; + + // effects + // automatically open the autocomplete soon after the code selection is shown + useEffect(() => { + if (showCodeSelection) { + setTimeout(() => { + setIsAutoCompleteOpen(showCodeSelection); + }, 250); + } + }, [showCodeSelection]); + + // event handlers + const handleChange: UseAutocompleteProps["onChange"] = (_event, newValue) => { + if (!editingAnnotation) { + console.error("editingAnnotation is undefined"); + return; + } + + if (newValue === null) { + return; + } + + onEdit(editingAnnotation!, newValue); + closeCodeSelector(); + }; + + const handleEdit = (annotationToEdit: SpanAnnotationReadResolved, code: CodeRead) => { + setEditingAnnotation(annotationToEdit); + setAutoCompleteValue({ ...code, title: code.name }); + setShowCodeSelection(true); + }; + + const handleDelete = (annotation: SpanAnnotationReadResolved) => { + onDelete(annotation); + closeCodeSelector(); + }; + + return ( + closeCodeSelector(reason)} + anchorPosition={position} + anchorReference="anchorPosition" + anchorOrigin={{ + vertical: "top", + horizontal: "left", + }} + transformOrigin={{ + vertical: "top", + horizontal: "left", + }} + > + {!showCodeSelection && annotationsToEdit ? ( + + {annotationsToEdit.map((annotation) => ( + + ))} + + ) : ( + <> + { + return filter(options, params); + }} + options={codeOptions} + getOptionLabel={(option) => { + // Value selected with enter, right from the input + if (typeof option === "string") { + return option; + } + return option.name; + }} + renderOption={(props, option) => ( +
  • + {" "} + {option.title} +
  • + )} + sx={{ width: 300 }} + renderInput={(params) => } + autoHighlight + selectOnFocus + clearOnBlur + handleHomeEndKeys + open={isAutoCompleteOpen} + onClose={(_event, reason) => reason === "escape" && closeCodeSelector("escapeKeyDown")} + /> + + )} +
    + ); + }, +); + +export default TextAnnotationValidationMenu; + +interface AnnotationMenuListItemProps { + code: CodeRead; + annotation: SpanAnnotationReadResolved; + handleDelete: (annotationToDelete: SpanAnnotationReadResolved) => void; + handleEdit: (annotationToEdit: SpanAnnotationReadResolved, newCode: CodeRead) => void; +} + +function AnnotationMenuListItem({ code, annotation, handleEdit, handleDelete }: AnnotationMenuListItemProps) { + return ( + + + + + handleDelete(annotation)}> + + + + + handleEdit(annotation, code)}> + + + + + ); +} diff --git a/frontend/src/components/LLMDialog/steps/AnnotationResultStep/TextAnnotationValidator.tsx b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/TextAnnotationValidator.tsx new file mode 100644 index 000000000..cd254f112 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/TextAnnotationValidator.tsx @@ -0,0 +1,151 @@ +import { MouseEventHandler, useRef } from "react"; +import { CodeRead } from "../../../../api/openapi/models/CodeRead.ts"; +import { SourceDocumentWithDataRead } from "../../../../api/openapi/models/SourceDocumentWithDataRead.ts"; +import { SpanAnnotationReadResolved } from "../../../../api/openapi/models/SpanAnnotationReadResolved.ts"; +import SdocHooks from "../../../../api/SdocHooks.ts"; +import DocumentRenderer from "../../../../views/annotation/DocumentRenderer/DocumentRenderer.tsx"; +import useComputeTokenDataWithAnnotations from "../../../../views/annotation/DocumentRenderer/useComputeTokenDataWithAnnotations.ts"; +import TextAnnotationValidationMenu, { + TextAnnotationValidationMenuHandle, + TextAnnotationValidationMenuProps, +} from "./TextAnnotationValidationMenu.tsx"; +import "./validatorStyles.css"; + +interface TextAnnotatorValidatorSharedProps { + codesForSelection: CodeRead[]; + annotations: SpanAnnotationReadResolved[]; + handleChangeAnnotations: (annotations: SpanAnnotationReadResolved[]) => void; +} + +interface TextAnnotatorValidatorProps extends TextAnnotatorValidatorSharedProps { + sdocId: number; +} + +function TextAnnotationValidator({ + sdocId, + codesForSelection, + annotations, + handleChangeAnnotations, +}: TextAnnotatorValidatorProps) { + const sdoc = SdocHooks.useGetDocument(sdocId); + + if (sdoc.isSuccess) { + return ( + + ); + } + return null; +} + +interface TextAnnotatorValidatorWithSdocProps extends TextAnnotatorValidatorSharedProps { + sdoc: SourceDocumentWithDataRead; +} + +function TextAnnotationValidatorWithSdoc({ + sdoc, + codesForSelection, + annotations, + handleChangeAnnotations, +}: TextAnnotatorValidatorWithSdocProps) { + // local state + const menuRef = useRef(null); + + // computed + const { tokenData, annotationsPerToken, annotationMap } = useComputeTokenDataWithAnnotations({ + sdoc: sdoc, + annotations: annotations, + }); + + // actions + const handleMouseUp: MouseEventHandler = (event) => { + if (event.button === 2 || !tokenData || !annotationsPerToken || !annotationMap) return; + + // try to find a parent element that has the tok class, we go up 3 levels at maximum + let target: HTMLElement = event.target as HTMLElement; + let found = false; + for (let i = 0; i < 3; i++) { + if (target && target.classList.contains("tok") && target.childElementCount > 0) { + found = true; + break; + } + if (target.parentElement) { + target = target.parentElement; + } else { + break; + } + } + if (!found) return; + + event.preventDefault(); + + // get all annotations that span this token + const tokenIndex = parseInt(target.getAttribute("data-tokenid")!); + const annos = annotationsPerToken.get(tokenIndex); + + // open code selector if there are annotations + if (annos) { + // calculate position of the code selector (based on selection end) + const boundingBox = target.getBoundingClientRect(); + const position = { + left: boundingBox.left, + top: boundingBox.top + boundingBox.height, + }; + + // open code selector + menuRef.current!.open( + position, + annos.map((a) => annotationMap.get(a)!), + ); + } + }; + + const handleClose: TextAnnotationValidationMenuProps["onClose"] = () => {}; + + const handleEdit: TextAnnotationValidationMenuProps["onEdit"] = (annotationToEdit, newCode) => { + handleChangeAnnotations( + annotations.map((a) => { + if (a.id === annotationToEdit.id) { + return { + ...a, + code: newCode, + }; + } + return a; + }), + ); + }; + + const handleDelete: TextAnnotationValidationMenuProps["onDelete"] = (annotationToDelete) => { + handleChangeAnnotations(annotations.filter((a) => a.id !== annotationToDelete.id)); + }; + + return ( + <> + + + + ); +} + +export default TextAnnotationValidator; diff --git a/frontend/src/components/LLMDialog/steps/AnnotationResultStep/validatorStyles.css b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/validatorStyles.css new file mode 100644 index 000000000..aba7c0c82 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/AnnotationResultStep/validatorStyles.css @@ -0,0 +1,3 @@ +[class*="span-"] { + cursor: pointer; +} diff --git a/frontend/src/components/LLMDialog/steps/CodeSelectionStep.tsx b/frontend/src/components/LLMDialog/steps/CodeSelectionStep.tsx new file mode 100644 index 000000000..c9bdf1af8 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/CodeSelectionStep.tsx @@ -0,0 +1,88 @@ +import PlayCircleIcon from "@mui/icons-material/PlayCircle"; +import { LoadingButton } from "@mui/lab"; +import { Box, Button, DialogActions, DialogContent, Typography } from "@mui/material"; +import { MRT_RowSelectionState } from "material-react-table"; +import { useState } from "react"; +import LLMHooks from "../../../api/LLMHooks.ts"; +import { CodeRead } from "../../../api/openapi/models/CodeRead.ts"; +import { LLMJobType } from "../../../api/openapi/models/LLMJobType.ts"; +import { useAppDispatch, useAppSelector } from "../../../plugins/ReduxHooks.ts"; +import CodeTable from "../../Code/CodeTable.tsx"; +import { CRUDDialogActions } from "../../dialogSlice.ts"; +import LLMUtterance from "./LLMUtterance.tsx"; + +function CodeSelectionStep({ projectId }: { projectId: number }) { + // local state + const [rowSelectionModel, setRowSelectionModel] = useState({}); + + // global state + const selectedDocuments = useAppSelector((state) => state.dialog.llmDocumentIds); + const dispatch = useAppDispatch(); + + // initiate next step (get the generated prompts) + const createPromptTemplatesMutation = LLMHooks.useCreatePromptTemplates(); + const handleNext = (codes: CodeRead[]) => () => { + createPromptTemplatesMutation.mutate( + { + requestBody: { + llm_job_type: LLMJobType.ANNOTATION, + project_id: projectId, + prompts: [], + specific_llm_job_parameters: { + llm_job_type: LLMJobType.ANNOTATION, + code_ids: codes.map((code) => code.id), + sdoc_ids: selectedDocuments, + }, + }, + }, + { + onSuccess(data) { + dispatch( + CRUDDialogActions.llmDialogGoToPromptEditor({ prompts: data, tags: [], metadata: [], codes: codes }), + ); + }, + }, + ); + }; + + return ( + <> + + + + You selected {selectedDocuments.length} document(s) for automatic annotation. Please select all codes that I + should use to annotate text passages. + + + + ( + + + + } + loading={createPromptTemplatesMutation.isPending} + loadingPosition="start" + disabled={props.selectedCodes.length === 0} + onClick={handleNext(props.selectedCodes)} + > + Next! + + + )} + /> + + ); +} + +export default CodeSelectionStep; diff --git a/frontend/src/components/LLMDialog/steps/DocumentTagSelectionStep.tsx b/frontend/src/components/LLMDialog/steps/DocumentTagSelectionStep.tsx new file mode 100644 index 000000000..44df84d42 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/DocumentTagSelectionStep.tsx @@ -0,0 +1,86 @@ +import PlayCircleIcon from "@mui/icons-material/PlayCircle"; +import { LoadingButton } from "@mui/lab"; +import { Box, Button, DialogActions, DialogContent, Typography } from "@mui/material"; +import { MRT_RowSelectionState } from "material-react-table"; +import { useState } from "react"; +import LLMHooks from "../../../api/LLMHooks.ts"; +import { DocumentTagRead } from "../../../api/openapi/models/DocumentTagRead.ts"; +import { LLMJobType } from "../../../api/openapi/models/LLMJobType.ts"; +import { useAppDispatch, useAppSelector } from "../../../plugins/ReduxHooks.ts"; +import TagTable from "../../Tag/TagTable.tsx"; +import { CRUDDialogActions } from "../../dialogSlice.ts"; +import LLMUtterance from "./LLMUtterance.tsx"; + +function DocumentTagSelectionStep({ projectId }: { projectId: number }) { + // local state + const [rowSelectionModel, setRowSelectionModel] = useState({}); + + // global state + const selectedDocuments = useAppSelector((state) => state.dialog.llmDocumentIds); + const dispatch = useAppDispatch(); + + // initiate next step (get the generated prompts) + const createPromptTemplatesMutation = LLMHooks.useCreatePromptTemplates(); + const handleNext = (tags: DocumentTagRead[]) => () => { + createPromptTemplatesMutation.mutate( + { + requestBody: { + llm_job_type: LLMJobType.DOCUMENT_TAGGING, + project_id: projectId, + prompts: [], + specific_llm_job_parameters: { + llm_job_type: LLMJobType.DOCUMENT_TAGGING, + tag_ids: tags.map((tag) => tag.id), + sdoc_ids: selectedDocuments, + }, + }, + }, + { + onSuccess(data) { + dispatch(CRUDDialogActions.llmDialogGoToPromptEditor({ prompts: data, tags: tags, metadata: [], codes: [] })); + }, + }, + ); + }; + + return ( + <> + + + + You selected {selectedDocuments.length} document(s) for automatic document tagging. Please select all tags + that I should use to classify the documents. + + + + ( + + + + } + loading={createPromptTemplatesMutation.isPending} + loadingPosition="start" + disabled={props.selectedTags.length === 0} + onClick={handleNext(props.selectedTags)} + > + Next! + + + )} + /> + + ); +} + +export default DocumentTagSelectionStep; diff --git a/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTagResultStep.tsx b/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTagResultStep.tsx new file mode 100644 index 000000000..fd6a12dc3 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTagResultStep.tsx @@ -0,0 +1,90 @@ +import LabelIcon from "@mui/icons-material/Label"; +import { LoadingButton } from "@mui/lab"; +import { Button, DialogActions, DialogContent, Typography } from "@mui/material"; +import { useState } from "react"; +import LLMHooks from "../../../../api/LLMHooks.ts"; +import { DocumentTaggingLLMJobResult } from "../../../../api/openapi/models/DocumentTaggingLLMJobResult.ts"; +import ProjectHooks from "../../../../api/ProjectHooks.ts"; +import TagHooks from "../../../../api/TagHooks.ts"; +import { useAppDispatch, useAppSelector } from "../../../../plugins/ReduxHooks.ts"; +import { CRUDDialogActions } from "../../../dialogSlice.ts"; +import LLMUtterance from "../LLMUtterance.tsx"; +import { DocumentTaggingResultRow } from "./DocumentTaggingResultRow.ts"; +import DocumentTagResultStepTable from "./DocumentTagResultStepTable.tsx"; + +function DocumentTagResultStep({ projectId }: { projectId: number }) { + // local client state + const [rows, setRows] = useState([]); + + // global client state + const llmJobId = useAppSelector((state) => state.dialog.llmJobId); + const dispatch = useAppDispatch(); + + // global server state + const documentTags = ProjectHooks.useGetAllTags(projectId); + + // get the job + const llmJob = LLMHooks.usePollLLMJob(llmJobId, undefined); + + // actions + const handleClose = () => { + dispatch(CRUDDialogActions.closeLLMDialog()); + }; + + const applyTagsMutation = TagHooks.useBulkSetDocumentTags(); + const handleApplyNewTags = () => { + applyTagsMutation.mutate( + { + requestBody: rows.map((row) => ({ + source_document_id: row.sdocId, + document_tag_ids: row.merged_tags.map((tag) => tag.id), + })), + }, + { + onSuccess() { + dispatch(CRUDDialogActions.closeLLMDialog()); + }, + }, + ); + }; + + return ( + <> + + + + Here are the results! You can find my suggestions in the column Suggested Tags. Now, you decide what + to do with them: + +
      +
    • Use your current tags (discarding my suggestions)
    • +
    • Use my suggested tags (discarding the current tags)
    • +
    • Merge both your current tags and my suggested tags
    • +
    +
    +
    + {documentTags.isSuccess && llmJob.isSuccess && ( + + )} + + + } + onClick={handleApplyNewTags} + loading={applyTagsMutation.isPending} + loadingPosition="start" + > + Apply new tags + + + + ); +} + +export default DocumentTagResultStep; diff --git a/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTagResultStepTable.tsx b/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTagResultStepTable.tsx new file mode 100644 index 000000000..3e7cb4db3 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTagResultStepTable.tsx @@ -0,0 +1,221 @@ +import { Box, Button, Stack, Typography } from "@mui/material"; +import { + MaterialReactTable, + MRT_ColumnDef, + MRT_RowModel, + MRT_RowSelectionState, + MRT_ShowHideColumnsButton, + MRT_ToggleDensePaddingButton, + useMaterialReactTable, +} from "material-react-table"; +import { useEffect, useState } from "react"; +import { DocumentTaggingResult } from "../../../../api/openapi/models/DocumentTaggingResult.ts"; +import { DocumentTagRead } from "../../../../api/openapi/models/DocumentTagRead.ts"; +import SdocRenderer from "../../../SourceDocument/SdocRenderer.tsx"; +import TagRenderer from "../../../Tag/TagRenderer.tsx"; +import { DocumentTaggingResultRow } from "./DocumentTaggingResultRow.ts"; + +function CustomTagsRenderer({ tags }: { tags: DocumentTagRead[] }) { + if (tags.length === 0) { + return no tags; + } + return ( + + {tags.map((tag) => ( + + ))} + + ); +} + +const columns: MRT_ColumnDef[] = [ + { + id: "Filename", + header: "Document", + Cell: ({ row }) => , + }, + { + id: "CurrentTags", + header: "Current Tags", + Cell: ({ row }) => , + }, + { + id: "SuggestedTags", + header: "Suggested Tags", + Cell: ({ row }) => , + }, + { + id: "FinalTags", + header: "Final Tags", + Cell: ({ row }) => , + }, +]; + +function DocumentTagResultStepTable({ + data, + projectTags, + rows, + onUpdateRows, +}: { + data: DocumentTaggingResult[]; + projectTags: DocumentTagRead[]; + rows: DocumentTaggingResultRow[]; + onUpdateRows: React.Dispatch>; +}) { + // local state + const [rowSelectionModel, setRowSelectionModel] = useState({}); + const buttonsDisabled = Object.keys(rowSelectionModel).length === 0; + + // init rows + useEffect(() => { + const tagId2Tag = projectTags.reduce( + (acc, tag) => { + acc[tag.id] = tag; + return acc; + }, + {} as Record, + ); + + onUpdateRows( + data.map((result) => { + return { + sdocId: result.sdoc_id, + current_tags: result.current_tag_ids.map((tagId) => tagId2Tag[tagId]), + suggested_tags: result.suggested_tag_ids.map((tagId) => tagId2Tag[tagId]), + merged_tags: [...new Set([...result.current_tag_ids, ...result.suggested_tag_ids])].map( + (tagId) => tagId2Tag[tagId], + ), + reasoning: result.reasoning, + }; + }), + ); + }, [data, onUpdateRows, projectTags]); + + // actions + const applyCurrentTags = (selectedRows: MRT_RowModel) => () => { + onUpdateRows((rows) => { + const result = [...rows]; + selectedRows.rows.forEach((selectedRow) => { + result[selectedRow.index] = { + ...result[selectedRow.index], + merged_tags: result[selectedRow.index].current_tags, + }; + }); + return result; + }); + }; + + const applySuggestedTags = (selectedRows: MRT_RowModel) => () => { + onUpdateRows((rows) => { + const result = [...rows]; + selectedRows.rows.forEach((selectedRow) => { + result[selectedRow.index] = { + ...result[selectedRow.index], + merged_tags: result[selectedRow.index].suggested_tags, + }; + }); + return result; + }); + }; + + const applyMergeTags = (selectedRows: MRT_RowModel) => () => { + onUpdateRows((rows) => { + const result = [...rows]; + selectedRows.rows.forEach((selectedRow) => { + result[selectedRow.index] = { + ...result[selectedRow.index], + merged_tags: [ + ...new Set([...result[selectedRow.index].current_tags, ...result[selectedRow.index].suggested_tags]), + ], + }; + }); + return result; + }); + }; + + // table + const table = useMaterialReactTable({ + data: rows, + columns: columns, + getRowId: (row) => `${row.sdocId}`, + // state + state: { + rowSelection: rowSelectionModel, + }, + // selection + enableRowSelection: true, + positionToolbarAlertBanner: "bottom", + onRowSelectionChange: setRowSelectionModel, + // expansion + enableExpandAll: false, //disable expand all button + positionExpandColumn: "last", + muiExpandButtonProps: ({ row, table }) => ({ + onClick: () => table.setExpanded({ [row.id]: !row.getIsExpanded() }), //only 1 detail panel open at a time + sx: { + transform: row.getIsExpanded() ? "rotate(180deg)" : "rotate(90deg)", + transition: "transform 0.2s", + }, + }), + renderDetailPanel: ({ row }) => ( + + {row.original.reasoning} + + ), + localization: { + expand: "Explanation", + }, + // style + muiTablePaperProps: { + elevation: 0, + style: { height: "100%", display: "flex", flexDirection: "column" }, + }, + muiTableContainerProps: { + style: { flexGrow: 1 }, + }, + // virtualization (scrolling instead of pagination) + enablePagination: false, + enableRowVirtualization: true, + // hide columns per default + initialState: { + columnVisibility: { + id: false, + }, + }, + // toolbars + enableBottomToolbar: true, + renderTopToolbarCustomActions: ({ table }) => ( + + + Merging strategy: + + + + + + ), + renderToolbarInternalActions: ({ table }) => ( + + + + + ), + }); + + return ; +} + +export default DocumentTagResultStepTable; diff --git a/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTaggingResultRow.ts b/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTaggingResultRow.ts new file mode 100644 index 000000000..276f91bcd --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/DocumentTaggingResultStep/DocumentTaggingResultRow.ts @@ -0,0 +1,9 @@ +import { DocumentTagRead } from "../../../../api/openapi/models/DocumentTagRead.ts"; + +export interface DocumentTaggingResultRow { + sdocId: number; + current_tags: Array; + suggested_tags: Array; + merged_tags: Array; + reasoning: string; +} diff --git a/frontend/src/components/LLMDialog/steps/LLMUtterance.tsx b/frontend/src/components/LLMDialog/steps/LLMUtterance.tsx new file mode 100644 index 000000000..58ab92c5a --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/LLMUtterance.tsx @@ -0,0 +1,15 @@ +import SmartToyIcon from "@mui/icons-material/SmartToy"; +import { Box, Stack } from "@mui/material"; + +function LLMUtterance({ children }: { children?: React.ReactNode }) { + return ( + + + + {children} + + + ); +} + +export default LLMUtterance; diff --git a/frontend/src/components/LLMDialog/steps/MetadataExtractionResultStep/MetadataExtractionResultStep.tsx b/frontend/src/components/LLMDialog/steps/MetadataExtractionResultStep/MetadataExtractionResultStep.tsx new file mode 100644 index 000000000..759eb66ac --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/MetadataExtractionResultStep/MetadataExtractionResultStep.tsx @@ -0,0 +1,39 @@ +import { DialogContent, Typography } from "@mui/material"; +import LLMHooks from "../../../../api/LLMHooks.ts"; +import { MetadataExtractionLLMJobResult } from "../../../../api/openapi/models/MetadataExtractionLLMJobResult.ts"; +import { useAppSelector } from "../../../../plugins/ReduxHooks.ts"; +import LLMUtterance from "../LLMUtterance.tsx"; +import MetadataExtractionResultStepTable from "./MetadataExtractionResultStepTable.tsx"; + +function MetadataExtractionResultStep() { + // get the job + const llmJobId = useAppSelector((state) => state.dialog.llmJobId); + const llmJob = LLMHooks.usePollLLMJob(llmJobId, undefined); + + return ( + <> + + + + Here are the results! You can find my suggestions in the columns marked with (suggested). Now, you + decide what to do with them: + +
      +
    • Use your current metadata values (discarding my suggestions)
    • +
    • Use my suggested metadata values (discarding the current value)
    • +
    + + Of course, you can decided individually for each document. Just click on the value you want to use. + +
    +
    + {llmJob.isSuccess && ( + + )} + + ); +} + +export default MetadataExtractionResultStep; diff --git a/frontend/src/components/LLMDialog/steps/MetadataExtractionResultStep/MetadataExtractionResultStepTable.tsx b/frontend/src/components/LLMDialog/steps/MetadataExtractionResultStep/MetadataExtractionResultStepTable.tsx new file mode 100644 index 000000000..42d5b435a --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/MetadataExtractionResultStep/MetadataExtractionResultStepTable.tsx @@ -0,0 +1,340 @@ +import LabelIcon from "@mui/icons-material/Label"; +import { LoadingButton } from "@mui/lab"; +import { Box, Button, DialogActions, Stack, Typography } from "@mui/material"; +import { + MaterialReactTable, + MRT_ColumnDef, + MRT_RowModel, + MRT_RowSelectionState, + MRT_ShowHideColumnsButton, + MRT_ToggleDensePaddingButton, + useMaterialReactTable, +} from "material-react-table"; +import { useEffect, useMemo, useState } from "react"; +import { MetadataExtractionResult } from "../../../../api/openapi/models/MetadataExtractionResult.ts"; +import { ProjectMetadataRead } from "../../../../api/openapi/models/ProjectMetadataRead.ts"; +import { SourceDocumentMetadataBulkUpdate } from "../../../../api/openapi/models/SourceDocumentMetadataBulkUpdate.ts"; +import { SourceDocumentMetadataReadResolved } from "../../../../api/openapi/models/SourceDocumentMetadataReadResolved.ts"; +import SdocMetadataHooks from "../../../../api/SdocMetadataHooks.ts"; +import { useAppDispatch } from "../../../../plugins/ReduxHooks.ts"; +import { CRUDDialogActions } from "../../../dialogSlice.ts"; +import { SdocMetadataRendererWithData } from "../../../Metadata/SdocMetadataRenderer.tsx"; +import SdocRenderer from "../../../SourceDocument/SdocRenderer.tsx"; + +interface MetadataExtractionResultRow { + sdocId: number; + metadataDict: Record< + number, + { + currentValue: SourceDocumentMetadataReadResolved; + suggestedValue?: SourceDocumentMetadataReadResolved; + useSuggested: boolean; + } + >; +} + +function MetadataExtractionResultStepTable({ data }: { data: MetadataExtractionResult[] }) { + // local state + const [rowSelectionModel, setRowSelectionModel] = useState({}); + const buttonsDisabled = Object.keys(rowSelectionModel).length === 0; + + // map the data to project metadata result rows + const { rows2, projectMetadataDict } = useMemo(() => { + const projectMetadataDict: Record = {}; + const rows: MetadataExtractionResultRow[] = []; + for (const result of data) { + const currentMetadataDict: Record = result.current_metadata.reduce( + (acc, metadata) => { + acc[metadata.project_metadata.id] = metadata; + return acc; + }, + {} as Record, + ); + const suggestedMetadataDict: Record = + result.suggested_metadata.reduce( + (acc, metadata) => { + acc[metadata.project_metadata.id] = metadata; + return acc; + }, + {} as Record, + ); + + const row: MetadataExtractionResultRow = { + sdocId: result.sdoc_id, + metadataDict: {}, + }; + for (const projectMetadataId of Object.keys(currentMetadataDict)) { + const pmId = parseInt(projectMetadataId); + row.metadataDict[pmId] = { + currentValue: currentMetadataDict[pmId], + suggestedValue: suggestedMetadataDict[pmId], + useSuggested: true, + }; + projectMetadataDict[pmId] = currentMetadataDict[pmId].project_metadata; + } + rows.push(row); + } + return { rows2: rows, projectMetadataDict: projectMetadataDict }; + }, [data]); + + // init the rows + const [theRows, setTheRows] = useState([]); + useEffect(() => { + console.log("init rows!"); + setTheRows(rows2); + }, [rows2]); + + // actions + const handleSelectCell = (sdocId: number, projectMetadataId: number) => () => { + setTheRows((rows) => { + // flip the useSuggested flag + return rows.map((row) => { + if (row.sdocId === sdocId) { + return { + ...row, + metadataDict: { + ...row.metadataDict, + [projectMetadataId]: { + ...row.metadataDict[projectMetadataId], + useSuggested: !row.metadataDict[projectMetadataId].useSuggested, + }, + }, + }; + } + return row; + }); + }); + }; + + const applyCurrentMetadata = (selectedRows: MRT_RowModel) => () => { + // for all the selectedRows, set the useSuggested flag to false + setTheRows((rows) => { + return rows.map((row) => { + if (selectedRows.rowsById[`${row.sdocId}`]) { + return { + ...row, + metadataDict: Object.keys(row.metadataDict).reduce( + (acc, key) => { + const pmId = parseInt(key); + acc[pmId] = { + ...row.metadataDict[pmId], + useSuggested: false, + }; + return acc; + }, + {} as MetadataExtractionResultRow["metadataDict"], + ), + }; + } + return row; + }); + }); + }; + + const applySuggestedMetadata = (selectedRows: MRT_RowModel) => () => { + // for all the selectedRows, set the useSuggested flag to false + setTheRows((rows) => { + return rows.map((row) => { + if (selectedRows.rowsById[`${row.sdocId}`]) { + return { + ...row, + metadataDict: Object.keys(row.metadataDict).reduce( + (acc, key) => { + const pmId = parseInt(key); + acc[pmId] = { + ...row.metadataDict[pmId], + useSuggested: true, + }; + return acc; + }, + {} as MetadataExtractionResultRow["metadataDict"], + ), + }; + } + return row; + }); + }); + }; + + // dialog actions + const dispatch = useAppDispatch(); + const handleClose = () => { + dispatch(CRUDDialogActions.closeLLMDialog()); + }; + + const updateBulkMetadataMutation = SdocMetadataHooks.useUpdateBulkMetadata(); + const handleUpdateBulkMetadata = () => { + // find all the metadata where the useSuggested flag is true + const metadataToUpdate: SourceDocumentMetadataBulkUpdate[] = theRows.reduce((acc, row) => { + for (const metadata of Object.values(row.metadataDict)) { + if (metadata.useSuggested && metadata.suggestedValue) { + acc.push({ + id: metadata.suggestedValue.id, + boolean_value: metadata.suggestedValue.boolean_value, + date_value: metadata.suggestedValue.date_value, + int_value: metadata.suggestedValue.int_value, + list_value: metadata.suggestedValue.list_value, + str_value: metadata.suggestedValue.str_value, + }); + } + } + return acc; + }, [] as SourceDocumentMetadataBulkUpdate[]); + + // update the metadata + updateBulkMetadataMutation.mutate( + { + requestBody: metadataToUpdate, + }, + { + onSuccess() { + dispatch(CRUDDialogActions.closeLLMDialog()); + }, + }, + ); + }; + + // columns + const columns = useMemo(() => { + const result: MRT_ColumnDef[] = [ + { + id: "Filename", + header: "Document", + Cell: ({ row }) => , + }, + ]; + + for (const projectMetadata of Object.values(projectMetadataDict)) { + result.push({ + id: `${projectMetadata.id.toString()}-current`, + header: `${projectMetadata.key} (current)`, + muiTableBodyCellProps: ({ row }) => { + const isSelected = !row.original.metadataDict[projectMetadata.id].useSuggested; + return { + sx: { + bgcolor: isSelected ? "success.light" : null, + color: isSelected ? "success.contrastText" : null, + "&:hover": { + bgcolor: isSelected ? "success.light" : "#9e9e9e", + }, + cursor: "pointer", + }, + onClick: handleSelectCell(row.original.sdocId, projectMetadata.id), + }; + }, + Cell: ({ row }) => { + const metadata = row.original.metadataDict[projectMetadata.id]; + return ; + }, + }); + result.push({ + id: `${projectMetadata.id.toString()}-suggestion`, + header: `${projectMetadata.key} (suggested)`, + muiTableBodyCellProps: ({ row }) => { + const isSelected = row.original.metadataDict[projectMetadata.id].useSuggested; + return { + sx: { + bgcolor: isSelected ? "success.light" : null, + color: isSelected ? "success.contrastText" : null, + "&:hover": { + bgcolor: isSelected ? "success.light" : "#9e9e9e", + }, + cursor: "pointer", + }, + onClick: handleSelectCell(row.original.sdocId, projectMetadata.id), + }; + }, + Cell: ({ row }) => { + const metadata = row.original.metadataDict[projectMetadata.id]; + return metadata.suggestedValue ? ( + + ) : ( + <>empty + ); + }, + }); + } + + return result; + }, [projectMetadataDict]); + + // table + const table = useMaterialReactTable({ + data: theRows, + columns: columns, + getRowId: (row) => `${row.sdocId}`, + // state + state: { + rowSelection: rowSelectionModel, + }, + // selection + enableRowSelection: true, + positionToolbarAlertBanner: "bottom", + onRowSelectionChange: setRowSelectionModel, + // style + muiTablePaperProps: { + elevation: 0, + style: { height: "100%", display: "flex", flexDirection: "column" }, + }, + muiTableContainerProps: { + style: { flexGrow: 1 }, + }, + // virtualization (scrolling instead of pagination) + enablePagination: false, + enableRowVirtualization: true, + // hide columns per default + initialState: { + columnVisibility: { + id: false, + }, + }, + // toolbars + enableBottomToolbar: true, + renderTopToolbarCustomActions: ({ table }) => ( + + + Strategy: + + + + + ), + renderToolbarInternalActions: ({ table }) => ( + + + + + ), + renderBottomToolbarCustomActions: () => ( + + + + } + onClick={handleUpdateBulkMetadata} + loading={updateBulkMetadataMutation.isPending} + loadingPosition="start" + > + Update metadata + + + ), + }); + + return ; +} + +export default MetadataExtractionResultStepTable; diff --git a/frontend/src/components/LLMDialog/steps/MethodSelectionStep.tsx b/frontend/src/components/LLMDialog/steps/MethodSelectionStep.tsx new file mode 100644 index 000000000..72ac2c181 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/MethodSelectionStep.tsx @@ -0,0 +1,75 @@ +import { + Button, + Card, + CardActionArea, + CardContent, + DialogActions, + DialogContent, + Stack, + Typography, +} from "@mui/material"; +import { LLMJobType } from "../../../api/openapi/models/LLMJobType.ts"; +import { useAppDispatch } from "../../../plugins/ReduxHooks.ts"; +import { CRUDDialogActions } from "../../dialogSlice.ts"; +import LLMUtterance from "./LLMUtterance.tsx"; + +function MethodSelectionStep() { + const dispatch = useAppDispatch(); + const selectMethod = (method: LLMJobType) => () => { + dispatch(CRUDDialogActions.llmDialogGoToDataSelection({ method })); + }; + const handleClose = () => { + dispatch(CRUDDialogActions.closeLLMDialog()); + }; + + return ( + <> + + + How can I help you? + + + + + + + + + + + + ); +} + +interface MethodButtonProps { + onClick: () => void; + headline: string; + description: string; +} + +function MethodButton({ onClick, headline, description }: MethodButtonProps) { + return ( + + + +

    {headline}

    + {description} +
    +
    +
    + ); +} + +export default MethodSelectionStep; diff --git a/frontend/src/components/LLMDialog/steps/ProjectMetadataSelectionStep.tsx b/frontend/src/components/LLMDialog/steps/ProjectMetadataSelectionStep.tsx new file mode 100644 index 000000000..1d097e3bb --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/ProjectMetadataSelectionStep.tsx @@ -0,0 +1,103 @@ +import PlayCircleIcon from "@mui/icons-material/PlayCircle"; +import { LoadingButton } from "@mui/lab"; +import { Box, Button, DialogActions, DialogContent, Typography } from "@mui/material"; +import { MRT_RowSelectionState } from "material-react-table"; +import { useMemo, useState } from "react"; +import LLMHooks from "../../../api/LLMHooks.ts"; +import ProjectHooks from "../../../api/ProjectHooks.ts"; +import { DocType } from "../../../api/openapi/models/DocType.ts"; +import { LLMJobType } from "../../../api/openapi/models/LLMJobType.ts"; +import { ProjectMetadataRead } from "../../../api/openapi/models/ProjectMetadataRead.ts"; +import { useAppDispatch, useAppSelector } from "../../../plugins/ReduxHooks.ts"; +import ProjectMetadataTable from "../../Metadata/ProjectMetadataTable.tsx"; +import { CRUDDialogActions } from "../../dialogSlice.ts"; +import LLMUtterance from "./LLMUtterance.tsx"; + +function ProjectMetadataSelectionStep({ projectId }: { projectId: number }) { + // local state + const [rowSelectionModel, setRowSelectionModel] = useState({}); + + // global state + const selectedDocuments = useAppSelector((state) => state.dialog.llmDocumentIds); + const dispatch = useAppDispatch(); + + // global server state + const projectMetadata = ProjectHooks.useGetMetadata(projectId); + const filteredProjectMetadata = useMemo(() => { + if (!projectMetadata.data) return []; + return projectMetadata.data.filter((metadata) => metadata.doctype === DocType.TEXT && metadata.read_only === false); + }, [projectMetadata.data]); + + // initiate next step (get the generated prompts) + const createPromptTemplatesMutation = LLMHooks.useCreatePromptTemplates(); + const handleNext = (projectMetadata: ProjectMetadataRead[]) => () => { + createPromptTemplatesMutation.mutate( + { + requestBody: { + llm_job_type: LLMJobType.METADATA_EXTRACTION, + project_id: projectId, + prompts: [], + specific_llm_job_parameters: { + llm_job_type: LLMJobType.METADATA_EXTRACTION, + project_metadata_ids: projectMetadata.map((metadata) => metadata.id), + sdoc_ids: selectedDocuments, + }, + }, + }, + { + onSuccess(data) { + dispatch( + CRUDDialogActions.llmDialogGoToPromptEditor({ + prompts: data, + tags: [], + metadata: projectMetadata, + codes: [], + }), + ); + }, + }, + ); + }; + + return ( + <> + + + + You selected {selectedDocuments.length} document(s) for automatic metadata extraction. Please select all + metadata that I should try to extract from the documents. + + + + ( + + + + } + loading={createPromptTemplatesMutation.isPending} + loadingPosition="start" + disabled={props.selectedProjectMetadata.length === 0} + onClick={handleNext(props.selectedProjectMetadata)} + > + Next! + + + )} + /> + + ); +} + +export default ProjectMetadataSelectionStep; diff --git a/frontend/src/components/LLMDialog/steps/PromptEditorStep.tsx b/frontend/src/components/LLMDialog/steps/PromptEditorStep.tsx new file mode 100644 index 000000000..14fdde833 --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/PromptEditorStep.tsx @@ -0,0 +1,171 @@ +import { ErrorMessage } from "@hookform/error-message"; +import PlayCircleIcon from "@mui/icons-material/PlayCircle"; +import { LoadingButton, TabContext, TabList, TabPanel } from "@mui/lab"; +import { Box, Button, DialogActions, DialogContent, Stack, Tab, Typography } from "@mui/material"; +import { useState } from "react"; +import { SubmitErrorHandler, SubmitHandler, useForm } from "react-hook-form"; +import LLMHooks from "../../../api/LLMHooks.ts"; +import { LLMPromptTemplates } from "../../../api/openapi/models/LLMPromptTemplates.ts"; +import { useAppDispatch, useAppSelector } from "../../../plugins/ReduxHooks.ts"; +import { CRUDDialogActions } from "../../dialogSlice.ts"; +import FormTextMultiline from "../../FormInputs/FormTextMultiline.tsx"; +import LLMUtterance from "./LLMUtterance.tsx"; + +type PromptEditorValues = { + systemPrompt: string; + userPrompt: string; +}; + +function PromptEditorStep({ projectId }: { projectId: number }) { + // global state + const tags = useAppSelector((state) => state.dialog.llmTags); + const metadata = useAppSelector((state) => state.dialog.llmMetadata); + const codes = useAppSelector((state) => state.dialog.llmCodes); + const method = useAppSelector((state) => state.dialog.llmMethod); + const sdocIds = useAppSelector((state) => state.dialog.llmDocumentIds); + const prompts = useAppSelector((state) => state.dialog.llmPrompts); + const dispatch = useAppDispatch(); + + // local state (to manage tabs) + const [tab, setTab] = useState(prompts[0].language); + const handleChangeTab = (_: React.SyntheticEvent, newValue: string) => { + setTab(newValue); + }; + + // react form handlers + const handleChangePrompt = (language: string) => (formData: PromptEditorValues) => { + dispatch( + CRUDDialogActions.updateLLMPrompts({ + language: language, + systemPrompt: formData.systemPrompt, + userPrompt: formData.userPrompt, + }), + ); + }; + + // start llm job + const startLLMJobMutation = LLMHooks.useStartLLMJob(); + const handleStartLLMJob = () => { + if (method === undefined) return; + + startLLMJobMutation.mutate( + { + requestBody: { + project_id: projectId, + prompts: prompts, + llm_job_type: method, + specific_llm_job_parameters: { + llm_job_type: method, + sdoc_ids: sdocIds, + tag_ids: tags.map((tag) => tag.id), + project_metadata_ids: metadata.map((m) => m.id), + code_ids: codes.map((code) => code.id), + }, + }, + }, + { + onSuccess: (data) => { + dispatch( + CRUDDialogActions.llmDialogGoToWaiting({ + jobId: data.id, + method: data.parameters.llm_job_type, + }), + ); + }, + }, + ); + }; + + return ( + <> + + + + These are my generated commands. Now is your last chance to edit them, before I get to work. + + + + + + {prompts.map((prompt) => ( + + ))} + + + {prompts.map((prompt) => ( + + + + ))} + + + + + } + loading={startLLMJobMutation.isPending} + loadingPosition="start" + onClick={handleStartLLMJob} + > + Start! + + + + ); +} + +function PromptEditorStepForm({ + prompt, + handleSavePrompt, +}: { + prompt: LLMPromptTemplates; + handleSavePrompt: SubmitHandler; +}) { + // react form + const { + handleSubmit, + formState: { errors }, + control, + } = useForm({ + defaultValues: { + systemPrompt: prompt.system_prompt, + userPrompt: prompt.user_prompt, + }, + }); + + // react form handlers + const handleError: SubmitErrorHandler = (data) => console.error(data); + + return ( + + , + variant: "outlined", + minRows: 2, + onBlur: () => handleSubmit(handleSavePrompt, handleError)(), + }} + /> + , + variant: "outlined", + onBlur: () => handleSubmit(handleSavePrompt, handleError)(), + }} + /> + + ); +} + +export default PromptEditorStep; diff --git a/frontend/src/components/LLMDialog/steps/StatusStep.tsx b/frontend/src/components/LLMDialog/steps/StatusStep.tsx new file mode 100644 index 000000000..7e5a6802a --- /dev/null +++ b/frontend/src/components/LLMDialog/steps/StatusStep.tsx @@ -0,0 +1,76 @@ +import { Button, DialogActions, DialogContent, Stack, Typography } from "@mui/material"; +import LLMHooks from "../../../api/LLMHooks.ts"; +import { BackgroundJobStatus } from "../../../api/openapi/models/BackgroundJobStatus.ts"; +import { useAppDispatch, useAppSelector } from "../../../plugins/ReduxHooks.ts"; +import { CRUDDialogActions } from "../../dialogSlice.ts"; +import LinearProgressWithLabel from "../../LinearProgressWithLabel.tsx"; +import LLMUtterance from "./LLMUtterance.tsx"; + +function StatusStep() { + // global state + const llmJobId = useAppSelector((state) => state.dialog.llmJobId); + const dispatch = useAppDispatch(); + + // poll the job + const llmJob = LLMHooks.usePollLLMJob(llmJobId, undefined); + + const handleClose = () => { + dispatch(CRUDDialogActions.closeLLMDialog()); + }; + + const handleNext = () => { + if (llmJob.data && llmJob.data.status === BackgroundJobStatus.FINISHED && llmJob.data.result) { + dispatch(CRUDDialogActions.llmDialogGoToResult({ result: llmJob.data.result })); + } else { + console.error("Job is not finished yet."); + } + }; + + return ( + <> + + + + I am working hard! Please wait ... + + + + + {llmJob.isSuccess && llmJob.data.status === BackgroundJobStatus.FINISHED && ( + + + I am done with {llmJob.data.parameters.llm_job_type.toLowerCase()}. You can view the results now! + + + )} + + + This may take a while. You can close the dialog and come back later. You can find all active LLM jobs in + Project Settings > Background Tasks. + + + + + + + + + ); +} + +export default StatusStep; diff --git a/frontend/src/components/LLMDialog/useOpenLLMDialog.ts b/frontend/src/components/LLMDialog/useOpenLLMDialog.ts new file mode 100644 index 000000000..d1e4aea35 --- /dev/null +++ b/frontend/src/components/LLMDialog/useOpenLLMDialog.ts @@ -0,0 +1,17 @@ +import { useCallback } from "react"; +import { useAppDispatch } from "../../plugins/ReduxHooks.ts"; +import { CRUDDialogActions } from "../dialogSlice.ts"; +import { LLMAssistanceEvent } from "./LLMEvent.ts"; + +export const useOpenLLMDialog = () => { + const dispatch = useAppDispatch(); + + const openLLMDialog = useCallback( + (event: LLMAssistanceEvent) => { + dispatch(CRUDDialogActions.openLLMDialog({ event })); + }, + [dispatch], + ); + + return openLLMDialog; +}; diff --git a/frontend/src/components/Metadata/ProjectMetadataTable.tsx b/frontend/src/components/Metadata/ProjectMetadataTable.tsx new file mode 100644 index 000000000..be94a6f03 --- /dev/null +++ b/frontend/src/components/Metadata/ProjectMetadataTable.tsx @@ -0,0 +1,156 @@ +import { + MRT_ColumnDef, + MRT_RowSelectionState, + MRT_TableInstance, + MRT_TableOptions, + MaterialReactTable, + useMaterialReactTable, +} from "material-react-table"; +import { useMemo } from "react"; +import { ProjectMetadataRead } from "../../api/openapi/models/ProjectMetadataRead.ts"; +import ProjectHooks from "../../api/ProjectHooks.ts"; + +const columns: MRT_ColumnDef[] = [ + { + accessorKey: "id", + header: "ID", + }, + { + accessorKey: "key", + header: "Metadata", + }, + { + accessorKey: "metatype", + header: "Type", + }, + { + accessorKey: "description", + header: "Description", + }, +]; + +export interface ProjectMetadataTableActionProps { + table: MRT_TableInstance; + selectedProjectMetadata: ProjectMetadataRead[]; +} + +interface SharedProjectMetadataTableProps { + projectId: number; + // selection + enableMultiRowSelection?: boolean; + rowSelectionModel: MRT_RowSelectionState; + onRowSelectionChange: MRT_TableOptions["onRowSelectionChange"]; + // toolbar + renderToolbarInternalActions?: (props: ProjectMetadataTableActionProps) => React.ReactNode; + renderTopToolbarCustomActions?: (props: ProjectMetadataTableActionProps) => React.ReactNode; + renderBottomToolbarCustomActions?: (props: ProjectMetadataTableActionProps) => React.ReactNode; +} + +interface ProjectMetadataTableProps extends SharedProjectMetadataTableProps { + projectMetadata?: ProjectMetadataRead[]; +} + +function ProjectMetadataTable(props: ProjectMetadataTableProps) { + const projectMetadata = props.projectMetadata; + if (projectMetadata) { + return ; + } else { + return ; + } +} + +function ProjectMetadataTableWithoutMetadata(props: SharedProjectMetadataTableProps) { + // global server state + const projectMetadata = ProjectHooks.useGetMetadata(props.projectId); + + if (projectMetadata.isSuccess) { + return ; + } else { + return null; + } +} + +interface ProjectMetadataTableContentProps extends SharedProjectMetadataTableProps { + projectMetadata: ProjectMetadataRead[]; +} + +function ProjectMetadataTableContent({ + projectMetadata, + enableMultiRowSelection = true, + rowSelectionModel, + onRowSelectionChange, + renderToolbarInternalActions, + renderTopToolbarCustomActions, + renderBottomToolbarCustomActions, +}: ProjectMetadataTableContentProps) { + // computed + const projectMetadataMap = useMemo(() => { + const projectMetadataMap = projectMetadata.reduce( + (acc, projectMetadata) => { + acc[projectMetadata.id.toString()] = projectMetadata; + return acc; + }, + {} as Record, + ); + + return projectMetadataMap; + }, [projectMetadata]); + + // table + const table = useMaterialReactTable({ + data: projectMetadata, + columns: columns, + getRowId: (row) => `${row.id}`, + // style + muiTablePaperProps: { + elevation: 0, + style: { height: "100%", display: "flex", flexDirection: "column" }, + }, + muiTableContainerProps: { + style: { flexGrow: 1 }, + }, + // state + state: { + rowSelection: rowSelectionModel, + }, + // virtualization (scrolling instead of pagination) + enablePagination: false, + enableRowVirtualization: true, + // selection + enableRowSelection: true, + enableMultiRowSelection, + onRowSelectionChange, + // toolbar + enableBottomToolbar: true, + renderTopToolbarCustomActions: renderTopToolbarCustomActions + ? (props) => + renderTopToolbarCustomActions({ + table: props.table, + selectedProjectMetadata: Object.keys(rowSelectionModel).map((mId) => projectMetadataMap[mId]), + }) + : undefined, + renderToolbarInternalActions: renderToolbarInternalActions + ? (props) => + renderToolbarInternalActions({ + table: props.table, + selectedProjectMetadata: Object.values(projectMetadataMap).filter((row) => rowSelectionModel[row.id]), + }) + : undefined, + renderBottomToolbarCustomActions: renderBottomToolbarCustomActions + ? (props) => + renderBottomToolbarCustomActions({ + table: props.table, + selectedProjectMetadata: Object.values(projectMetadataMap).filter((row) => rowSelectionModel[row.id]), + }) + : undefined, + // hide columns per default + initialState: { + columnVisibility: { + id: false, + }, + }, + }); + + return ; +} +export default ProjectMetadataTable; diff --git a/frontend/src/components/Metadata/SdocMetadataRenderer.tsx b/frontend/src/components/Metadata/SdocMetadataRenderer.tsx index 9882b7805..1e6a1f66f 100644 --- a/frontend/src/components/Metadata/SdocMetadataRenderer.tsx +++ b/frontend/src/components/Metadata/SdocMetadataRenderer.tsx @@ -28,7 +28,7 @@ function SdocMetadataRenderer({ sdocId, projectMetadataId }: SdocMetadataRendere } } -function SdocMetadataRendererWithData({ sdocMetadata }: { sdocMetadata: SourceDocumentMetadataReadResolved }) { +export function SdocMetadataRendererWithData({ sdocMetadata }: { sdocMetadata: SourceDocumentMetadataReadResolved }) { switch (sdocMetadata.project_metadata.metatype) { case MetaType.STRING: return <>{sdocMetadata.str_value ? sdocMetadata.str_value : empty}; diff --git a/frontend/src/components/ProjectSettings/backgroundtasks/LLMJobDetailListItem.tsx b/frontend/src/components/ProjectSettings/backgroundtasks/LLMJobDetailListItem.tsx new file mode 100644 index 000000000..51ec037ac --- /dev/null +++ b/frontend/src/components/ProjectSettings/backgroundtasks/LLMJobDetailListItem.tsx @@ -0,0 +1,40 @@ +import ExpandLess from "@mui/icons-material/ExpandLess"; +import ExpandMoreIcon from "@mui/icons-material/ExpandMore"; +import { Collapse, ListItemButton, ListItemText, Tooltip, Typography } from "@mui/material"; +import React from "react"; + +interface LLMJobdetailListItemProps { + detailKey: string; + detailValue: string; +} + +function LLMJobDetailListItem({ detailKey, detailValue }: LLMJobdetailListItemProps) { + // local state + const [expanded, setExpanded] = React.useState(false); + const handleExpandClick = () => { + setExpanded(!expanded); + }; + + return ( + <> + + + + + {detailKey} + + + {expanded ? : } + + + + + + {detailValue} + + + + ); +} + +export default LLMJobDetailListItem; diff --git a/frontend/src/components/ProjectSettings/backgroundtasks/LLMJobListItem.tsx b/frontend/src/components/ProjectSettings/backgroundtasks/LLMJobListItem.tsx new file mode 100644 index 000000000..650381df3 --- /dev/null +++ b/frontend/src/components/ProjectSettings/backgroundtasks/LLMJobListItem.tsx @@ -0,0 +1,116 @@ +import { TabContext, TabList, TabPanel } from "@mui/lab"; +import { Box, Button, Stack, Tab, TextField } from "@mui/material"; +import { useState } from "react"; +import LLMHooks from "../../../api/LLMHooks.ts"; +import { BackgroundJobStatus } from "../../../api/openapi/models/BackgroundJobStatus.ts"; +import { LLMJobRead } from "../../../api/openapi/models/LLMJobRead.ts"; +import { LLMPromptTemplates } from "../../../api/openapi/models/LLMPromptTemplates.ts"; +import { useAppDispatch } from "../../../plugins/ReduxHooks.ts"; +import { dateToLocaleString } from "../../../utils/DateUtils.ts"; +import { CRUDDialogActions } from "../../dialogSlice.ts"; +import BackgroundJobListItem from "./BackgroundJobListItem.tsx"; + +interface LLMJobListItemProps { + initialLLMJob: LLMJobRead; +} + +function LLMJobListItem({ initialLLMJob }: LLMJobListItemProps) { + // global server state (react-query) + const llmJob = LLMHooks.usePollLLMJob(initialLLMJob.id, initialLLMJob); + + // compute date sting + const createdDate = dateToLocaleString(llmJob.data!.created); + const updatedDate = dateToLocaleString(llmJob.data!.updated); + let subTitle = `${ + llmJob.data!.parameters.specific_llm_job_parameters.sdoc_ids.length + } documents, started at ${createdDate}`; + if (llmJob.data!.status === BackgroundJobStatus.FINISHED) { + subTitle += `, finished at ${updatedDate}`; + } else if (llmJob.data!.status === BackgroundJobStatus.ABORTED) { + subTitle += `, aborted at ${updatedDate}`; + } else if (llmJob.data!.status === BackgroundJobStatus.ERRORNEOUS) { + subTitle += `, failed at ${updatedDate}`; + } + + // actions + const dispatch = useAppDispatch(); + const handleViewResults = () => { + dispatch(CRUDDialogActions.closeProjectSettings()); + dispatch( + CRUDDialogActions.llmDialogGoToWaiting({ + jobId: initialLLMJob.id, + method: initialLLMJob.parameters.llm_job_type, + }), + ); + }; + + if (llmJob.isSuccess) { + return ( + + + {llmJob.data.status === BackgroundJobStatus.FINISHED ? ( + + ) : llmJob.data.status === BackgroundJobStatus.RUNNING ? ( + + ) : null} + + + + ); + } else { + return null; + } +} + +function PromptViewer({ prompts }: { prompts: LLMPromptTemplates[] }) { + // tab state + const [tab, setTab] = useState(prompts[0].language); + const handleChangeTab = (_: React.SyntheticEvent, newValue: string) => { + setTab(newValue); + }; + + return ( + + + + {prompts.map((prompt) => ( + + ))} + + + {prompts.map((prompt) => ( + + + + + + + ))} + + ); +} + +export default LLMJobListItem; diff --git a/frontend/src/components/ProjectSettings/backgroundtasks/ProjectBackgroundTasks.tsx b/frontend/src/components/ProjectSettings/backgroundtasks/ProjectBackgroundTasks.tsx index e5e279c7b..375b6543c 100644 --- a/frontend/src/components/ProjectSettings/backgroundtasks/ProjectBackgroundTasks.tsx +++ b/frontend/src/components/ProjectSettings/backgroundtasks/ProjectBackgroundTasks.tsx @@ -1,14 +1,28 @@ import { Box, Divider, List, Toolbar, Typography } from "@mui/material"; import { useMemo } from "react"; import CrawlerHooks from "../../../api/CrawlerHooks.ts"; +import LLMHooks from "../../../api/LLMHooks.ts"; import PreProHooks from "../../../api/PreProHooks.ts"; import { BackgroundJobStatus } from "../../../api/openapi/models/BackgroundJobStatus.ts"; import { CrawlerJobRead } from "../../../api/openapi/models/CrawlerJobRead.ts"; +import { LLMJobRead } from "../../../api/openapi/models/LLMJobRead.ts"; import { PreprocessingJobRead } from "../../../api/openapi/models/PreprocessingJobRead.ts"; import { ProjectRead } from "../../../api/openapi/models/ProjectRead.ts"; import CrawlerJobListItem from "./CrawlerJobListItem.tsx"; +import LLMJobListItem from "./LLMJobListItem.tsx"; import PreProJobListItem from "./PreProJobListItem.tsx"; +// type guards +const isCrawlerJob = (job: CrawlerJobRead | PreprocessingJobRead | LLMJobRead): job is CrawlerJobRead => { + return "output_dir" in job; +}; +const isPreProJob = (job: CrawlerJobRead | PreprocessingJobRead | LLMJobRead): job is PreprocessingJobRead => { + return "payloads" in job; +}; +const isLLMJob = (job: CrawlerJobRead | PreprocessingJobRead | LLMJobRead): job is LLMJobRead => { + return "num_steps_finished" in job; +}; + interface ProjectBackgroundTasksProps { project: ProjectRead; } @@ -17,9 +31,10 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { // global server state (react-query) const crawlerJobs = CrawlerHooks.useGetAllCrawlerJobs(project.id); const preProJobs = PreProHooks.useGetAllPreProJobs(project.id); + const llmJobs = LLMHooks.useGetAllLLMJobs(project.id); const backgroundJobsByStatus = useMemo(() => { - const result: Record = { + const result: Record = { [BackgroundJobStatus.WAITING]: [], [BackgroundJobStatus.RUNNING]: [], [BackgroundJobStatus.FINISHED]: [], @@ -27,7 +42,7 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { [BackgroundJobStatus.ABORTED]: [], }; - if (!crawlerJobs.data && !preProJobs.data) { + if (!crawlerJobs.data && !preProJobs.data && !llmJobs.data) { return result; } @@ -43,16 +58,43 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { result[job.status].push(job); } } + if (llmJobs.data) { + for (const job of llmJobs.data) { + if (!job.status) continue; + result[job.status].push(job); + } + } return result; - }, [crawlerJobs.data, preProJobs.data]); + }, [crawlerJobs.data, preProJobs.data, llmJobs.data]); + + // rendering + const renderBackgroundJobs = (status: BackgroundJobStatus) => { + return ( + <> + {backgroundJobsByStatus[status].map((job) => { + if (isCrawlerJob(job)) { + return ; + } else if (isPreProJob(job)) { + return ; + } else if (isLLMJob(job)) { + return ; + } else { + return null; + } + })} + {backgroundJobsByStatus[status].length === 0 && empty} + + ); + }; return ( <> - {(crawlerJobs.isLoading || preProJobs.isLoading) && <>Loading background jobs...} + {(crawlerJobs.isLoading || preProJobs.isLoading || llmJobs.isLoading) && <>Loading background jobs...} {crawlerJobs.isError && <>An error occurred while loading crawler jobs for project {project.id}...} {preProJobs.isError && <>An error occurred while loading preprocessing jobs for project {project.id}...} - {crawlerJobs.isSuccess && preProJobs.isSuccess && ( + {llmJobs.isError && <>An error occurred while loading llm jobs for project {project.id}...} + {crawlerJobs.isSuccess && preProJobs.isSuccess && llmJobs.isSuccess && ( <> @@ -61,18 +103,7 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { - - {backgroundJobsByStatus[BackgroundJobStatus.WAITING].map((job) => { - if ("parameters" in job) { - return ; - } else if ("payloads" in job) { - return ; - } else { - return null; - } - })} - {backgroundJobsByStatus[BackgroundJobStatus.WAITING].length === 0 && empty} - + {renderBackgroundJobs(BackgroundJobStatus.WAITING)} Running @@ -80,18 +111,7 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { - - {backgroundJobsByStatus[BackgroundJobStatus.RUNNING].map((job) => { - if ("parameters" in job) { - return ; - } else if ("payloads" in job) { - return ; - } else { - return null; - } - })} - {backgroundJobsByStatus[BackgroundJobStatus.RUNNING].length === 0 && empty} - + {renderBackgroundJobs(BackgroundJobStatus.RUNNING)} Finished @@ -99,18 +119,7 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { - - {backgroundJobsByStatus[BackgroundJobStatus.FINISHED].map((job) => { - if ("parameters" in job) { - return ; - } else if ("payloads" in job) { - return ; - } else { - return null; - } - })} - {backgroundJobsByStatus[BackgroundJobStatus.FINISHED].length === 0 && empty} - + {renderBackgroundJobs(BackgroundJobStatus.FINISHED)} Aborted @@ -118,18 +127,7 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { - - {backgroundJobsByStatus[BackgroundJobStatus.ABORTED].map((job) => { - if ("parameters" in job) { - return ; - } else if ("payloads" in job) { - return ; - } else { - return null; - } - })} - {backgroundJobsByStatus[BackgroundJobStatus.ABORTED].length === 0 && empty} - + {renderBackgroundJobs(BackgroundJobStatus.ABORTED)} Failed @@ -137,20 +135,7 @@ function ProjectBackgroundTasks({ project }: ProjectBackgroundTasksProps) { - - {backgroundJobsByStatus[BackgroundJobStatus.ERRORNEOUS].map((job) => { - if ("parameters" in job) { - return ; - } else if ("payloads" in job) { - return ; - } else { - return null; - } - })} - {backgroundJobsByStatus[BackgroundJobStatus.ERRORNEOUS].length === 0 && ( - empty - )} - + {renderBackgroundJobs(BackgroundJobStatus.ERRORNEOUS)} )} diff --git a/frontend/src/components/SourceDocument/DocumentInformation/MetadataCreateButton.tsx b/frontend/src/components/SourceDocument/DocumentInformation/MetadataCreateButton.tsx index 979eccb6f..017c5df3b 100644 --- a/frontend/src/components/SourceDocument/DocumentInformation/MetadataCreateButton.tsx +++ b/frontend/src/components/SourceDocument/DocumentInformation/MetadataCreateButton.tsx @@ -45,6 +45,7 @@ function MetadataCreateButton({ sdocId }: MetadataCreateButtonProps) { key: `${metaType.toLowerCase()} (new)`, project_id: projectId, read_only: false, + description: "Placeholder description", }, }, { diff --git a/frontend/src/components/SourceDocument/DocumentInformation/MetadataEditMenu.tsx b/frontend/src/components/SourceDocument/DocumentInformation/MetadataEditMenu.tsx index f45d1245d..dfbb2db7b 100644 --- a/frontend/src/components/SourceDocument/DocumentInformation/MetadataEditMenu.tsx +++ b/frontend/src/components/SourceDocument/DocumentInformation/MetadataEditMenu.tsx @@ -10,7 +10,9 @@ import { ListItemText, Popover, PopoverPosition, + Stack, TextField, + Tooltip, } from "@mui/material"; import React, { useCallback, useState } from "react"; import ProjectMetadataHooks from "../../../api/ProjectMetadataHooks.ts"; @@ -36,12 +38,18 @@ function MetadataEditMenu({ metadata }: MetadataEditMenuProps) { }); }; - // rename + // metadata name const [name, setName] = useState(metadata.project_metadata.key); const handleChangeName = (event: React.ChangeEvent) => { setName(event.target.value); }; + // metadata description + const [description, setDescription] = useState(metadata.project_metadata.description); + const handleChangeDescription = (event: React.ChangeEvent) => { + setDescription(event.target.value); + }; + // change type const [isTypeMenuOpen, setIsTypeMenuOpen] = useState(false); const [metatype, setMetatype] = useState(metadata.project_metadata.metatype); @@ -56,7 +64,11 @@ function MetadataEditMenu({ metadata }: MetadataEditMenuProps) { setPosition(undefined); // only update if data has changed! - if (metadata.project_metadata.metatype !== metatype || metadata.project_metadata.key !== name) { + if ( + metadata.project_metadata.metatype !== metatype || + metadata.project_metadata.key !== name || + metadata.project_metadata.description !== description + ) { const mutation = updateMutation.mutate; const actuallyMutate = () => mutation({ @@ -64,6 +76,7 @@ function MetadataEditMenu({ metadata }: MetadataEditMenuProps) { requestBody: { metatype: metatype, key: name, + description: description, }, }); if (metadata.project_metadata.metatype !== metatype) { @@ -107,15 +120,18 @@ function MetadataEditMenu({ metadata }: MetadataEditMenuProps) { return ( <> - + + + + + - + + + + setIsTypeMenuOpen(true)}> Type diff --git a/frontend/src/components/SourceDocument/DocumentInformation/MetadataTypeSelectorMenu.tsx b/frontend/src/components/SourceDocument/DocumentInformation/MetadataTypeSelectorMenu.tsx index e09e651dd..a3e8ccff3 100644 --- a/frontend/src/components/SourceDocument/DocumentInformation/MetadataTypeSelectorMenu.tsx +++ b/frontend/src/components/SourceDocument/DocumentInformation/MetadataTypeSelectorMenu.tsx @@ -72,7 +72,7 @@ function MetadataTypeSelectorMenu({ }} slotProps={{ paper: { - sx: { minHeight: "201px" }, + sx: { minHeight: "201px", width: "240px" }, }, }} renderOption={(props, option) => ( @@ -81,7 +81,7 @@ function MetadataTypeSelectorMenu({ {option} )} - sx={{ width: 230 }} + sx={{ width: 240 }} renderInput={(params) => ( )} diff --git a/frontend/src/components/dialogSlice.ts b/frontend/src/components/dialogSlice.ts index 4ee8d99c2..749b7febb 100644 --- a/frontend/src/components/dialogSlice.ts +++ b/frontend/src/components/dialogSlice.ts @@ -2,8 +2,14 @@ import { AlertProps } from "@mui/material"; import { PayloadAction, createSlice } from "@reduxjs/toolkit/react"; import { BBoxAnnotationReadResolvedCode } from "../api/openapi/models/BBoxAnnotationReadResolvedCode.ts"; import { CodeRead } from "../api/openapi/models/CodeRead.ts"; +import { DocumentTagRead } from "../api/openapi/models/DocumentTagRead.ts"; +import { LLMJobResult } from "../api/openapi/models/LLMJobResult.ts"; +import { LLMJobType } from "../api/openapi/models/LLMJobType.ts"; +import { LLMPromptTemplates } from "../api/openapi/models/LLMPromptTemplates.ts"; +import { ProjectMetadataRead } from "../api/openapi/models/ProjectMetadataRead.ts"; import { SnackbarEvent } from "../components/SnackbarDialog/SnackbarEvent.ts"; import { CodeCreateSuccessHandler } from "./Code/CodeCreateDialog.tsx"; +import { LLMAssistanceEvent } from "./LLMDialog/LLMEvent.ts"; interface DialogState { // tags @@ -31,6 +37,17 @@ interface DialogState { snackbarData: SnackbarEvent; // project settings isProjectSettingsOpen: boolean; + // llm dialog + isLLMDialogOpen: boolean; + llmMethod?: LLMJobType; + llmDocumentIds: number[]; + llmStep: number; + llmTags: DocumentTagRead[]; + llmMetadata: ProjectMetadataRead[]; + llmCodes: CodeRead[]; + llmPrompts: LLMPromptTemplates[]; + llmJobId?: string; + llmJobResult: LLMJobResult | null | undefined; } const initialState: DialogState = { @@ -63,6 +80,17 @@ const initialState: DialogState = { }, // project settings isProjectSettingsOpen: false, + // llm dialog + isLLMDialogOpen: false, + llmDocumentIds: [], + llmMethod: undefined, + llmStep: 0, + llmTags: [], + llmMetadata: [], + llmCodes: [], + llmPrompts: [], + llmJobId: undefined, + llmJobResult: undefined, }; export const dialogSlice = createSlice({ @@ -151,6 +179,91 @@ export const dialogSlice = createSlice({ closeProjectSettings: (state) => { state.isProjectSettingsOpen = false; }, + // Step 0: Select documents & open the dialog + openLLMDialog: (state, action: PayloadAction<{ event: LLMAssistanceEvent }>) => { + state.isLLMDialogOpen = true; + state.llmDocumentIds = action.payload.event.selectedDocumentIds; + state.llmMethod = action.payload.event.method; + state.llmStep = action.payload.event.method === undefined ? 0 : 1; + }, + // Step 1: Select method + llmDialogGoToDataSelection: (state, action: PayloadAction<{ method: LLMJobType }>) => { + state.llmMethod = action.payload.method; + state.llmStep = 1; + }, + // Step 2: Select tags, metadata, or codes + llmDialogGoToPromptEditor: ( + state, + action: PayloadAction<{ + prompts: LLMPromptTemplates[]; + tags: DocumentTagRead[]; + metadata: ProjectMetadataRead[]; + codes: CodeRead[]; + }>, + ) => { + state.llmStep = 2; + state.llmPrompts = action.payload.prompts; + state.llmTags = action.payload.tags; + state.llmMetadata = action.payload.metadata; + state.llmCodes = action.payload.codes; + }, + // Step 3: Edit the prompts + updateLLMPrompts: ( + state, + action: PayloadAction<{ language: string; systemPrompt: string; userPrompt: string }>, + ) => { + const updatedPrompts = state.llmPrompts.map((prompt) => { + if (prompt.language === action.payload.language) { + return { + ...prompt, + system_prompt: action.payload.systemPrompt, + user_prompt: action.payload.userPrompt, + }; + } + return prompt; + }); + state.llmPrompts = updatedPrompts.slice(); + }, + llmDialogGoToWaiting: (state, action: PayloadAction<{ jobId: string; method: LLMJobType }>) => { + state.isLLMDialogOpen = true; + state.llmStep = 3; + state.llmJobId = action.payload.jobId; + state.llmMethod = action.payload.method; + }, + // Step 4: Wait for the job to finish + llmDialogGoToResult: (state, action: PayloadAction<{ result: LLMJobResult }>) => { + state.llmJobResult = action.payload.result; + state.llmStep = 4; + }, + // close the dialog & reset + closeLLMDialog: (state) => { + state.isLLMDialogOpen = initialState.isLLMDialogOpen; + state.llmDocumentIds = initialState.llmDocumentIds; + state.llmMethod = initialState.llmMethod; + state.llmStep = initialState.llmStep; + state.llmTags = initialState.llmTags; + state.llmMetadata = initialState.llmMetadata; + state.llmCodes = initialState.llmCodes; + state.llmPrompts = initialState.llmPrompts; + state.llmJobId = initialState.llmJobId; + state.llmJobResult = initialState.llmJobResult; + }, + previousLLMDialogStep: (state) => { + state.llmStep -= 1; + if (state.llmStep < 0) { + state.llmStep = 0; + } + // user just selected the method, reset method selection + if (state.llmStep === 0) { + state.llmMethod = initialState.llmMethod; + // user just selected the data, reset data selection + } else if (state.llmStep === 1) { + state.llmPrompts = initialState.llmPrompts; + state.llmTags = initialState.llmTags; + state.llmMetadata = initialState.llmMetadata; + state.llmCodes = initialState.llmCodes; + } + }, }, }); diff --git a/frontend/src/index.css b/frontend/src/index.css index b63cd4d48..73173f750 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -51,3 +51,26 @@ code { .fixAlertBanner .MuiAlert-message { max-width: 100% !important; } + +.speech-bubble { + position: relative; + border: 2px solid #1976d2; + border-radius: 0.4em; + margin-left: 48px; + padding: 16px; +} + +.speech-bubble:after { + content: ""; + position: absolute; + left: 0; + top: 50%; + width: 0; + height: 0; + border: 1.469em solid transparent; + border-right-color: #1976d2; + border-left: 0; + border-top: 0; + margin-top: -0.737em; + margin-left: -1.469em; +} diff --git a/frontend/src/layouts/TwoBarLayout.tsx b/frontend/src/layouts/TwoBarLayout.tsx index e4b0329d2..86d3a97d7 100644 --- a/frontend/src/layouts/TwoBarLayout.tsx +++ b/frontend/src/layouts/TwoBarLayout.tsx @@ -4,6 +4,7 @@ import { Outlet, useParams } from "react-router-dom"; import CodeCreateDialog from "../components/Code/CodeCreateDialog.tsx"; import ConfirmationDialog from "../components/ConfirmationDialog/ConfirmationDialog.tsx"; import ExporterDialog from "../components/Exporter/ExporterDialog.tsx"; +import LLMDialog from "../components/LLMDialog/LLMDialog.tsx"; import MemoDialog from "../components/Memo/MemoDialog/MemoDialog.tsx"; import ProjectSettingsDialog from "../components/ProjectSettings/ProjectSettingsDialog.tsx"; import SnackbarDialog from "../components/SnackbarDialog/SnackbarDialog.tsx"; @@ -40,6 +41,7 @@ function TwoBarLayout() { + ); } else { diff --git a/frontend/src/openapi.json b/frontend/src/openapi.json index f1204e7ea..522752a9c 100644 --- a/frontend/src/openapi.json +++ b/frontend/src/openapi.json @@ -1490,6 +1490,40 @@ "security": [{ "OAuth2PasswordBearer": [] }] } }, + "/doctag/bulk/set": { + "patch": { + "tags": ["documentTag"], + "summary": "Sets SourceDocuments' tags to the provided tags", + "operationId": "set_document_tags_batch", + "requestBody": { + "content": { + "application/json": { + "schema": { + "items": { "$ref": "#/components/schemas/SourceDocumentDocumentTagLinks" }, + "type": "array", + "title": "Links" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "type": "integer", "title": "Response Documenttag-Set Document Tags Batch" } + } + } + }, + "422": { + "description": "Validation Error", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } } + } + }, + "security": [{ "OAuth2PasswordBearer": [] }] + } + }, "/doctag/{tag_id}": { "get": { "tags": ["documentTag"], @@ -2033,6 +2067,60 @@ } } }, + "/span/bulk/create": { + "put": { + "tags": ["spanAnnotation"], + "summary": "Creates a SpanAnnotations in Bulk", + "operationId": "add_span_annotations_bulk", + "security": [{ "OAuth2PasswordBearer": [] }], + "parameters": [ + { + "name": "resolve", + "in": "query", + "required": false, + "schema": { + "type": "boolean", + "title": "Resolve Code", + "description": "If true, the current_code_id of the SpanAnnotation gets resolved and replaced by the respective Code entity", + "default": true + }, + "description": "If true, the current_code_id of the SpanAnnotation gets resolved and replaced by the respective Code entity" + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { "$ref": "#/components/schemas/SpanAnnotationCreateBulkWithCodeId" }, + "title": "Spans" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { "type": "array", "items": { "$ref": "#/components/schemas/SpanAnnotationRead" } }, + { "type": "array", "items": { "$ref": "#/components/schemas/SpanAnnotationReadResolved" } } + ], + "title": "Response Spanannotation-Add Span Annotations Bulk" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } } + } + } + } + }, "/span/{span_id}": { "get": { "tags": ["spanAnnotation"], @@ -3577,6 +3665,44 @@ } } }, + "/sdocmeta/bulk/update": { + "patch": { + "tags": ["sdocMetadata"], + "summary": "Updates multiple metadata objects at once.", + "operationId": "update_bulk", + "requestBody": { + "content": { + "application/json": { + "schema": { + "items": { "$ref": "#/components/schemas/SourceDocumentMetadataBulkUpdate" }, + "type": "array", + "title": "Metadatas" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { "$ref": "#/components/schemas/SourceDocumentMetadataRead" }, + "type": "array", + "title": "Response Sdocmetadata-Update Bulk" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } } + } + }, + "security": [{ "OAuth2PasswordBearer": [] }] + } + }, "/feedback": { "get": { "tags": ["feedback"], @@ -5503,6 +5629,118 @@ } } } + }, + "/llm": { + "post": { + "tags": ["llm"], + "summary": "Returns the LLMJob for the given Parameters", + "operationId": "start_llm_job", + "requestBody": { + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/LLMJobParameters" } } }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/LLMJobRead" } } } + }, + "422": { + "description": "Validation Error", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } } + } + }, + "security": [{ "OAuth2PasswordBearer": [] }] + } + }, + "/llm/{llm_job_id}": { + "get": { + "tags": ["llm"], + "summary": "Returns the LLMJob for the given ID if it exists", + "operationId": "get_llm_job", + "security": [{ "OAuth2PasswordBearer": [] }], + "parameters": [ + { + "name": "llm_job_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Llm Job Id" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/LLMJobRead" } } } + }, + "422": { + "description": "Validation Error", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } } + } + } + } + }, + "/llm/project/{project_id}": { + "get": { + "tags": ["llm"], + "summary": "Returns all LLMJobRead for the given project ID if it exists", + "operationId": "get_all_llm_jobs", + "security": [{ "OAuth2PasswordBearer": [] }], + "parameters": [ + { + "name": "project_id", + "in": "path", + "required": true, + "schema": { "type": "integer", "title": "Project Id" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { "$ref": "#/components/schemas/LLMJobRead" }, + "title": "Response Llm-Get All Llm Jobs" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } } + } + } + } + }, + "/llm/create_prompt_templates": { + "post": { + "tags": ["llm"], + "summary": "Returns the system and user prompt templates for the given llm task in all supported languages", + "operationId": "create_prompt_templates", + "requestBody": { + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/LLMJobParameters" } } }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { "$ref": "#/components/schemas/LLMPromptTemplates" }, + "type": "array", + "title": "Response Llm-Create Prompt Templates" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } } + } + }, + "security": [{ "OAuth2PasswordBearer": [] }] + } } }, "components": { @@ -5750,6 +5988,39 @@ "required": ["source_document_id", "user_id", "id", "created", "updated"], "title": "AnnotationDocumentRead" }, + "AnnotationLLMJobParams": { + "properties": { + "llm_job_type": { "const": "ANNOTATION", "title": "Llm Job Type" }, + "sdoc_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Sdoc Ids", + "description": "IDs of the source documents to analyse" + }, + "code_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Code Ids", + "description": "IDs of the codes to use for the annotation" + } + }, + "type": "object", + "required": ["llm_job_type", "sdoc_ids", "code_ids"], + "title": "AnnotationLLMJobParams" + }, + "AnnotationLLMJobResult": { + "properties": { + "llm_job_type": { "const": "ANNOTATION", "title": "Llm Job Type" }, + "results": { + "items": { "$ref": "#/components/schemas/AnnotationResult" }, + "type": "array", + "title": "Results" + } + }, + "type": "object", + "required": ["llm_job_type", "results"], + "title": "AnnotationLLMJobResult" + }, "AnnotationOccurrence": { "properties": { "annotation": { @@ -5771,6 +6042,20 @@ "required": ["annotation", "code", "sdoc", "text"], "title": "AnnotationOccurrence" }, + "AnnotationResult": { + "properties": { + "sdoc_id": { "type": "integer", "title": "Sdoc Id", "description": "ID of the source document" }, + "suggested_annotations": { + "items": { "$ref": "#/components/schemas/SpanAnnotationReadResolved" }, + "type": "array", + "title": "Suggested Annotations", + "description": "Suggested annotations" + } + }, + "type": "object", + "required": ["sdoc_id", "suggested_annotations"], + "title": "AnnotationResult" + }, "AnnotationTableRow": { "properties": { "id": { "type": "integer", "title": "Id", "description": "ID of the SpanAnnotation" }, @@ -6814,6 +7099,60 @@ "type": "object", "title": "DocumentTagUpdate" }, + "DocumentTaggingLLMJobParams": { + "properties": { + "llm_job_type": { "const": "DOCUMENT_TAGGING", "title": "Llm Job Type" }, + "sdoc_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Sdoc Ids", + "description": "IDs of the source documents to analyse" + }, + "tag_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Tag Ids", + "description": "IDs of the tags to use for the document tagging" + } + }, + "type": "object", + "required": ["llm_job_type", "sdoc_ids", "tag_ids"], + "title": "DocumentTaggingLLMJobParams" + }, + "DocumentTaggingLLMJobResult": { + "properties": { + "llm_job_type": { "const": "DOCUMENT_TAGGING", "title": "Llm Job Type" }, + "results": { + "items": { "$ref": "#/components/schemas/DocumentTaggingResult" }, + "type": "array", + "title": "Results" + } + }, + "type": "object", + "required": ["llm_job_type", "results"], + "title": "DocumentTaggingLLMJobResult" + }, + "DocumentTaggingResult": { + "properties": { + "sdoc_id": { "type": "integer", "title": "Sdoc Id", "description": "ID of the source document" }, + "current_tag_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Current Tag Ids", + "description": "IDs of the tags currently assigned to the document" + }, + "suggested_tag_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Suggested Tag Ids", + "description": "IDs of the tags suggested by the LLM to assign to the document" + }, + "reasoning": { "type": "string", "title": "Reasoning", "description": "Reasoning for the tagging" } + }, + "type": "object", + "required": ["sdoc_id", "current_tag_ids", "suggested_tag_ids", "reasoning"], + "title": "DocumentTaggingResult" + }, "ElasticSearchDocumentHit": { "properties": { "document_id": { "type": "integer", "title": "Document Id", "description": "The ID of the Document" }, @@ -7313,6 +7652,131 @@ "required": ["keyword", "filtered_count", "global_count"], "title": "KeywordStat" }, + "LLMJobParameters": { + "properties": { + "llm_job_type": { + "allOf": [{ "$ref": "#/components/schemas/LLMJobType" }], + "description": "The type of the LLMJob (what to llm)" + }, + "project_id": { "type": "integer", "title": "Project Id", "description": "The ID of the Project to analyse" }, + "prompts": { + "items": { "$ref": "#/components/schemas/LLMPromptTemplates" }, + "type": "array", + "title": "Prompts", + "description": "The prompt templates to use for the job" + }, + "specific_llm_job_parameters": { + "oneOf": [ + { "$ref": "#/components/schemas/DocumentTaggingLLMJobParams" }, + { "$ref": "#/components/schemas/MetadataExtractionLLMJobParams" }, + { "$ref": "#/components/schemas/AnnotationLLMJobParams" } + ], + "title": "Specific Llm Job Parameters", + "description": "Specific parameters for the LLMJob w.r.t it's type", + "discriminator": { + "propertyName": "llm_job_type", + "mapping": { + "ANNOTATION": "#/components/schemas/AnnotationLLMJobParams", + "DOCUMENT_TAGGING": "#/components/schemas/DocumentTaggingLLMJobParams", + "METADATA_EXTRACTION": "#/components/schemas/MetadataExtractionLLMJobParams" + } + } + } + }, + "type": "object", + "required": ["llm_job_type", "project_id", "prompts", "specific_llm_job_parameters"], + "title": "LLMJobParameters" + }, + "LLMJobRead": { + "properties": { + "status": { + "allOf": [{ "$ref": "#/components/schemas/BackgroundJobStatus" }], + "description": "Status of the LLMJob", + "default": "Waiting" + }, + "num_steps_finished": { + "type": "integer", + "title": "Num Steps Finished", + "description": "Number of steps LLMJob has completed." + }, + "num_steps_total": { "type": "integer", "title": "Num Steps Total", "description": "Number of total steps." }, + "result": { + "anyOf": [{ "$ref": "#/components/schemas/LLMJobResult" }, { "type": "null" }], + "description": "Results of hte LLMJob." + }, + "id": { "type": "string", "title": "Id", "description": "ID of the LLMJob" }, + "parameters": { + "allOf": [{ "$ref": "#/components/schemas/LLMJobParameters" }], + "description": "The parameters of the LLMJob that defines what to llm!" + }, + "created": { + "type": "string", + "format": "date-time", + "title": "Created", + "description": "Created timestamp of the LLMJob" + }, + "updated": { + "type": "string", + "format": "date-time", + "title": "Updated", + "description": "Updated timestamp of the LLMJob" + } + }, + "type": "object", + "required": ["num_steps_finished", "num_steps_total", "id", "parameters", "created", "updated"], + "title": "LLMJobRead" + }, + "LLMJobResult": { + "properties": { + "llm_job_type": { + "allOf": [{ "$ref": "#/components/schemas/LLMJobType" }], + "description": "The type of the LLMJob (what to llm)" + }, + "specific_llm_job_result": { + "oneOf": [ + { "$ref": "#/components/schemas/DocumentTaggingLLMJobResult" }, + { "$ref": "#/components/schemas/MetadataExtractionLLMJobResult" }, + { "$ref": "#/components/schemas/AnnotationLLMJobResult" } + ], + "title": "Specific Llm Job Result", + "description": "Specific result for the LLMJob w.r.t it's type", + "discriminator": { + "propertyName": "llm_job_type", + "mapping": { + "ANNOTATION": "#/components/schemas/AnnotationLLMJobResult", + "DOCUMENT_TAGGING": "#/components/schemas/DocumentTaggingLLMJobResult", + "METADATA_EXTRACTION": "#/components/schemas/MetadataExtractionLLMJobResult" + } + } + } + }, + "type": "object", + "required": ["llm_job_type", "specific_llm_job_result"], + "title": "LLMJobResult" + }, + "LLMJobType": { + "type": "string", + "enum": ["DOCUMENT_TAGGING", "METADATA_EXTRACTION", "ANNOTATION"], + "title": "LLMJobType" + }, + "LLMPromptTemplates": { + "properties": { + "language": { "type": "string", "title": "Language", "description": "The language of the prompt template" }, + "system_prompt": { + "type": "string", + "title": "System Prompt", + "description": "The system prompt to use for the job" + }, + "user_prompt": { + "type": "string", + "title": "User Prompt", + "description": "The user prompt to use for the job" + } + }, + "type": "object", + "required": ["language", "system_prompt", "user_prompt"], + "title": "LLMPromptTemplates" + }, "ListOperator": { "type": "string", "enum": ["LIST_CONTAINS", "LIST_NOT_CONTAINS"], "title": "ListOperator" }, "LogicalOperator": { "type": "string", @@ -7409,6 +7873,59 @@ "title": "MemoUpdate" }, "MetaType": { "type": "string", "enum": ["STRING", "NUMBER", "DATE", "BOOLEAN", "LIST"], "title": "MetaType" }, + "MetadataExtractionLLMJobParams": { + "properties": { + "llm_job_type": { "const": "METADATA_EXTRACTION", "title": "Llm Job Type" }, + "sdoc_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Sdoc Ids", + "description": "IDs of the source documents to analyse" + }, + "project_metadata_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Project Metadata Ids", + "description": "IDs of the project metadata to use for the metadata extraction" + } + }, + "type": "object", + "required": ["llm_job_type", "sdoc_ids", "project_metadata_ids"], + "title": "MetadataExtractionLLMJobParams" + }, + "MetadataExtractionLLMJobResult": { + "properties": { + "llm_job_type": { "const": "METADATA_EXTRACTION", "title": "Llm Job Type" }, + "results": { + "items": { "$ref": "#/components/schemas/MetadataExtractionResult" }, + "type": "array", + "title": "Results" + } + }, + "type": "object", + "required": ["llm_job_type", "results"], + "title": "MetadataExtractionLLMJobResult" + }, + "MetadataExtractionResult": { + "properties": { + "sdoc_id": { "type": "integer", "title": "Sdoc Id", "description": "ID of the source document" }, + "current_metadata": { + "items": { "$ref": "#/components/schemas/SourceDocumentMetadataReadResolved" }, + "type": "array", + "title": "Current Metadata", + "description": "Current metadata" + }, + "suggested_metadata": { + "items": { "$ref": "#/components/schemas/SourceDocumentMetadataReadResolved" }, + "type": "array", + "title": "Suggested Metadata", + "description": "Suggested metadata" + } + }, + "type": "object", + "required": ["sdoc_id", "current_metadata", "suggested_metadata"], + "title": "MetadataExtractionResult" + }, "NumberOperator": { "type": "string", "enum": ["NUMBER_EQUALS", "NUMBER_NOT_EQUALS", "NUMBER_GT", "NUMBER_LT", "NUMBER_GTE", "NUMBER_LTE"], @@ -7583,6 +8100,11 @@ "allOf": [{ "$ref": "#/components/schemas/DocType" }], "description": "DOCTYPE of the SourceDocument this metadata refers to" }, + "description": { + "type": "string", + "title": "Description", + "description": "Description of the ProjectMetadata" + }, "project_id": { "type": "integer", "title": "Project Id", @@ -7590,7 +8112,7 @@ } }, "type": "object", - "required": ["key", "metatype", "doctype", "project_id"], + "required": ["key", "metatype", "doctype", "description", "project_id"], "title": "ProjectMetadataCreate" }, "ProjectMetadataRead": { @@ -7610,6 +8132,11 @@ "allOf": [{ "$ref": "#/components/schemas/DocType" }], "description": "DOCTYPE of the SourceDocument this metadata refers to" }, + "description": { + "type": "string", + "title": "Description", + "description": "Description of the ProjectMetadata" + }, "id": { "type": "integer", "title": "Id", "description": "ID of the ProjectMetadata" }, "project_id": { "type": "integer", @@ -7618,7 +8145,7 @@ } }, "type": "object", - "required": ["key", "metatype", "doctype", "id", "project_id"], + "required": ["key", "metatype", "doctype", "description", "id", "project_id"], "title": "ProjectMetadataRead" }, "ProjectMetadataUpdate": { @@ -7631,10 +8158,14 @@ "metatype": { "anyOf": [{ "$ref": "#/components/schemas/MetaType" }, { "type": "null" }], "description": "Type of the ProjectMetadata" + }, + "description": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Description", + "description": "Description of the ProjectMetadata" } }, "type": "object", - "required": ["metatype"], "title": "ProjectMetadataUpdate" }, "ProjectRead": { @@ -8015,6 +8546,24 @@ "required": ["column", "direction"], "title": "Sort[WordFrequencyColumns]" }, + "SourceDocumentDocumentTagLinks": { + "properties": { + "source_document_id": { + "type": "integer", + "title": "Source Document Id", + "description": "ID of SourceDocument" + }, + "document_tag_ids": { + "items": { "type": "integer" }, + "type": "array", + "title": "Document Tag Ids", + "description": "List of IDs of DocumentTags" + } + }, + "type": "object", + "required": ["source_document_id", "document_tag_ids"], + "title": "SourceDocumentDocumentTagLinks" + }, "SourceDocumentDocumentTagMultiLink": { "properties": { "source_document_ids": { @@ -8034,6 +8583,39 @@ "required": ["source_document_ids", "document_tag_ids"], "title": "SourceDocumentDocumentTagMultiLink" }, + "SourceDocumentMetadataBulkUpdate": { + "properties": { + "int_value": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Int Value", + "description": "Int Value of the SourceDocumentMetadata" + }, + "str_value": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Str Value", + "description": "String Value of the SourceDocumentMetadata" + }, + "boolean_value": { + "anyOf": [{ "type": "boolean" }, { "type": "null" }], + "title": "Boolean Value", + "description": "Boolean Value of the SourceDocumentMetadata" + }, + "date_value": { + "anyOf": [{ "type": "string", "format": "date-time" }, { "type": "null" }], + "title": "Date Value", + "description": "Date Value of the SourceDocumentMetadata" + }, + "list_value": { + "anyOf": [{ "items": { "type": "string" }, "type": "array" }, { "type": "null" }], + "title": "List Value", + "description": "List Value of the SourceDocumentMetadata" + }, + "id": { "type": "integer", "title": "Id", "description": "ID of the SourceDocumentMetadata" } + }, + "type": "object", + "required": ["int_value", "str_value", "boolean_value", "date_value", "list_value", "id"], + "title": "SourceDocumentMetadataBulkUpdate" + }, "SourceDocumentMetadataCreate": { "properties": { "int_value": { @@ -8401,6 +8983,37 @@ ], "title": "SourceDocumentWithDataRead" }, + "SpanAnnotationCreateBulkWithCodeId": { + "properties": { + "begin": { "type": "integer", "title": "Begin", "description": "Begin of the SpanAnnotation in characters" }, + "end": { "type": "integer", "title": "End", "description": "End of the SpanAnnotation in characters" }, + "begin_token": { + "type": "integer", + "title": "Begin Token", + "description": "Begin of the SpanAnnotation in tokens" + }, + "end_token": { + "type": "integer", + "title": "End Token", + "description": "End of the SpanAnnotation in tokens" + }, + "span_text": { + "type": "string", + "title": "Span Text", + "description": "The SpanText the SpanAnnotation spans." + }, + "code_id": { "type": "integer", "title": "Code Id", "description": "Code the SpanAnnotation refers to" }, + "sdoc_id": { + "type": "integer", + "title": "Sdoc Id", + "description": "SourceDocument the SpanAnnotation refers to" + }, + "user_id": { "type": "integer", "title": "User Id", "description": "User the SpanAnnotation belongs to" } + }, + "type": "object", + "required": ["begin", "end", "begin_token", "end_token", "span_text", "code_id", "sdoc_id", "user_id"], + "title": "SpanAnnotationCreateBulkWithCodeId" + }, "SpanAnnotationCreateWithCodeId": { "properties": { "begin": { "type": "integer", "title": "Begin", "description": "Begin of the SpanAnnotation in characters" }, diff --git a/frontend/src/views/annotation/DocumentRenderer/useComputeTokenDataWithAnnotations.ts b/frontend/src/views/annotation/DocumentRenderer/useComputeTokenDataWithAnnotations.ts new file mode 100644 index 000000000..a55e97c1d --- /dev/null +++ b/frontend/src/views/annotation/DocumentRenderer/useComputeTokenDataWithAnnotations.ts @@ -0,0 +1,53 @@ +import { useMemo } from "react"; +import { SourceDocumentWithDataRead } from "../../../api/openapi/models/SourceDocumentWithDataRead.ts"; +import { SpanAnnotationReadResolved } from "../../../api/openapi/models/SpanAnnotationReadResolved.ts"; +import { IToken } from "./IToken.ts"; + +function useComputeTokenDataWithAnnotations({ + sdoc, + annotations, +}: { + sdoc: SourceDocumentWithDataRead; + annotations: SpanAnnotationReadResolved[]; +}) { + // computed + // todo: maybe implement with selector? + const tokenData: IToken[] | undefined = useMemo(() => { + const offsets = sdoc.token_character_offsets; + const texts = sdoc.tokens; + const result = texts.map((text, index) => ({ + beginChar: offsets[index][0], + endChar: offsets[index][1], + index, + text, + whitespace: offsets.length > index + 1 && offsets[index + 1][0] - offsets[index][1] > 0, + newLine: text.split("\n").length - 1, + })); + return result; + }, [sdoc]); + + // todo: maybe implement with selector? + // this map stores annotationId -> SpanAnnotationReadResolved + const annotationMap = useMemo(() => { + const result = new Map(); + annotations.forEach((a) => result.set(a.id, a)); + return result; + }, [annotations]); + + // this map stores tokenId -> spanAnnotationId[] + const annotationsPerToken = useMemo(() => { + const result = new Map(); + annotations.forEach((annotation) => { + for (let i = annotation.begin_token; i <= annotation.end_token - 1; i++) { + const tokenAnnotations = result.get(i) || []; + tokenAnnotations.push(annotation.id); + result.set(i, tokenAnnotations); + } + }); + return result; + }, [annotations]); + + return { tokenData, annotationsPerToken, annotationMap }; +} + +export default useComputeTokenDataWithAnnotations; diff --git a/frontend/src/views/search/DocumentSearch/SearchDocumentTable.tsx b/frontend/src/views/search/DocumentSearch/SearchDocumentTable.tsx index 4c763ad75..09da1900d 100644 --- a/frontend/src/views/search/DocumentSearch/SearchDocumentTable.tsx +++ b/frontend/src/views/search/DocumentSearch/SearchDocumentTable.tsx @@ -16,6 +16,7 @@ import { PaginatedElasticSearchDocumentHits } from "../../../api/openapi/models/ import { SearchColumns } from "../../../api/openapi/models/SearchColumns.ts"; import { useAuth } from "../../../auth/useAuth.ts"; import ReduxFilterDialog from "../../../components/FilterDialog/ReduxFilterDialog.tsx"; +import LLMAssistanceButton from "../../../components/LLMDialog/LLMAssistanceButton.tsx"; import SdocMetadataRenderer from "../../../components/Metadata/SdocMetadataRenderer.tsx"; import DeleteSdocsButton from "../../../components/SourceDocument/DeleteSdocsButton.tsx"; import DownloadSdocsButton from "../../../components/SourceDocument/DownloadSdocsButton.tsx"; @@ -261,6 +262,7 @@ function SearchDocumentTable({ projectId, data, isLoading, isFetching, isError } /> + )} diff --git a/tools/importer/README.md b/tools/importer/README.md index ff3beee4d..d538dee5c 100644 --- a/tools/importer/README.md +++ b/tools/importer/README.md @@ -18,6 +18,9 @@ pip install python-magic ## Usage ``` +# import klimaallgemein +python importer/dats_importer.py --input_dir /ltstorage/shares/projects/dwts/backend/src/dev_notebooks/data/KlimaAllgemein/json2 --backend_url http://localhost:19002/ --project_id 84 --tag_name klima --tag_description klima --is_json --filter_duplicate_files_before_upload --metadata_keys paper paper_db headline date --metadata_types STRING STRING STRING DATE --doctype text --content_key html --mime_type text/html + # import cnn python importer/dats_importer.py --input_dir /home/tfischer/Development/dats/data/cnn_crawl_fixed --backend_url http://localhost:10220/ --project_id 1 --tag_name cnn --tag_description cnn --is_json --filter_duplicate_files_before_upload --metadata_keys author published_date visited_date origin --metadata_types STRING DATE DATE STRING --doctype text --content_key text