Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
StanGirard committed Sep 25, 2024
1 parent d123495 commit 33002fd
Show file tree
Hide file tree
Showing 35 changed files with 112 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
upload_file_storage,
)
from quivr_api.modules.user.entity.user_identity import UserIdentity
from quivr_api.modules.assistant.entity.task_entity import TaskMetadata

logger = get_logger(__name__)

Expand Down Expand Up @@ -83,7 +84,7 @@ async def create_task(
raise HTTPException(status_code=400, detail=error)
else:
print("Assistant input is valid.")
notification_uuid = uuid4()
notification_uuid = f"{assistant.name}-{str(uuid4())[:8]}"

# Process files dynamically
for upload_file in files:
Expand All @@ -96,12 +97,13 @@ async def create_task(
raise HTTPException(
status_code=500, detail=f"Failed to upload file to storage. {e}"
)

task = CreateTask(
assistant_id=input.id,
assistant_name=assistant.name,
pretty_id=f"{assistant.name}-{str(notification_uuid)[:8]}",
pretty_id=notification_uuid,
settings=input.model_dump(mode="json"),
task_metadata=TaskMetadata(input_files=[file.filename for file in files]).model_dump(mode="json") if files else None, # type: ignore
)

task_created = await tasks_service.create_task(task, current_user.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
AssistantOutput,
InputBoolean,
InputFile,
InputSelectText,
Inputs,
InputSelectText,
Pricing,
)

Expand Down Expand Up @@ -201,7 +201,10 @@ def validate_assistant_input(
InputSelectText(
key="DocumentsType",
description="Select Documents Type",
options=["Etiquettes VS Cahier des charges", "Fiche Dev VS Cahier des charges"],
options=[
"Etiquettes VS Cahier des charges",
"Fiche Dev VS Cahier des charges",
],
),
],
),
Expand All @@ -222,7 +225,9 @@ def validate_assistant_input(
InputFile(key="Document 2", description="File description"),
],
booleans=[
InputBoolean(key="Hard-to-Read Document?", description="Boolean description"),
InputBoolean(
key="Hard-to-Read Document?", description="Boolean description"
),
],
select_texts=[
InputSelectText(
Expand Down
3 changes: 2 additions & 1 deletion backend/api/quivr_api/modules/assistant/dto/inputs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional
from uuid import UUID

from pydantic import BaseModel, root_validator
Expand All @@ -9,6 +9,7 @@ class CreateTask(BaseModel):
assistant_id: int
assistant_name: str
settings: dict
task_metadata: Dict | None = None


class BrainInput(BaseModel):
Expand Down
10 changes: 7 additions & 3 deletions backend/api/quivr_api/modules/assistant/entity/task_entity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from datetime import datetime
from typing import Dict
from typing import Dict, List, Optional
from uuid import UUID

from pydantic import BaseModel
from sqlmodel import JSON, TIMESTAMP, BigInteger, Column, Field, SQLModel, text


class TaskMetadata(BaseModel):
input_files: Optional[List[str]] = None


class Task(SQLModel, table=True):
__tablename__ = "tasks" # type: ignore

Expand All @@ -30,6 +35,5 @@ class Task(SQLModel, table=True):
)
settings: Dict = Field(default_factory=dict, sa_column=Column(JSON))
answer: str | None = Field(default=None)
task_metadata: Dict | None = Field(default=None, sa_column=Column(JSON))

class Config:
arbitrary_types_allowed = True
7 changes: 5 additions & 2 deletions backend/api/quivr_api/modules/assistant/repository/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from sqlalchemy import exc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select, col
from sqlmodel import col, select

from quivr_api.modules.assistant.dto.inputs import CreateTask
from quivr_api.modules.assistant.entity.task_entity import Task
Expand All @@ -25,6 +25,7 @@ async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
pretty_id=task.pretty_id,
user_id=user_id,
settings=task.settings,
task_metadata=task.task_metadata, # type: ignore
)
self.session.add(task_to_create)
await self.session.commit()
Expand All @@ -41,7 +42,9 @@ async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
return response.one()

async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
query = select(Task).where(Task.user_id == user_id).order_by(col(Task.id).desc())
query = (
select(Task).where(Task.user_id == user_id).order_by(col(Task.id).desc())
)
response = await self.session.exec(query)
return response.all()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Dict
from typing import Dict, Optional, Tuple
from uuid import UUID

from fastapi import HTTPException
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
import os
import time
from enum import Enum

from fastapi import HTTPException
Expand Down
4 changes: 2 additions & 2 deletions backend/api/quivr_api/modules/chat/controller/chat_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Annotated, List, Optional
from uuid import UUID
import os

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from quivr_core.config import RetrievalConfig

from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
Expand Down Expand Up @@ -36,7 +37,6 @@
from quivr_api.modules.vector.service.vector_service import VectorService
from quivr_api.utils.telemetry import maybe_send_telemetry
from quivr_api.utils.uuid_generator import generate_uuid_from_string
from quivr_core.config import RetrievalConfig

logger = get_logger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel

from pydantic import BaseModel
from quivr_core.models import KnowledgeStatus
from sqlalchemy import JSON, TIMESTAMP, Column, text
from sqlalchemy.ext.asyncio import AsyncAttrs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,3 @@ async def remove_file(self, storage_path: str):
except Exception as e:
logger.error(e)
raise e

Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,9 @@ async def test_should_process_knowledge_prev_error(
assert new.file_sha1


@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'")
@pytest.mark.skip(
reason="Bug: UnboundLocalError: cannot access local variable 'response'"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
_, [knowledge, _] = test_data
Expand Down
4 changes: 1 addition & 3 deletions backend/api/quivr_api/modules/misc/controller/misc_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

from fastapi import APIRouter, Depends, HTTPException
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import get_async_session
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import text
from sqlmodel.ext.asyncio.session import AsyncSession

logger = get_logger(__name__)

Expand All @@ -20,7 +19,6 @@ async def root():

@misc_router.get("/healthz", tags=["Health"])
async def healthz(session: AsyncSession = Depends(get_async_session)):

try:
result = await session.execute(text("SELECT 1"))
if not result:
Expand Down
2 changes: 1 addition & 1 deletion backend/api/quivr_api/modules/rag_service/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from uuid import UUID, uuid4

from quivr_api.utils.uuid_generator import generate_uuid_from_string
from quivr_core.brain import Brain as BrainCore
from quivr_core.chat import ChatHistory as ChatHistoryCore
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
Expand All @@ -29,6 +28,7 @@
from quivr_api.modules.prompt.service.prompt_service import PromptService
from quivr_api.modules.user.entity.user_identity import UserIdentity
from quivr_api.modules.vector.service.vector_service import VectorService
from quivr_api.utils.uuid_generator import generate_uuid_from_string
from quivr_api.vectorstore.supabase import CustomSupabaseVectorStore

from .utils import generate_source
Expand Down
2 changes: 1 addition & 1 deletion backend/api/quivr_api/modules/rag_service/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def generate_source(
try:
file_name = doc.metadata["file_name"]
file_path = await knowledge_service.get_knowledge_storage_path(
file_name=file_name, brain_id=brain_id
file_name=file_name, brain_id=brain_id
)
if file_path in generated_urls:
source_url = generated_urls[file_path]
Expand Down
10 changes: 4 additions & 6 deletions backend/api/quivr_api/modules/sync/repository/sync_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None):
sync_user_id,
)
query = (
self.db.from_("syncs_user")
.select("*")
.eq("user_id", user_id)
self.db.from_("syncs_user").select("*").eq("user_id", user_id)
# .neq("status", "REMOVED")
)
if sync_user_id:
Expand Down Expand Up @@ -170,9 +168,9 @@ def update_sync_user(
)

state_str = json.dumps(state)
self.db.from_("syncs_user").update(sync_user_input.model_dump(exclude_unset=True)).eq(
"user_id", str(sync_user_id)
).eq("state", state_str).execute()
self.db.from_("syncs_user").update(
sync_user_input.model_dump(exclude_unset=True)
).eq("user_id", str(sync_user_id)).eq("state", state_str).execute()
logger.info("Sync user updated successfully")

def update_sync_user_status(self, sync_user_id: int, status: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def handler(request):
assert len(result) == 0


@pytest.mark.skip(reason="Bug: httpx.ConnectError: [Errno -2] Name or service not known'")
@pytest.mark.skip(
reason="Bug: httpx.ConnectError: [Errno -2] Name or service not known'"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_store_notion_pages_success(
session: AsyncSession,
Expand Down
16 changes: 10 additions & 6 deletions backend/api/quivr_api/modules/sync/tests/test_syncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ async def test_process_sync_file_not_supported(syncutils: SyncUtils):
sync_active=sync_active,
)

@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'")

@pytest.mark.skip(
reason="Bug: UnboundLocalError: cannot access local variable 'response'"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_process_sync_file_noprev(
monkeypatch,
Expand Down Expand Up @@ -327,8 +330,8 @@ def _send_task(*args, **kwargs):
assert created_km.file_sha1 is None
assert created_km.created_at is not None
assert created_km.metadata == {"sync_file_id": "1"}
assert len(created_km.brains)> 0
assert created_km.brains[0]["brain_id"]== brain_1.brain_id
assert len(created_km.brains) > 0
assert created_km.brains[0]["brain_id"] == brain_1.brain_id

# Assert celery task in correct
assert task["args"] == ("process_file_task",)
Expand All @@ -345,8 +348,9 @@ def _send_task(*args, **kwargs):
)



@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'")
@pytest.mark.skip(
reason="Bug: UnboundLocalError: cannot access local variable 'response'"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_process_sync_file_with_prev(
monkeypatch,
Expand Down Expand Up @@ -424,7 +428,7 @@ def _send_task(*args, **kwargs):
assert created_km.created_at
assert created_km.updated_at == created_km.created_at # new line
assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)}
assert created_km.brains[0]["brain_id"]== brain_1.brain_id
assert created_km.brains[0]["brain_id"] == brain_1.brain_id

# Check file content changed
assert check_file_exists(str(brain_1.brain_id), sync_file.name)
Expand Down
7 changes: 6 additions & 1 deletion backend/api/quivr_api/modules/sync/utils/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,12 @@ async def aget_files(
pages.append(page_info)

if recursive:
sub_pages = await self.aget_files(credentials=credentials, sync_user_id=sync_user_id, folder_id=str(page.id), recursive=recursive)
sub_pages = await self.aget_files(
credentials=credentials,
sync_user_id=sync_user_id,
folder_id=str(page.id),
recursive=recursive,
)
pages.extend(sub_pages)
return pages

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import base64
import re
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import List

from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
import base64
from pdf2image import convert_from_path
import asyncio
import re

# BASE_OCR_PROMPT = """
# Transcribe the content of this file into markdown. Be mindful of the formatting.
Expand Down
8 changes: 5 additions & 3 deletions backend/core/MegaParse/megaparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from docx.document import Document as DocumentObject
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.section import Section
from docx.section import _Footer as Footer
from docx.section import _Header as Header
from docx.table import Table
from docx.text.paragraph import Paragraph
from docx.section import Section, _Header as Header, _Footer as Footer
from docx.oxml.text.paragraph import CT_P
from docx.oxml.table import CT_Tbl


def print_element(element):
Expand Down
1 change: 0 additions & 1 deletion backend/core/MegaParse/tests/test_import.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from megaparse.Converter import MegaParse


Expand Down
1 change: 0 additions & 1 deletion backend/core/examples/simple_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from quivr_core import Brain
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph


if __name__ == "__main__":
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
Expand Down
4 changes: 2 additions & 2 deletions backend/core/quivr_core/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from copy import deepcopy
from datetime import datetime
from typing import Any, Generator, Tuple, List
from typing import Any, Generator, List, Tuple
from uuid import UUID, uuid4
from copy import deepcopy

from langchain_core.messages import AIMessage, HumanMessage

Expand Down
2 changes: 1 addition & 1 deletion backend/core/quivr_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from enum import Enum
from typing import Dict, List, Optional
from uuid import UUID
from sqlmodel import SQLModel

from megaparse.config import MegaparseConfig
from sqlmodel import SQLModel

from quivr_core.base_config import QuivrBaseConfig
from quivr_core.processor.splitter import SplitterConfig
Expand Down
Loading

0 comments on commit 33002fd

Please sign in to comment.