diff --git a/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py b/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py index a53b92408b87..cf2e3207e6e3 100644 --- a/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py +++ b/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py @@ -22,6 +22,7 @@ upload_file_storage, ) from quivr_api.modules.user.entity.user_identity import UserIdentity +from quivr_api.modules.assistant.entity.task_entity import TaskMetadata logger = get_logger(__name__) @@ -83,7 +84,7 @@ async def create_task( raise HTTPException(status_code=400, detail=error) else: print("Assistant input is valid.") - notification_uuid = uuid4() + notification_uuid = f"{assistant.name}-{str(uuid4())[:8]}" # Process files dynamically for upload_file in files: @@ -96,12 +97,13 @@ async def create_task( raise HTTPException( status_code=500, detail=f"Failed to upload file to storage. {e}" ) - + task = CreateTask( assistant_id=input.id, assistant_name=assistant.name, - pretty_id=f"{assistant.name}-{str(notification_uuid)[:8]}", + pretty_id=notification_uuid, settings=input.model_dump(mode="json"), + task_metadata=TaskMetadata(input_files=[file.filename for file in files]).model_dump(mode="json") if files else None, # type: ignore ) task_created = await tasks_service.create_task(task, current_user.id) diff --git a/backend/api/quivr_api/modules/assistant/controller/assistants_definition.py b/backend/api/quivr_api/modules/assistant/controller/assistants_definition.py index a1445473a2ac..edaf09af8997 100644 --- a/backend/api/quivr_api/modules/assistant/controller/assistants_definition.py +++ b/backend/api/quivr_api/modules/assistant/controller/assistants_definition.py @@ -3,8 +3,8 @@ AssistantOutput, InputBoolean, InputFile, - InputSelectText, Inputs, + InputSelectText, Pricing, ) @@ -201,7 +201,10 @@ def validate_assistant_input( InputSelectText( key="DocumentsType", description="Select Documents Type", - options=["Etiquettes VS Cahier des charges", "Fiche Dev VS Cahier des charges"], + options=[ + "Etiquettes VS Cahier des charges", + "Fiche Dev VS Cahier des charges", + ], ), ], ), @@ -222,7 +225,9 @@ def validate_assistant_input( InputFile(key="Document 2", description="File description"), ], booleans=[ - InputBoolean(key="Hard-to-Read Document?", description="Boolean description"), + InputBoolean( + key="Hard-to-Read Document?", description="Boolean description" + ), ], select_texts=[ InputSelectText( diff --git a/backend/api/quivr_api/modules/assistant/dto/inputs.py b/backend/api/quivr_api/modules/assistant/dto/inputs.py index f0dcff5a7dd1..0847224dd2a4 100644 --- a/backend/api/quivr_api/modules/assistant/dto/inputs.py +++ b/backend/api/quivr_api/modules/assistant/dto/inputs.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional from uuid import UUID from pydantic import BaseModel, root_validator @@ -9,6 +9,7 @@ class CreateTask(BaseModel): assistant_id: int assistant_name: str settings: dict + task_metadata: Dict | None = None class BrainInput(BaseModel): diff --git a/backend/api/quivr_api/modules/assistant/entity/task_entity.py b/backend/api/quivr_api/modules/assistant/entity/task_entity.py index 28c2e5d9f29f..cbc8db994b80 100644 --- a/backend/api/quivr_api/modules/assistant/entity/task_entity.py +++ b/backend/api/quivr_api/modules/assistant/entity/task_entity.py @@ -1,10 +1,15 @@ from datetime import datetime -from typing import Dict +from typing import Dict, List, Optional from uuid import UUID +from pydantic import BaseModel from sqlmodel import JSON, TIMESTAMP, BigInteger, Column, Field, SQLModel, text +class TaskMetadata(BaseModel): + input_files: Optional[List[str]] = None + + class Task(SQLModel, table=True): __tablename__ = "tasks" # type: ignore @@ -30,6 +35,5 @@ class Task(SQLModel, table=True): ) settings: Dict = Field(default_factory=dict, sa_column=Column(JSON)) answer: str | None = Field(default=None) + task_metadata: Dict | None = Field(default=None, sa_column=Column(JSON)) - class Config: - arbitrary_types_allowed = True diff --git a/backend/api/quivr_api/modules/assistant/repository/tasks.py b/backend/api/quivr_api/modules/assistant/repository/tasks.py index 7c3d92a5db18..9d45ce7d53e5 100644 --- a/backend/api/quivr_api/modules/assistant/repository/tasks.py +++ b/backend/api/quivr_api/modules/assistant/repository/tasks.py @@ -3,7 +3,7 @@ from sqlalchemy import exc from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import select, col +from sqlmodel import col, select from quivr_api.modules.assistant.dto.inputs import CreateTask from quivr_api.modules.assistant.entity.task_entity import Task @@ -25,6 +25,7 @@ async def create_task(self, task: CreateTask, user_id: UUID) -> Task: pretty_id=task.pretty_id, user_id=user_id, settings=task.settings, + task_metadata=task.task_metadata, # type: ignore ) self.session.add(task_to_create) await self.session.commit() @@ -41,7 +42,9 @@ async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task: return response.one() async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]: - query = select(Task).where(Task.user_id == user_id).order_by(col(Task.id).desc()) + query = ( + select(Task).where(Task.user_id == user_id).order_by(col(Task.id).desc()) + ) response = await self.session.exec(query) return response.all() diff --git a/backend/api/quivr_api/modules/brain/service/brain_service.py b/backend/api/quivr_api/modules/brain/service/brain_service.py index e5b403d8f03e..7b9da881c7ff 100644 --- a/backend/api/quivr_api/modules/brain/service/brain_service.py +++ b/backend/api/quivr_api/modules/brain/service/brain_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Dict +from typing import Dict, Optional, Tuple from uuid import UUID from fastapi import HTTPException diff --git a/backend/api/quivr_api/modules/chat/controller/chat/utils.py b/backend/api/quivr_api/modules/chat/controller/chat/utils.py index 5306f4ecbeb7..ce5e684221df 100644 --- a/backend/api/quivr_api/modules/chat/controller/chat/utils.py +++ b/backend/api/quivr_api/modules/chat/controller/chat/utils.py @@ -1,5 +1,5 @@ -import time import os +import time from enum import Enum from fastapi import HTTPException diff --git a/backend/api/quivr_api/modules/chat/controller/chat_routes.py b/backend/api/quivr_api/modules/chat/controller/chat_routes.py index a42d7fe7fb7a..f89e792c81fa 100644 --- a/backend/api/quivr_api/modules/chat/controller/chat_routes.py +++ b/backend/api/quivr_api/modules/chat/controller/chat_routes.py @@ -1,9 +1,10 @@ +import os from typing import Annotated, List, Optional from uuid import UUID -import os from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request from fastapi.responses import StreamingResponse +from quivr_core.config import RetrievalConfig from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user @@ -36,7 +37,6 @@ from quivr_api.modules.vector.service.vector_service import VectorService from quivr_api.utils.telemetry import maybe_send_telemetry from quivr_api.utils.uuid_generator import generate_uuid_from_string -from quivr_core.config import RetrievalConfig logger = get_logger(__name__) diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index dcbb7a5b66d8..e08f3c0abcdb 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -2,8 +2,8 @@ from enum import Enum from typing import Any, Dict, List, Optional from uuid import UUID -from pydantic import BaseModel +from pydantic import BaseModel from quivr_core.models import KnowledgeStatus from sqlalchemy import JSON, TIMESTAMP, Column, text from sqlalchemy.ext.asyncio import AsyncAttrs diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage.py b/backend/api/quivr_api/modules/knowledge/repository/storage.py index e53165e22282..ad35659dbbd0 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/storage.py +++ b/backend/api/quivr_api/modules/knowledge/repository/storage.py @@ -86,4 +86,3 @@ async def remove_file(self, storage_path: str): except Exception as e: logger.error(e) raise e - diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index a0f49a07ed30..7381b6e917dd 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -527,7 +527,9 @@ async def test_should_process_knowledge_prev_error( assert new.file_sha1 -@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'") +@pytest.mark.skip( + reason="Bug: UnboundLocalError: cannot access local variable 'response'" +) @pytest.mark.asyncio(loop_scope="session") async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData): _, [knowledge, _] = test_data diff --git a/backend/api/quivr_api/modules/misc/controller/misc_routes.py b/backend/api/quivr_api/modules/misc/controller/misc_routes.py index 590b3cd0e3aa..054798b34c18 100644 --- a/backend/api/quivr_api/modules/misc/controller/misc_routes.py +++ b/backend/api/quivr_api/modules/misc/controller/misc_routes.py @@ -1,9 +1,8 @@ - from fastapi import APIRouter, Depends, HTTPException from quivr_api.logger import get_logger from quivr_api.modules.dependencies import get_async_session -from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel import text +from sqlmodel.ext.asyncio.session import AsyncSession logger = get_logger(__name__) @@ -20,7 +19,6 @@ async def root(): @misc_router.get("/healthz", tags=["Health"]) async def healthz(session: AsyncSession = Depends(get_async_session)): - try: result = await session.execute(text("SELECT 1")) if not result: diff --git a/backend/api/quivr_api/modules/rag_service/rag_service.py b/backend/api/quivr_api/modules/rag_service/rag_service.py index c1f3ee7da6a3..4183524b0a52 100644 --- a/backend/api/quivr_api/modules/rag_service/rag_service.py +++ b/backend/api/quivr_api/modules/rag_service/rag_service.py @@ -2,7 +2,6 @@ import os from uuid import UUID, uuid4 -from quivr_api.utils.uuid_generator import generate_uuid_from_string from quivr_core.brain import Brain as BrainCore from quivr_core.chat import ChatHistory as ChatHistoryCore from quivr_core.config import LLMEndpointConfig, RetrievalConfig @@ -29,6 +28,7 @@ from quivr_api.modules.prompt.service.prompt_service import PromptService from quivr_api.modules.user.entity.user_identity import UserIdentity from quivr_api.modules.vector.service.vector_service import VectorService +from quivr_api.utils.uuid_generator import generate_uuid_from_string from quivr_api.vectorstore.supabase import CustomSupabaseVectorStore from .utils import generate_source diff --git a/backend/api/quivr_api/modules/rag_service/utils.py b/backend/api/quivr_api/modules/rag_service/utils.py index 068a2db28c5e..afc12082eac8 100644 --- a/backend/api/quivr_api/modules/rag_service/utils.py +++ b/backend/api/quivr_api/modules/rag_service/utils.py @@ -68,7 +68,7 @@ async def generate_source( try: file_name = doc.metadata["file_name"] file_path = await knowledge_service.get_knowledge_storage_path( - file_name=file_name, brain_id=brain_id + file_name=file_name, brain_id=brain_id ) if file_path in generated_urls: source_url = generated_urls[file_path] diff --git a/backend/api/quivr_api/modules/sync/repository/sync_user.py b/backend/api/quivr_api/modules/sync/repository/sync_user.py index c2de84cc4f41..1fbaca12cd0c 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_user.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_user.py @@ -93,9 +93,7 @@ def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None): sync_user_id, ) query = ( - self.db.from_("syncs_user") - .select("*") - .eq("user_id", user_id) + self.db.from_("syncs_user").select("*").eq("user_id", user_id) # .neq("status", "REMOVED") ) if sync_user_id: @@ -170,9 +168,9 @@ def update_sync_user( ) state_str = json.dumps(state) - self.db.from_("syncs_user").update(sync_user_input.model_dump(exclude_unset=True)).eq( - "user_id", str(sync_user_id) - ).eq("state", state_str).execute() + self.db.from_("syncs_user").update( + sync_user_input.model_dump(exclude_unset=True) + ).eq("user_id", str(sync_user_id)).eq("state", state_str).execute() logger.info("Sync user updated successfully") def update_sync_user_status(self, sync_user_id: int, status: str): diff --git a/backend/api/quivr_api/modules/sync/tests/test_notion_service.py b/backend/api/quivr_api/modules/sync/tests/test_notion_service.py index d866a3d11733..526114c5ef8b 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_notion_service.py +++ b/backend/api/quivr_api/modules/sync/tests/test_notion_service.py @@ -74,7 +74,9 @@ def handler(request): assert len(result) == 0 -@pytest.mark.skip(reason="Bug: httpx.ConnectError: [Errno -2] Name or service not known'") +@pytest.mark.skip( + reason="Bug: httpx.ConnectError: [Errno -2] Name or service not known'" +) @pytest.mark.asyncio(loop_scope="session") async def test_store_notion_pages_success( session: AsyncSession, diff --git a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py index 767a944029e4..3c20f70d9679 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py +++ b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py @@ -271,7 +271,10 @@ async def test_process_sync_file_not_supported(syncutils: SyncUtils): sync_active=sync_active, ) -@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'") + +@pytest.mark.skip( + reason="Bug: UnboundLocalError: cannot access local variable 'response'" +) @pytest.mark.asyncio(loop_scope="session") async def test_process_sync_file_noprev( monkeypatch, @@ -327,8 +330,8 @@ def _send_task(*args, **kwargs): assert created_km.file_sha1 is None assert created_km.created_at is not None assert created_km.metadata == {"sync_file_id": "1"} - assert len(created_km.brains)> 0 - assert created_km.brains[0]["brain_id"]== brain_1.brain_id + assert len(created_km.brains) > 0 + assert created_km.brains[0]["brain_id"] == brain_1.brain_id # Assert celery task in correct assert task["args"] == ("process_file_task",) @@ -345,8 +348,9 @@ def _send_task(*args, **kwargs): ) - -@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'") +@pytest.mark.skip( + reason="Bug: UnboundLocalError: cannot access local variable 'response'" +) @pytest.mark.asyncio(loop_scope="session") async def test_process_sync_file_with_prev( monkeypatch, @@ -424,7 +428,7 @@ def _send_task(*args, **kwargs): assert created_km.created_at assert created_km.updated_at == created_km.created_at # new line assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)} - assert created_km.brains[0]["brain_id"]== brain_1.brain_id + assert created_km.brains[0]["brain_id"] == brain_1.brain_id # Check file content changed assert check_file_exists(str(brain_1.brain_id), sync_file.name) diff --git a/backend/api/quivr_api/modules/sync/utils/sync.py b/backend/api/quivr_api/modules/sync/utils/sync.py index bd9d205276e8..b7985b161bd2 100644 --- a/backend/api/quivr_api/modules/sync/utils/sync.py +++ b/backend/api/quivr_api/modules/sync/utils/sync.py @@ -818,7 +818,12 @@ async def aget_files( pages.append(page_info) if recursive: - sub_pages = await self.aget_files(credentials=credentials, sync_user_id=sync_user_id, folder_id=str(page.id), recursive=recursive) + sub_pages = await self.aget_files( + credentials=credentials, + sync_user_id=sync_user_id, + folder_id=str(page.id), + recursive=recursive, + ) pages.extend(sub_pages) return pages diff --git a/backend/core/MegaParse/megaparse/multimodal_convertor/megaparse_vision.py b/backend/core/MegaParse/megaparse/multimodal_convertor/megaparse_vision.py index 0395a16ff922..f9391881a6b2 100644 --- a/backend/core/MegaParse/megaparse/multimodal_convertor/megaparse_vision.py +++ b/backend/core/MegaParse/megaparse/multimodal_convertor/megaparse_vision.py @@ -1,13 +1,14 @@ +import asyncio +import base64 +import re from enum import Enum from io import BytesIO from pathlib import Path from typing import List + from langchain_core.messages import HumanMessage from langchain_openai import ChatOpenAI -import base64 from pdf2image import convert_from_path -import asyncio -import re # BASE_OCR_PROMPT = """ # Transcribe the content of this file into markdown. Be mindful of the formatting. diff --git a/backend/core/MegaParse/megaparse/utils.py b/backend/core/MegaParse/megaparse/utils.py index 7dea8352481d..b16f022ebe91 100644 --- a/backend/core/MegaParse/megaparse/utils.py +++ b/backend/core/MegaParse/megaparse/utils.py @@ -1,9 +1,11 @@ from docx.document import Document as DocumentObject +from docx.oxml.table import CT_Tbl +from docx.oxml.text.paragraph import CT_P +from docx.section import Section +from docx.section import _Footer as Footer +from docx.section import _Header as Header from docx.table import Table from docx.text.paragraph import Paragraph -from docx.section import Section, _Header as Header, _Footer as Footer -from docx.oxml.text.paragraph import CT_P -from docx.oxml.table import CT_Tbl def print_element(element): diff --git a/backend/core/MegaParse/tests/test_import.py b/backend/core/MegaParse/tests/test_import.py index 72e196c3a9af..840d7baf41e2 100644 --- a/backend/core/MegaParse/tests/test_import.py +++ b/backend/core/MegaParse/tests/test_import.py @@ -1,5 +1,4 @@ import pytest - from megaparse.Converter import MegaParse diff --git a/backend/core/examples/simple_question.py b/backend/core/examples/simple_question.py index b7732d3e2cbc..35ffe1d8291c 100644 --- a/backend/core/examples/simple_question.py +++ b/backend/core/examples/simple_question.py @@ -2,7 +2,6 @@ from quivr_core import Brain from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph - if __name__ == "__main__": with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file: diff --git a/backend/core/quivr_core/chat.py b/backend/core/quivr_core/chat.py index b8d3b1057774..458c7fafbaa5 100644 --- a/backend/core/quivr_core/chat.py +++ b/backend/core/quivr_core/chat.py @@ -1,7 +1,7 @@ +from copy import deepcopy from datetime import datetime -from typing import Any, Generator, Tuple, List +from typing import Any, Generator, List, Tuple from uuid import UUID, uuid4 -from copy import deepcopy from langchain_core.messages import AIMessage, HumanMessage diff --git a/backend/core/quivr_core/config.py b/backend/core/quivr_core/config.py index b974d3220ed7..4c5f9d8513a8 100644 --- a/backend/core/quivr_core/config.py +++ b/backend/core/quivr_core/config.py @@ -2,9 +2,9 @@ from enum import Enum from typing import Dict, List, Optional from uuid import UUID -from sqlmodel import SQLModel from megaparse.config import MegaparseConfig +from sqlmodel import SQLModel from quivr_core.base_config import QuivrBaseConfig from quivr_core.processor.splitter import SplitterConfig diff --git a/backend/core/quivr_core/prompts.py b/backend/core/quivr_core/prompts.py index fa30cb5b8490..48ec90a05e11 100644 --- a/backend/core/quivr_core/prompts.py +++ b/backend/core/quivr_core/prompts.py @@ -1,14 +1,14 @@ import datetime -from pydantic import ConfigDict, create_model -from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, + MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate, - MessagesPlaceholder, ) +from langchain_core.prompts.base import BasePromptTemplate +from pydantic import ConfigDict, create_model class CustomPromptsDict(dict): diff --git a/backend/core/quivr_core/quivr_rag_langgraph.py b/backend/core/quivr_core/quivr_rag_langgraph.py index 7a18f83a111c..12d0bea450ec 100644 --- a/backend/core/quivr_core/quivr_rag_langgraph.py +++ b/backend/core/quivr_core/quivr_rag_langgraph.py @@ -1,7 +1,7 @@ import logging +from enum import Enum from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict from uuid import uuid4 -from enum import Enum # TODO(@aminediro): this is the only dependency to langchain package, we should remove it from langchain.retrievers import ContextualCompressionRetriever @@ -12,7 +12,7 @@ from langchain_core.messages import BaseMessage from langchain_core.messages.ai import AIMessageChunk from langchain_core.vectorstores import VectorStore -from langgraph.graph import START, END, StateGraph +from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from quivr_core.chat import ChatHistory diff --git a/backend/worker/diff-assistant/quivr_diff_assistant/main_uc2.py b/backend/worker/diff-assistant/quivr_diff_assistant/main_uc2.py index 59cf4c979f37..6de22fd92c94 100644 --- a/backend/worker/diff-assistant/quivr_diff_assistant/main_uc2.py +++ b/backend/worker/diff-assistant/quivr_diff_assistant/main_uc2.py @@ -1,22 +1,21 @@ -import streamlit as st import asyncio from enum import Enum -from langchain_openai import ChatOpenAI -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.output_parsers import StrOutputParser - import pandas as pd +import streamlit as st +from dotenv import load_dotenv +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser +from langchain_openai import ChatOpenAI from llama_index.core import SimpleDirectoryReader, VectorStoreIndex from llama_index.core.node_parser import UnstructuredElementNodeParser -from llama_index.core.retrievers import RecursiveRetriever from llama_index.core.query_engine import RetrieverQueryEngine -from llama_index.llms.openai import OpenAI -from quivr_diff_assistant.use_case_3.parser import DeadlyParser +from llama_index.core.retrievers import RecursiveRetriever from llama_index.core.schema import Document +from llama_index.llms.openai import OpenAI from utils.utils import COMPARISON_PROMPT -from dotenv import load_dotenv +from quivr_diff_assistant.use_case_3.parser import DeadlyParser load_dotenv() @@ -80,7 +79,6 @@ class ComparisonTypes(str, Enum): def llm_comparator( document: str, cdc: str, llm: BaseChatModel, comparison_type: ComparisonTypes ): - chain = COMPARISON_PROMPT | llm | StrOutputParser() if comparison_type == ComparisonTypes.CDC_ETIQUETTE: @@ -99,7 +97,6 @@ def llm_comparator( async def test_main(): - cdc_doc = "/Users/jchevall/Coding/diff-assistant/data/Use case #2/Cas2-2-1_Mendiant Lait_QD PC F03 - FR Cahier des charges produit -rev 2021-v2.pdf" doc = "/Users/jchevall/Coding/diff-assistant/data/Use case #2/Cas2-2-1_Proposition étiquette Mendiant Lait croustillant.pdf" @@ -152,7 +149,6 @@ def get_document_path(doc): async def parse_documents(cdc_doc, doc, comparison_type: ComparisonTypes, llm): - parser = DeadlyParser() # Schedule the coroutines as tasks @@ -199,7 +195,6 @@ def main(): return with st.spinner("Processing files..."): - llm = ChatOpenAI( model="gpt-4o", temperature=0.1, diff --git a/backend/worker/diff-assistant/quivr_diff_assistant/main_uc3.py b/backend/worker/diff-assistant/quivr_diff_assistant/main_uc3.py index 2e51d4fb9c73..8f6084b81856 100644 --- a/backend/worker/diff-assistant/quivr_diff_assistant/main_uc3.py +++ b/backend/worker/diff-assistant/quivr_diff_assistant/main_uc3.py @@ -1,19 +1,19 @@ import asyncio -import streamlit as st -from langchain_openai import ChatOpenAI -import tempfile import os +import tempfile from enum import Enum -from use_case_3.diff_type import DiffResult, llm_comparator -from use_case_3.parser import DeadlyParser -from use_case_3.llm_reporter import redact_report from pathlib import Path + +import streamlit as st from diff_match_patch import diff_match_patch -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_openai import ChatOpenAI # get environment variables from dotenv import load_dotenv +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_openai import ChatOpenAI +from use_case_3.diff_type import DiffResult, llm_comparator +from use_case_3.llm_reporter import redact_report +from use_case_3.parser import DeadlyParser load_dotenv() @@ -55,7 +55,7 @@ async def create_modification_report( print("using diff match patch") dmp = diff_match_patch() section_diffs = [] - for after_section, before_section in zip(text_after_sections, text_before_sections): + for after_section, before_section in zip(text_after_sections, text_before_sections, strict=False): main_diff: list[tuple[int, str]] = dmp.diff_main(after_section, before_section) section_diffs.append(DiffResult(main_diff)) diff --git a/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/diff_type.py b/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/diff_type.py index becc0e1ff912..28e18d9be6d7 100644 --- a/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/diff_type.py +++ b/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/diff_type.py @@ -1,4 +1,5 @@ from typing import List, Tuple + from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts.prompt import PromptTemplate diff --git a/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/llm_reporter.py b/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/llm_reporter.py index c2adacfbd4e2..859a64d37a77 100644 --- a/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/llm_reporter.py +++ b/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/llm_reporter.py @@ -1,7 +1,7 @@ from typing import List -from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.prompts.prompt import PromptTemplate from use_case_3.diff_type import DiffResult REPORT_PROMPT = PromptTemplate.from_template( diff --git a/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/parser.py b/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/parser.py index 82fa9113ec0c..98689cf59bfb 100644 --- a/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/parser.py +++ b/backend/worker/diff-assistant/quivr_diff_assistant/use_case_3/parser.py @@ -2,27 +2,21 @@ All of this needs to be in MegaParse, this is just a placeholder for now. """ +import base64 +import os from typing import List + +import cv2 +import numpy as np from doctr.io import DocumentFile +from doctr.io.elements import Document as doctrDocument from doctr.models import ocr_predictor +from doctr.models.predictor.pytorch import OCRPredictor from doctr.utils.common_types import AbstractFile -import os -from doctr.io.elements import Document as doctrDocument from langchain_core.documents import Document -from llama_index.core import SimpleDirectoryReader, VectorStoreIndex - -from numba import njit - - from langchain_core.language_models.chat_models import BaseChatModel -from megaparse import MegaParse # FIXME: @chloedia Version problems - - -from doctr.models.predictor.pytorch import OCRPredictor from langchain_core.messages import HumanMessage -import base64 -import numpy as np -import cv2 +from megaparse import MegaParse # FIXME: @chloedia Version problems os.environ["USE_TORCH"] = "1" @@ -64,7 +58,7 @@ async def deep_aparse( if llm: entire_content = "" print("ocr llm start") - for raw_result, img in zip(raw_results.pages, docs): + for raw_result, img in zip(raw_results.pages, docs, strict=False): if raw_result.render() == "": continue _, buffer = cv2.imencode(".png", img) @@ -125,7 +119,7 @@ def deep_parse( if llm: entire_content = "" print("ocr llm start") - for raw_result, img in zip(raw_results.pages, docs): + for raw_result, img in zip(raw_results.pages, docs, strict=False): if raw_result.render() == "": continue _, buffer = cv2.imencode(".png", img) diff --git a/backend/worker/diff-assistant/quivr_diff_assistant/utils/utils.py b/backend/worker/diff-assistant/quivr_diff_assistant/utils/utils.py index 448182a0718c..90ef92251d6d 100644 --- a/backend/worker/diff-assistant/quivr_diff_assistant/utils/utils.py +++ b/backend/worker/diff-assistant/quivr_diff_assistant/utils/utils.py @@ -1,13 +1,12 @@ from langchain_core.prompts.prompt import PromptTemplate - COMPARISON_PROMPT = PromptTemplate.from_template( template=""" You are provided with two texts and . You need to consider the information contained in \ and compare it with the corresponding information contained in . \ Keep in mind that contains non-relevant information for this task, and that in you \ should only focus on the information correspnding to the information contained in . \ - You need to report all the differences between the information contained in and . \ + You need to report all the differences between the information contained in and . \\ Your job is to parse these differences and create a clear, concise report. \ Organize the report by sections and provide a detailed explanation of each difference. \ Be specific on difference, it will be reviewed and verified by a highly-trained quality engineer. diff --git a/backend/worker/diff-assistant/tests/test_hello.py b/backend/worker/diff-assistant/tests/test_hello.py index 5a76870c7dd0..a8fb1175823f 100644 --- a/backend/worker/diff-assistant/tests/test_hello.py +++ b/backend/worker/diff-assistant/tests/test_hello.py @@ -1,4 +1,3 @@ -import pytest from use_case_3 import hello diff --git a/backend/worker/quivr_worker/assistants/assistants.py b/backend/worker/quivr_worker/assistants/assistants.py index 88db0f34f5c5..dfcc9b256542 100644 --- a/backend/worker/quivr_worker/assistants/assistants.py +++ b/backend/worker/quivr_worker/assistants/assistants.py @@ -1,6 +1,6 @@ import os -import time import random +import time from quivr_api.modules.assistant.services.tasks_service import TasksService from quivr_api.modules.upload.service.upload_file import ( @@ -22,7 +22,7 @@ async def process_assistant( await tasks_service.update_task(task_id, {"status": "processing"}) print(task) - + # Add a random delay of 10 to 20 seconds time.sleep(random.randint(10, 20)) diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index bc6588d65f25..8432434ca84d 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -32,8 +32,8 @@ from sqlmodel import Session, text from sqlmodel.ext.asyncio.session import AsyncSession -from quivr_worker.celery_monitor import is_being_executed from quivr_worker.assistants.assistants import process_assistant +from quivr_worker.celery_monitor import is_being_executed from quivr_worker.check_premium import check_is_premium from quivr_worker.process.process_s3_file import process_uploaded_file from quivr_worker.process.process_url import process_url_func