Skip to content

Commit

Permalink
Merge pull request #104 from tjmlabs/fix-paths-bug
Browse files Browse the repository at this point in the history
Prevent users from passing paths to the embeddings endpoint
  • Loading branch information
Jonathan-Adly authored Dec 4, 2024
2 parents c62caad + 142b66d commit 69c183d
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 10 deletions.
77 changes: 68 additions & 9 deletions web/api/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,8 @@
from accounts.models import CustomUser
from api.middleware import add_slash
from api.models import Collection, Document, Page, PageEmbedding
from api.views import (
Bearer,
QueryFilter,
QueryIn,
filter_collections,
filter_documents,
filter_query,
router,
)
from api.views import (Bearer, QueryFilter, QueryIn, filter_collections,
filter_documents, filter_query, router)
from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import override_settings
Expand Down Expand Up @@ -1684,6 +1677,72 @@ async def test_create_embedding(async_client, user):
assert response.json()["data"] != []


async def test_create_embedding_invalid_input(async_client, user):
task = "image"
input_data = ["/Users/user/Desktop/image.png"]
response = await async_client.post(
"/embeddings/",
json={"task": task, "input_data": input_data},
headers={"Authorization": f"Bearer {user.token}"},
)

assert response.status_code == 422
assert response.json() == {
"detail": [
{
"type": "value_error",
"loc": ["body", "payload"],
"msg": "Value error, Each input must be a valid base64 string or a URL. Please use our Python SDK if you want to provide a file path.",
"ctx": {
"error": "Each input must be a valid base64 string or a URL. Please use our Python SDK if you want to provide a file path."
},
}
]
}


async def test_create_embedding_valid_url_service_down(async_client, user):
task = "image"
input_data = ["https://tourism.gov.in/sites/default/files/2019-04/dummy-pdf_2.pdf"]
EMBEDDINGS_POST_PATH = "api.models.aiohttp.ClientSession.post"
# Create a mock response object with status 500
mock_response = AsyncMock()
mock_response.status = 500
mock_response.json.return_value = AsyncMock(return_value={"error": "Service Down"})
# Mock the context manager __aenter__ to return the mock_response
mock_response.__aenter__.return_value = mock_response
# Patch the aiohttp.ClientSession.post method to return the mock_response
with patch(EMBEDDINGS_POST_PATH, return_value=mock_response):
response = await async_client.post(
"/embeddings/",
json={"task": task, "input_data": input_data},
headers={"Authorization": f"Bearer {user.token}"},
)
assert response.status_code == 503


async def test_create_embedding_valid_base64_service_down(async_client, user):
task = "image"
input_data = [
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
]
EMBEDDINGS_POST_PATH = "api.models.aiohttp.ClientSession.post"
# Create a mock response object with status 500
mock_response = AsyncMock()
mock_response.status = 500
mock_response.json.return_value = AsyncMock(return_value={"error": "Service Down"})
# Mock the context manager __aenter__ to return the mock_response
mock_response.__aenter__.return_value = mock_response
# Patch the aiohttp.ClientSession.post method to return the mock_response
with patch(EMBEDDINGS_POST_PATH, return_value=mock_response):
response = await async_client.post(
"/embeddings/",
json={"task": task, "input_data": input_data},
headers={"Authorization": f"Bearer {user.token}"},
)
assert response.status_code == 503


async def test_create_embedding_service_down(async_client, user):
task = "query"
input_data = ["What is 1 + 1"]
Expand Down
23 changes: 22 additions & 1 deletion web/api/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import base64
import logging
import re
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse

import aiohttp
from accounts.models import CustomUser
Expand All @@ -18,7 +20,8 @@
from ninja.security import HttpBearer
from pgvector.utils import HalfVector
from pydantic import Field, model_validator
from svix.api import ApplicationIn, EndpointIn, EndpointUpdate, MessageIn, SvixAsync
from svix.api import (ApplicationIn, EndpointIn, EndpointUpdate, MessageIn,
SvixAsync)
from typing_extensions import Self

from .models import Collection, Document, MaxSim, Page
Expand Down Expand Up @@ -1266,6 +1269,24 @@ class EmbeddingsIn(Schema):
input_data: List[str]
task: TaskEnum

@model_validator(mode="after")
def validate_input_data(self) -> Self:
if self.task == TaskEnum.image:
for value in self.input_data:
# Validate base64
base64_pattern = r"^[A-Za-z0-9+/]+={0,2}$"
is_base64 = re.match(base64_pattern, value) and len(value) % 4 == 0

# Validate URL
parsed = urlparse(value)
is_url = all([parsed.scheme, parsed.netloc])

if not (is_base64 or is_url):
raise ValueError(
"Each input must be a valid base64 string or a URL. Please use our Python SDK if you want to provide a file path."
)
return self


class EmbeddingsOut(Schema):
_object: str
Expand Down

0 comments on commit 69c183d

Please sign in to comment.