Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cohere-ai/cohere-toolkit into assis…
Browse files Browse the repository at this point in the history
…tants/deploy
  • Loading branch information
BeatrixCohere committed Jul 15, 2024
2 parents 0cf3abb + 9dfb1c5 commit 6c2312d
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 1 deletion.
11 changes: 11 additions & 0 deletions src/backend/crud/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def create_file(db: Session, file: File) -> File:
return file


def batch_create_files(db: Session, files: list[File]) -> list[File]:
"""
Batch create files.
"""
db.add_all(files)
db.commit()
for file in files:
db.refresh(file)
return files


def get_file(db: Session, file_id: str, user_id: str) -> File:
"""
Get a file by ID.
Expand Down
90 changes: 89 additions & 1 deletion src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
SEARCH_RELEVANCE_THRESHOLD,
extract_details_from_conversation,
)
from backend.services.file import get_file_content, validate_file_size
from backend.services.file import (
get_file_content,
validate_batch_file_size,
validate_file_size,
)

router = APIRouter(
prefix="/v1/conversations",
Expand Down Expand Up @@ -243,6 +247,7 @@ async def search_conversations(


# FILES
# TODO: Deprecate singular file upload once client uses batch upload endpoint
@router.post("/upload_file", response_model=UploadFile)
async def upload_file(
session: DBSessionDep,
Expand Down Expand Up @@ -324,6 +329,89 @@ async def upload_file(
return upload_file


@router.post("/batch_upload_file", response_model=list[UploadFile])
async def batch_upload_file(
session: DBSessionDep,
request: Request,
conversation_id: str = Form(None),
files: list[FastAPIUploadFile] = RequestFile(...),
) -> UploadFile:
"""
Uploads and creates a batch of File object.
If no conversation_id is provided, a new Conversation is created as well.
Args:
session (DBSessionDep): Database session.
file (list[FastAPIUploadFile]): List of files to be uploaded.
conversation_id (Optional[str]): Conversation ID passed from request query parameter.
Returns:
list[UploadFile]: List of uploaded files.
Raises:
HTTPException: If the conversation with the given ID is not found. Status code 404.
HTTPException: If the file wasn't uploaded correctly. Status code 500.
"""

user_id = get_header_user_id(request)

validate_batch_file_size(session, user_id, files)

# Create new conversation
if not conversation_id:
conversation = conversation_crud.create_conversation(
session,
ConversationModel(user_id=user_id),
)
# Check for existing conversation
else:
conversation = conversation_crud.get_conversation(
session, conversation_id, user_id
)

# Fail if user_id is not provided when conversation DNE
if not conversation:
if not user_id:
raise HTTPException(
status_code=400,
detail=f"user_id is required if no valid conversation is provided.",
)

# Create new conversation
conversation = conversation_crud.create_conversation(
session,
ConversationModel(user_id=user_id),
)

# TODO: check if file already exists in DB once we have files per agents

# Handle uploading File
files_to_upload = []
for file in files:
content = await get_file_content(file)
cleaned_content = content.replace("\x00", "")
filename = file.filename.encode("ascii", "ignore").decode("utf-8")

# Create File
upload_file = FileModel(
user_id=conversation.user_id,
conversation_id=conversation.id,
file_name=filename,
file_path=filename,
file_size=file.size,
file_content=cleaned_content,
)
files_to_upload.append(upload_file)
try:
uploaded_files = file_crud.batch_create_files(session, files_to_upload)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error while uploading file(s): {e}."
)

return uploaded_files


@router.get("/{conversation_id}/files", response_model=list[ListFile])
async def list_files(
conversation_id: str, session: DBSessionDep, request: Request
Expand Down
31 changes: 31 additions & 0 deletions src/backend/services/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,34 @@ def validate_file_size(
status_code=400,
detail=f"Total file size exceeds the maximum allowed size of {MAX_TOTAL_FILE_SIZE} bytes.",
)


def validate_batch_file_size(
session: DBSessionDep, user_id: str, files: list[FastAPIUploadFile]
) -> None:
"""Validate sizes of files in batch
Args:
user_id (str): The user ID
files (list[FastAPIUploadFile]): The files to validate
Raises:
HTTPException: If the file size is too large
"""
total_batch_size = 0
for file in files:
if file.size > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"{file.filename} exceeds the maximum allowed size of {MAX_FILE_SIZE} bytes.",
)
total_batch_size += file.size

existing_files = file_crud.get_files_by_user_id(session, user_id)
total_file_size = sum([f.file_size for f in existing_files]) + total_batch_size

if total_file_size > MAX_TOTAL_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"Total file size exceeds the maximum allowed size of {MAX_TOTAL_FILE_SIZE} bytes.",
)
28 changes: 28 additions & 0 deletions src/backend/tests/crud/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,34 @@ def test_create_file(session, user):
assert file.user_id == file_data.user_id


def test_batch_create_files(session, user):
file_data = File(
file_name="test.txt",
file_path="/tmp/test.txt",
file_size=100,
conversation_id="1",
user_id="1",
)
file_data2 = File(
file_name="test2.txt",
file_path="/tmp/test2.txt",
file_size=100,
conversation_id="1",
user_id="1",
)

files = file_crud.batch_create_files(session, [file_data, file_data2])
assert len(files) == 2

files = file_crud.get_files(session, user.id)
assert len(files) == 2
assert all(file.file_name in ["test.txt", "test2.txt"] for file in files) == True
assert files[0].conversation_id == "1"
assert files[1].conversation_id == "1"
assert files[0].user_id == "1"
assert files[1].user_id == "1"


def test_get_file(session, user):
_ = get_factory("File", session).create(
id="1", file_name="test.txt", conversation_id="1", user_id=user.id
Expand Down
1 change: 1 addition & 0 deletions src/backend/tests/factories/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ class Meta:
conversation_id = factory.Faker("uuid4")
file_name = factory.Faker("file_name")
file_path = factory.Faker("file_path")
file_size = factory.Faker("random_int", min=1, max=20000000)
Loading

0 comments on commit 6c2312d

Please sign in to comment.