Skip to content

Commit

Permalink
Merge pull request #985 from julep-ai/f/dsearch-queries
Browse files Browse the repository at this point in the history
F/dsearch queries: Add doc search sql queries
  • Loading branch information
Vedantsahai18 authored Dec 24, 2024
2 parents 358b60b + 23235f4 commit 34586d2
Show file tree
Hide file tree
Showing 15 changed files with 325 additions and 334 deletions.
9 changes: 6 additions & 3 deletions agents-api/agents_api/queries/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- Deleting documents by their unique identifiers.
- Embedding document snippets for retrieval purposes.
- Searching documents by text.
- Searching documents by hybrid text and embedding.
- Searching documents by embedding.
The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users.
Expand All @@ -21,15 +23,16 @@
from .delete_doc import delete_doc
from .get_doc import get_doc
from .list_docs import list_docs

# from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text
from .search_docs_hybrid import search_docs_hybrid

__all__ = [
"create_doc",
"delete_doc",
"get_doc",
"list_docs",
# "search_docs_by_embedding",
"search_docs_by_embedding",
"search_docs_by_text",
"search_docs_hybrid",
]
28 changes: 18 additions & 10 deletions agents-api/agents_api/queries/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
from typing import Any, List, Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException

from ...autogen.openapi_model import DocReference
from ..utils import pg_query, wrap_in_class
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Raw query for vector search
search_docs_by_embedding_query = """
SELECT * FROM search_by_vector(
$1, -- developer_id
$2::vector(1024), -- query_embedding
$3::text[], -- owner_types
$UUID_LIST::uuid[], -- owner_ids
$4, -- k
$5, -- confidence
$6 -- metadata_filter
$4::uuid[], -- owner_ids
$5, -- k
$6, -- confidence
$7 -- metadata_filter
)
"""


@rewrap_exceptions(
{
asyncpg.UniqueViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer does not exist.",
)
}
)
@wrap_in_class(
DocReference,
transform=lambda d: {
Expand Down Expand Up @@ -69,16 +80,13 @@ async def search_docs_by_embedding(
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]

# NOTE: Manually replace uuids list coz asyncpg isnt sending it correctly
owner_ids_pg_str = f"ARRAY['{'\', \''.join(owner_ids)}']"
query = search_docs_by_embedding_query.replace("$UUID_LIST", owner_ids_pg_str)

return (
query,
search_docs_by_embedding_query,
[
developer_id,
query_embedding_str,
owner_types,
owner_ids,
k,
confidence,
metadata_filter,
Expand Down
16 changes: 7 additions & 9 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from ...autogen.openapi_model import DocReference
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Raw query for text search
search_docs_text_query = """
SELECT * FROM search_by_text(
$1, -- developer_id
$2, -- query
$3, -- owner_types
$UUID_LIST::uuid[], -- owner_ids
$4, -- search_language
$5, -- k
$6 -- metadata_filter
$4, -- owner_ids
$5, -- search_language
$6, -- k
$7 -- metadata_filter
)
"""

Expand Down Expand Up @@ -74,16 +75,13 @@ async def search_docs_by_text(
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]

# NOTE: Manually replace uuids list coz asyncpg isnt sending it correctly
owner_ids_pg_str = f"ARRAY['{'\', \''.join(owner_ids)}']"
query = search_docs_text_query.replace("$UUID_LIST", owner_ids_pg_str)

return (
query,
search_docs_text_query,
[
developer_id,
query,
owner_types,
owner_ids,
search_language,
k,
metadata_filter,
Expand Down
237 changes: 94 additions & 143 deletions agents-api/agents_api/queries/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
@@ -1,158 +1,109 @@
from typing import List, Literal
from typing import Any, List, Literal
from uuid import UUID

import asyncpg
from beartype import beartype

from ...autogen.openapi_model import Doc
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text


def dbsf_normalize(scores: List[float]) -> List[float]:
"""
Example distribution-based normalization: clamp each score
from (mean - 3*stddev) to (mean + 3*stddev) and scale to 0..1
"""
import statistics

if len(scores) < 2:
return scores
m = statistics.mean(scores)
sd = statistics.pstdev(scores) # population std
if sd == 0:
return scores
upper = m + 3 * sd
lower = m - 3 * sd

def clamp_scale(v):
c = min(upper, max(lower, v))
return (c - lower) / (upper - lower)

return [clamp_scale(s) for s in scores]


@beartype
def fuse_results(
text_docs: List[Doc], embedding_docs: List[Doc], alpha: float
) -> List[Doc]:
"""
Merges text search results (descending by text rank) with
embedding results (descending by closeness or inverse distance).
alpha ~ how much to weigh the embedding score
"""
# Suppose we stored each doc's "distance" from the embedding query, and
# for text search we store a rank or negative distance. We'll unify them:
# Make up a dictionary of doc_id -> text_score, doc_id -> embed_score
# For example, text_score = -distance if you want bigger = better
text_scores = {}
embed_scores = {}
for doc in text_docs:
# If you had "rank", you might store doc.distance = rank
# For demo, let's assume doc.distance is negative... up to you
text_scores[doc.id] = float(-doc.distance if doc.distance else 0)

for doc in embedding_docs:
# Lower distance => better, so we do embed_score = -distance
embed_scores[doc.id] = float(-doc.distance if doc.distance else 0)

# Normalize them
text_vals = list(text_scores.values())
embed_vals = list(embed_scores.values())
text_vals_norm = dbsf_normalize(text_vals)
embed_vals_norm = dbsf_normalize(embed_vals)

# Map them back
t_keys = list(text_scores.keys())
for i, key in enumerate(t_keys):
text_scores[key] = text_vals_norm[i]
e_keys = list(embed_scores.keys())
for i, key in enumerate(e_keys):
embed_scores[key] = embed_vals_norm[i]

# Gather all doc IDs
all_ids = set(text_scores.keys()) | set(embed_scores.keys())

# Weighted sum => combined
out = []
for doc_id in all_ids:
# text and embed might be missing doc_id => 0
t_score = text_scores.get(doc_id, 0)
e_score = embed_scores.get(doc_id, 0)
combined = alpha * e_score + (1 - alpha) * t_score
# We'll store final "distance" as -(combined) so bigger combined => smaller distance
out.append((doc_id, combined))

# Sort descending by combined
out.sort(key=lambda x: x[1], reverse=True)

# Convert to doc objects. We can pick from text_docs or embedding_docs or whichever is found.
# If present in both, we can merge fields. For simplicity, just pick from text_docs then fallback embedding_docs.

# Create a quick ID->doc map
text_map = {d.id: d for d in text_docs}
embed_map = {d.id: d for d in embedding_docs}

final_docs = []
for doc_id, score in out:
doc = text_map.get(doc_id) or embed_map.get(doc_id)
doc = doc.model_copy() # or copy if you are using Pydantic
doc.distance = float(-score) # so a higher combined => smaller distance
final_docs.append(doc)
return final_docs


from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import DocReference
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Raw query for hybrid search
search_docs_hybrid_query = """
SELECT * FROM search_hybrid(
$1, -- developer_id
$2, -- text_query
$3::vector(1024), -- embedding
$4::text[], -- owner_types
$5::uuid[], -- owner_ids
$6, -- k
$7, -- alpha
$8, -- confidence
$9, -- metadata_filter
$10 -- search_language
)
"""


@rewrap_exceptions(
{
asyncpg.UniqueViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer does not exist.",
)
}
)
@wrap_in_class(
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
"metadata": d.get("metadata", {}),
**d,
},
)
@pg_query
@beartype
async def search_docs_hybrid(
developer_id: UUID,
owners: list[tuple[Literal["user", "agent"], UUID]],
text_query: str = "",
embedding: List[float] = None,
k: int = 10,
alpha: float = 0.5,
owner_type: Literal["user", "agent", "org"] | None = None,
owner_id: UUID | None = None,
) -> List[Doc]:
metadata_filter: dict[str, Any] = {},
search_language: str = "english",
confidence: float = 0.5,
) -> tuple[str, list]:
"""
Hybrid text-and-embedding doc search. We get top-K from each approach,
then fuse them client-side. Adjust concurrency or approach as you like.
"""
# We'll dispatch two queries in parallel
# (One full-text, one embedding-based) each limited to K
tasks = []
if text_query.strip():
tasks.append(
search_docs_by_text(
developer_id=developer_id,
query=text_query,
k=k,
owner_type=owner_type,
owner_id=owner_id,
)
)
else:
tasks.append([]) # no text results if query is empty
if embedding and any(embedding):
tasks.append(
search_docs_by_embedding(
developer_id=developer_id,
query_embedding=embedding,
k=k,
owner_type=owner_type,
owner_id=owner_id,
)
)
else:
tasks.append([])

# Run concurrently (or sequentially, if you prefer)
# If you have a 'run_concurrently' from your old code, you can do:
# text_results, embed_results = await run_concurrently([task1, task2])
# Otherwise just do them in parallel with e.g. asyncio.gather:
from asyncio import gather

text_results, embed_results = await gather(*tasks)
Parameters:
developer_id (UUID): The unique identifier for the developer.
text_query (str): The text query to search for.
embedding (List[float]): The embedding to search for.
k (int): The number of results to return.
alpha (float): The weight for the embedding results.
owner_type (Literal["user", "agent", "org"] | None): The type of the owner.
owner_id (UUID | None): The ID of the owner.
Returns:
tuple[str, list]: The SQL query and parameters for the search.
"""

# fuse them
fused = fuse_results(text_results, embed_results, alpha)
# Then pick top K overall
return fused[:k]
if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")

if not text_query and not embedding:
raise HTTPException(status_code=400, detail="Empty query provided")

if not embedding:
raise HTTPException(status_code=400, detail="Empty embedding provided")

# Convert query_embedding to a string
embedding_str = f"[{', '.join(map(str, embedding))}]"

# Extract owner types and IDs
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]

return (
search_docs_hybrid_query,
[
developer_id,
text_query,
embedding_str,
owner_types,
owner_ids,
k,
alpha,
confidence,
metadata_filter,
search_language,
],
)
10 changes: 10 additions & 0 deletions agents-api/agents_api/queries/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,13 @@
from .list_tools import list_tools
from .patch_tool import patch_tool
from .update_tool import update_tool

__all__ = [
"create_tools",
"delete_tool",
"get_tool",
"get_tool_args_from_metadata",
"list_tools",
"patch_tool",
"update_tool",
]
Loading

0 comments on commit 34586d2

Please sign in to comment.