diff --git a/web/api/tests/tests.py b/web/api/tests/tests.py index 1624991..d6c7d13 100644 --- a/web/api/tests/tests.py +++ b/web/api/tests/tests.py @@ -1346,6 +1346,44 @@ async def test_search_documents(async_client, user, collection, document): assert response.json() != [] +async def test_search_image(async_client, user, collection, document): + response = await async_client.post( + "/search-image/", + json={ + "img_base64": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=", + "top_k": 1, + }, + headers={"Authorization": f"Bearer {user.token}"}, + ) + assert response.status_code == 200 + assert response.json() != [] + + +async def test_search_image_invalid_base64(async_client, user, collection, document): + response = await async_client.post( + "/search-image/", + json={ + "img_base64": "Hello", + "top_k": 1, + }, + headers={"Authorization": f"Bearer {user.token}"}, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "value_error", + "loc": ["body", "payload"], + "msg": "Value error, Provided 'base64' is not valid. Please provide a valid base64 string.", + "ctx": { + "error": "Provided 'base64' is not valid. Please provide a valid base64 string." + }, + } + ] + } + + async def test_filter_collections(async_client, user, collection, document): response = await async_client.post( "/filter/", @@ -1453,6 +1491,22 @@ async def test_search_filter_collection_name(async_client, user, collection, doc assert response.json() != [] +async def test_search_image_filter_collection_name( + async_client, user, collection, document +): + response = await async_client.post( + "/search-image/", + json={ + "img_base64": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=", + "top_k": 1, + "collection_name": collection.name, + }, + headers={"Authorization": f"Bearer {user.token}"}, + ) + assert response.status_code == 200 + assert response.json() != [] + + async def test_search_filter_key_equals(async_client, user, search_filter_fixture): collection, document_1, document_2 = search_filter_fixture @@ -2126,6 +2180,31 @@ async def test_embedding_service_down_query(async_client, user): assert response.status_code == 503 +async def test_embedding_service_down_search_image(async_client, user): + EMBEDDINGS_POST_PATH = "api.views.aiohttp.ClientSession.post" + # Create a mock response object with status 500 + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.json.return_value = AsyncMock(return_value={"error": "Service Down"}) + + # Mock the context manager __aenter__ to return the mock_response + mock_response.__aenter__.return_value = mock_response + + # Patch the aiohttp.ClientSession.post method to return the mock_response + with patch(EMBEDDINGS_POST_PATH, return_value=mock_response): + # Perform the POST request to trigger embed_document + response = await async_client.post( + "/search-image/", + json={ + "img_base64": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=", + "top_k": 1, + }, + headers={"Authorization": f"Bearer {user.token}"}, + ) + + assert response.status_code == 503 + + async def test_embed_document_arxiv_await(async_client, user): response = await async_client.post( "/documents/upsert-document/", diff --git a/web/api/views.py b/web/api/views.py index 1214a10..f6c1292 100644 --- a/web/api/views.py +++ b/web/api/views.py @@ -915,6 +915,27 @@ class QueryIn(Schema): query_filter: Optional[QueryFilter] = None +class SearchImageIn(Schema): + img_base64: str + collection_name: Optional[str] = "all" + top_k: Optional[int] = 3 + query_filter: Optional[QueryFilter] = None + + @model_validator(mode="after") + def base64(self) -> Self: + # Validate base64 + base64_pattern = r"^[A-Za-z0-9+/]+={0,2}$" + is_base64 = ( + re.match(base64_pattern, self.img_base64) and len(self.img_base64) % 4 == 0 + ) + + if not is_base64: + raise ValueError( + "Provided 'base64' is not valid. Please provide a valid base64 string." + ) + return self + + class PageOutQuery(Schema): collection_name: str collection_id: int @@ -933,6 +954,10 @@ class QueryOut(Schema): results: List[PageOutQuery] +class SearchImageOut(Schema): + results: List[PageOutQuery] + + @router.post( "/search/", tags=["search"], @@ -1037,6 +1062,110 @@ async def search( return 200, QueryOut(query=payload.query, results=formatted_results) +@router.post( + "/search-image/", + tags=["search"], + auth=Bearer(), + response={200: SearchImageOut, 503: GenericError}, +) +async def search_image( + request: Request, payload: SearchImageIn +) -> Tuple[int, SearchImageOut] | Tuple[int, GenericError]: + """ + Search for pages similar to a given image. + + This endpoint allows the user to search for pages similar to a given image. + The search is performed across all documents in the specified collection. + + Args: + request: The HTTP request object, which includes the user information. + payload (SearchImageIn): The input data for the search, which includes the image in base64 format and collection ID. + + Returns: + SearchImageOut: The search results, a list of similar pages. + + Raises: + HttpError: If the collection does not exist or the img_base64 is invalid. + + Example: + POST /search-image/ + { + "img_base64": "base64_string", + "collection_name": "my_collection", + "top_k": 3, + "query_filter": { + "on": "document", + "key": "breed", + "value": "collie", + "lookup": "contains" + } + } + """ + image_embeddings = await get_image_embeddings(payload.img_base64) + if not image_embeddings: + return 503, GenericError( + detail="Failed to get embeddings from the embeddings service" + ) + query_length = len(image_embeddings) # we need this for normalization + + # we want to cast the embeddings to halfvec + casted_image_embeddings = [ + HalfVector(embedding).to_text() for embedding in image_embeddings + ] + + # building the query: + + # 1. filter the pages based on the collection_id and the query_filter + base_query = await filter_query(payload, request.auth) + + # 2. annotate the query with the max sim score + # maxsim needs 2 arrays of embeddings, one for the pages and one for the query + pages_query = ( + base_query.annotate(page_embeddings=ArrayAgg("embeddings__embedding")) + .annotate(max_sim=MaxSim("page_embeddings", casted_image_embeddings)) + .order_by("-max_sim")[: payload.top_k or 3] + ) + # 3. execute the query + results = pages_query.values( + "id", + "page_number", + "img_base64", + "document__id", + "document__name", + "document__metadata", + "document__collection__id", + "document__collection__name", + "document__collection__metadata", + "max_sim", + ) + # Normalization + normalization_factor = query_length + + # Format the results + formatted_results = [ + PageOutQuery( + collection_name=row["document__collection__name"], + collection_id=row["document__collection__id"], + collection_metadata=( + row["document__collection__metadata"] + if row["document__collection__metadata"] + else {} + ), + document_name=row["document__name"], + document_id=row["document__id"], + document_metadata=( + row["document__metadata"] if row["document__metadata"] else {} + ), + page_number=row["page_number"], + raw_score=row["max_sim"], + normalized_score=row["max_sim"] / normalization_factor, + img_base64=row["img_base64"], + ) + async for row in results + ] + return 200, SearchImageOut(results=formatted_results) + + @router.post( "/filter/", tags=["filter"], @@ -1145,7 +1274,34 @@ async def get_query_embeddings(query: str) -> List: return out["output"]["data"][0]["embedding"] -async def filter_query(payload: QueryIn, user: CustomUser) -> QuerySet[Page]: +async def get_image_embeddings(img_base64: str) -> List: + EMBEDDINGS_URL = settings.ALWAYS_ON_EMBEDDINGS_URL + embed_token = settings.EMBEDDINGS_URL_TOKEN + headers = {"Authorization": f"Bearer {embed_token}"} + payload = { + "input": { + "task": "image", + "input_data": [img_base64], + } + } + async with aiohttp.ClientSession() as session: + async with session.post( + EMBEDDINGS_URL, json=payload, headers=headers + ) as response: + if response.status != 200: + logger.error( + f"Failed to get embeddings from the embeddings service: {response.status}" + ) + return [] + out = await response.json() + # returning a dynamic array of embeddings, each of which is a list of 128 floats + # example: [[0.1, 0.2, 0.3, ...], [0.4, 0.5, 0.6, ...]] + return out["output"]["data"][0]["embedding"] + + +async def filter_query( + payload: Union[QueryIn, SearchImageIn], user: CustomUser +) -> QuerySet[Page]: base_query = Page.objects.select_related("document__collection") if payload.collection_name == "all": base_query = base_query.filter(document__collection__owner=user)