Skip to content

Commit

Permalink
Merge pull request #130 from tjmlabs/image-search
Browse files Browse the repository at this point in the history
Add search-image endpoint
  • Loading branch information
Jonathan-Adly authored Jan 7, 2025
2 parents 444f7f6 + ea5dc3f commit e0456ad
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 1 deletion.
79 changes: 79 additions & 0 deletions web/api/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,44 @@ async def test_search_documents(async_client, user, collection, document):
assert response.json() != []


async def test_search_image(async_client, user, collection, document):
response = await async_client.post(
"/search-image/",
json={
"img_base64": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=",
"top_k": 1,
},
headers={"Authorization": f"Bearer {user.token}"},
)
assert response.status_code == 200
assert response.json() != []


async def test_search_image_invalid_base64(async_client, user, collection, document):
response = await async_client.post(
"/search-image/",
json={
"img_base64": "Hello",
"top_k": 1,
},
headers={"Authorization": f"Bearer {user.token}"},
)

assert response.status_code == 422
assert response.json() == {
"detail": [
{
"type": "value_error",
"loc": ["body", "payload"],
"msg": "Value error, Provided 'base64' is not valid. Please provide a valid base64 string.",
"ctx": {
"error": "Provided 'base64' is not valid. Please provide a valid base64 string."
},
}
]
}


async def test_filter_collections(async_client, user, collection, document):
response = await async_client.post(
"/filter/",
Expand Down Expand Up @@ -1453,6 +1491,22 @@ async def test_search_filter_collection_name(async_client, user, collection, doc
assert response.json() != []


async def test_search_image_filter_collection_name(
async_client, user, collection, document
):
response = await async_client.post(
"/search-image/",
json={
"img_base64": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=",
"top_k": 1,
"collection_name": collection.name,
},
headers={"Authorization": f"Bearer {user.token}"},
)
assert response.status_code == 200
assert response.json() != []


async def test_search_filter_key_equals(async_client, user, search_filter_fixture):
collection, document_1, document_2 = search_filter_fixture

Expand Down Expand Up @@ -2126,6 +2180,31 @@ async def test_embedding_service_down_query(async_client, user):
assert response.status_code == 503


async def test_embedding_service_down_search_image(async_client, user):
EMBEDDINGS_POST_PATH = "api.views.aiohttp.ClientSession.post"
# Create a mock response object with status 500
mock_response = AsyncMock()
mock_response.status = 500
mock_response.json.return_value = AsyncMock(return_value={"error": "Service Down"})

# Mock the context manager __aenter__ to return the mock_response
mock_response.__aenter__.return_value = mock_response

# Patch the aiohttp.ClientSession.post method to return the mock_response
with patch(EMBEDDINGS_POST_PATH, return_value=mock_response):
# Perform the POST request to trigger embed_document
response = await async_client.post(
"/search-image/",
json={
"img_base64": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=",
"top_k": 1,
},
headers={"Authorization": f"Bearer {user.token}"},
)

assert response.status_code == 503


async def test_embed_document_arxiv_await(async_client, user):
response = await async_client.post(
"/documents/upsert-document/",
Expand Down
158 changes: 157 additions & 1 deletion web/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,27 @@ class QueryIn(Schema):
query_filter: Optional[QueryFilter] = None


class SearchImageIn(Schema):
img_base64: str
collection_name: Optional[str] = "all"
top_k: Optional[int] = 3
query_filter: Optional[QueryFilter] = None

@model_validator(mode="after")
def base64(self) -> Self:
# Validate base64
base64_pattern = r"^[A-Za-z0-9+/]+={0,2}$"
is_base64 = (
re.match(base64_pattern, self.img_base64) and len(self.img_base64) % 4 == 0
)

if not is_base64:
raise ValueError(
"Provided 'base64' is not valid. Please provide a valid base64 string."
)
return self


class PageOutQuery(Schema):
collection_name: str
collection_id: int
Expand All @@ -933,6 +954,10 @@ class QueryOut(Schema):
results: List[PageOutQuery]


class SearchImageOut(Schema):
results: List[PageOutQuery]


@router.post(
"/search/",
tags=["search"],
Expand Down Expand Up @@ -1037,6 +1062,110 @@ async def search(
return 200, QueryOut(query=payload.query, results=formatted_results)


@router.post(
"/search-image/",
tags=["search"],
auth=Bearer(),
response={200: SearchImageOut, 503: GenericError},
)
async def search_image(
request: Request, payload: SearchImageIn
) -> Tuple[int, SearchImageOut] | Tuple[int, GenericError]:
"""
Search for pages similar to a given image.
This endpoint allows the user to search for pages similar to a given image.
The search is performed across all documents in the specified collection.
Args:
request: The HTTP request object, which includes the user information.
payload (SearchImageIn): The input data for the search, which includes the image in base64 format and collection ID.
Returns:
SearchImageOut: The search results, a list of similar pages.
Raises:
HttpError: If the collection does not exist or the img_base64 is invalid.
Example:
POST /search-image/
{
"img_base64": "base64_string",
"collection_name": "my_collection",
"top_k": 3,
"query_filter": {
"on": "document",
"key": "breed",
"value": "collie",
"lookup": "contains"
}
}
"""
image_embeddings = await get_image_embeddings(payload.img_base64)
if not image_embeddings:
return 503, GenericError(
detail="Failed to get embeddings from the embeddings service"
)
query_length = len(image_embeddings) # we need this for normalization

# we want to cast the embeddings to halfvec
casted_image_embeddings = [
HalfVector(embedding).to_text() for embedding in image_embeddings
]

# building the query:

# 1. filter the pages based on the collection_id and the query_filter
base_query = await filter_query(payload, request.auth)

# 2. annotate the query with the max sim score
# maxsim needs 2 arrays of embeddings, one for the pages and one for the query
pages_query = (
base_query.annotate(page_embeddings=ArrayAgg("embeddings__embedding"))
.annotate(max_sim=MaxSim("page_embeddings", casted_image_embeddings))
.order_by("-max_sim")[: payload.top_k or 3]
)
# 3. execute the query
results = pages_query.values(
"id",
"page_number",
"img_base64",
"document__id",
"document__name",
"document__metadata",
"document__collection__id",
"document__collection__name",
"document__collection__metadata",
"max_sim",
)
# Normalization
normalization_factor = query_length

# Format the results
formatted_results = [
PageOutQuery(
collection_name=row["document__collection__name"],
collection_id=row["document__collection__id"],
collection_metadata=(
row["document__collection__metadata"]
if row["document__collection__metadata"]
else {}
),
document_name=row["document__name"],
document_id=row["document__id"],
document_metadata=(
row["document__metadata"] if row["document__metadata"] else {}
),
page_number=row["page_number"],
raw_score=row["max_sim"],
normalized_score=row["max_sim"] / normalization_factor,
img_base64=row["img_base64"],
)
async for row in results
]
return 200, SearchImageOut(results=formatted_results)


@router.post(
"/filter/",
tags=["filter"],
Expand Down Expand Up @@ -1145,7 +1274,34 @@ async def get_query_embeddings(query: str) -> List:
return out["output"]["data"][0]["embedding"]


async def filter_query(payload: QueryIn, user: CustomUser) -> QuerySet[Page]:
async def get_image_embeddings(img_base64: str) -> List:
EMBEDDINGS_URL = settings.ALWAYS_ON_EMBEDDINGS_URL
embed_token = settings.EMBEDDINGS_URL_TOKEN
headers = {"Authorization": f"Bearer {embed_token}"}
payload = {
"input": {
"task": "image",
"input_data": [img_base64],
}
}
async with aiohttp.ClientSession() as session:
async with session.post(
EMBEDDINGS_URL, json=payload, headers=headers
) as response:
if response.status != 200:
logger.error(
f"Failed to get embeddings from the embeddings service: {response.status}"
)
return []
out = await response.json()
# returning a dynamic array of embeddings, each of which is a list of 128 floats
# example: [[0.1, 0.2, 0.3, ...], [0.4, 0.5, 0.6, ...]]
return out["output"]["data"][0]["embedding"]


async def filter_query(
payload: Union[QueryIn, SearchImageIn], user: CustomUser
) -> QuerySet[Page]:
base_query = Page.objects.select_related("document__collection")
if payload.collection_name == "all":
base_query = base_query.filter(document__collection__owner=user)
Expand Down

0 comments on commit e0456ad

Please sign in to comment.