Skip to content

Commit

Permalink
Added reranking via model_gateway class (#253)
Browse files Browse the repository at this point in the history
* Added reranking via model_gateway class

* Added Auth for infinity API

---------

Co-authored-by: Abhishek Choudhary <[email protected]>
  • Loading branch information
S1LV3RJ1NX and innoavator authored Jun 25, 2024
1 parent 394b851 commit e08207c
Show file tree
Hide file tree
Showing 18 changed files with 180 additions and 239 deletions.
12 changes: 6 additions & 6 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ RUN if [ "${ADD_VECTORDB}" = "1" ]; then python3 -m pip install --use-pep517 --n
ARG ADD_PRISMA=0
RUN if [ "${ADD_PRISMA}" = "1" ]; then prisma version; fi

# TODO (chiragjn): These should be removed from here and directly added as environment variables
# Temporary addition until templates have been updated using build args as environment variables
ARG ADD_RERANKER_SVC_URL=""
ENV RERANKER_SVC_URL=${ADD_RERANKER_SVC_URL}
ARG ADD_EMBEDDING_SVC_URL=""
ENV EMBEDDING_SVC_URL=${ADD_EMBEDDING_SVC_URL}
# TODO: Remove these when templates inject env vars
ARG MODELS_CONFIG_PATH
ENV MODELS_CONFIG_PATH=${MODELS_CONFIG_PATH}

ARG INFINITY_API_KEY
ENV INFINITY_API_KEY=${INFINITY_API_KEY}

# Copy the project files
COPY . /app
Expand Down
43 changes: 42 additions & 1 deletion backend/modules/model_gateway/model_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from backend.settings import settings
from backend.types import ModelConfig, ModelProviderConfig, ModelType

from .reranker_svc import InfinityRerankerSvc


class ModelGateway:
provider_configs: List[ModelProviderConfig]
Expand All @@ -20,8 +22,9 @@ def __init__(self):
logger.info(f"Loading models config from {settings.MODELS_CONFIG_PATH}")
with open(settings.MODELS_CONFIG_PATH) as f:
data = yaml.safe_load(f)
print(data)
logger.info(f"Loaded models config: {data}")
_providers = data.get("model_providers") or []

# parse the json data into a list of ModelProviderConfig objects
self.provider_configs = [
ModelProviderConfig.parse_obj(item) for item in _providers
Expand All @@ -32,6 +35,9 @@ def __init__(self):
# load embedding models
self.embedding_models: List[ModelConfig] = []

# load reranker models
self.reranker_models: List[ModelConfig] = []

for provider_config in self.provider_configs:
if provider_config.api_key_env_var and not os.environ.get(
provider_config.api_key_env_var
Expand Down Expand Up @@ -65,12 +71,27 @@ def __init__(self):
)
)

for model_id in provider_config.reranking_model_ids:
model_name = f"{provider_config.provider_name}/{model_id}"
self.model_name_to_provider_config[model_name] = provider_config

# Register the model as a reranker model
self.reranker_models.append(
ModelConfig(
name=f"{provider_config.provider_name}/{model_id}",
type=ModelType.reranking,
)
)

def get_embedding_models(self) -> List[ModelConfig]:
return self.embedding_models

def get_llm_models(self) -> List[ModelConfig]:
return self.llm_models

def get_reranker_models(self) -> List[ModelConfig]:
return self.reranker_models

def get_embedder_from_model_config(self, model_name: str) -> Embeddings:
if model_name not in self.model_name_to_provider_config:
raise ValueError(f"Model {model_name} not registered in the model gateway.")
Expand Down Expand Up @@ -116,5 +137,25 @@ def get_llm_from_model_config(
base_url=model_provider_config.base_url,
)

def get_reranker_from_model_config(self, model_name: str, top_k: int = 3):
if model_name not in self.model_name_to_provider_config:
raise ValueError(f"Model {model_name} not registered in the model gateway.")

model_provider_config: ModelProviderConfig = self.model_name_to_provider_config[
model_name
]
if not model_provider_config.api_key_env_var:
api_key = "EMPTY"
else:
api_key = os.environ.get(model_provider_config.api_key_env_var, "")
model_id = "/".join(model_name.split("/")[1:])

return InfinityRerankerSvc(
model=model_id,
api_key=api_key,
base_url=model_provider_config.base_url,
top_k=top_k,
)


model_gateway = ModelGateway()
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class InfinityRerankerSvc(BaseDocumentCompressor):
"""

model: str
top_k: int = 3
url = settings.RERANKER_SVC_URL
top_k: int
base_url: str
api_key: str

def compress_documents(
self,
Expand All @@ -36,8 +37,13 @@ def compress_documents(
"model": self.model,
}

headers = {
"Authorization": f"Bearer {settings.INFINITY_API_KEY}",
"Content-Type": "application/json",
}

reranked_docs = requests.post(
self.url.rstrip("/") + "/rerank", json=payload
self.base_url.rstrip("/") + "/rerank", headers=headers, json=payload
).json()

"""
Expand Down
2 changes: 1 addition & 1 deletion backend/modules/query_controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
)
from backend.modules.query_controllers.query_controller import register_query_controller

register_query_controller("default", BasicRAGQueryController)
register_query_controller("basic-rag", BasicRAGQueryController)
register_query_controller("multimodal", MultiModalRAGQueryController)
46 changes: 14 additions & 32 deletions backend/modules/query_controllers/example/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,18 @@
GENERATION_TIMEOUT_SEC,
ExampleQueryInput,
)
from backend.modules.rerankers.reranker_svc import InfinityRerankerSvc

# from backend.modules.rerankers.reranker_svc import InfinityRerankerSvc
from backend.modules.vector_db.client import VECTOR_STORE_CLIENT
from backend.server.decorators import post, query_controller
from backend.settings import settings
from backend.types import Collection, ModelConfig

EXAMPLES = {
"vector-store-similarity": QUERY_WITH_VECTOR_STORE_RETRIEVER_PAYLOAD,
"contextual-compression-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_PAYLOAD,
"contextual-compression-multi-query-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_MULTI_QUERY_RETRIEVER_SIMILARITY_PAYLOAD,
}

if settings.RERANKER_SVC_URL:
EXAMPLES.update(
{
"contextual-compression-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_PAYLOAD,
}
)
EXAMPLES.update(
{
"contextual-compression-multi-query-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_MULTI_QUERY_RETRIEVER_SIMILARITY_PAYLOAD,
}
)


@query_controller("/basic-rag")
class BasicRAGQueryController:
Expand Down Expand Up @@ -111,26 +101,18 @@ def _get_contextual_compression_retriever(self, vector_store, retriever_config):
Get the contextual compression retriever
"""
try:
if settings.RERANKER_SVC_URL:
retriever = self._get_vector_store_retriever(
vector_store, retriever_config
)
logger.info("Using MxBaiRerankerSmall th' service...")
compressor = InfinityRerankerSvc(
top_k=retriever_config.top_k,
model=retriever_config.compressor_model_name,
)
retriever = self._get_vector_store_retriever(vector_store, retriever_config)
logger.info("Using MxBaiRerankerSmall th' service...")

compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
compressor = model_gateway.get_reranker_from_model_config(
model_name=retriever_config.compressor_model_name,
top_k=retriever_config.top_k,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)

return compression_retriever
else:
raise HTTPException(
status_code=500,
detail="Reranker service is not available",
)
return compression_retriever
except Exception as e:
logger.error(f"Error in getting contextual compression retriever: {e}")
raise HTTPException(
Expand Down
22 changes: 8 additions & 14 deletions backend/modules/query_controllers/example/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity",
"search_kwargs": {"k": 10},
Expand Down Expand Up @@ -113,8 +112,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "mmr",
"search_kwargs": {
Expand All @@ -130,7 +128,7 @@
"description": """
Requires k and fetch_k in search kwargs for mmr.
search_type can either be similarity or mmr or similarity_score_threshold.
Currently only support for mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added.""",
Currently only support for local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added.""",
"value": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_SEARCH_TYPE_MMR,
}

Expand All @@ -147,8 +145,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity_score_threshold",
"search_kwargs": {"score_threshold": 0.7},
Expand All @@ -161,7 +158,7 @@
"description": """
Requires score_threshold float (0~1) in search kwargs for similarity search.
search_type can either be similarity or mmr or similarity_score_threshold.
Currently only support for mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added""",
Currently only support for local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added""",
"value": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_SEARCH_TYPE_SIMILARITY_WITH_SCORE,
}

Expand Down Expand Up @@ -273,8 +270,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression-multi-query",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "mmr",
"search_kwargs": {
Expand Down Expand Up @@ -309,8 +305,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression-multi-query",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity",
"search_kwargs": {"k": 10},
Expand Down Expand Up @@ -343,8 +338,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression-multi-query",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity_score_threshold",
"search_kwargs": {"score_threshold": 0.7},
Expand Down
23 changes: 1 addition & 22 deletions backend/modules/query_controllers/example/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ class MultiQueryRetrieverConfig(VectorStoreRetrieverConfig):


class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig):
compressor_model_provider: str = Field(
title="provider of the compressor model",
)

compressor_model_name: str = Field(
title="model name of the compressor",
)
Expand All @@ -85,14 +81,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig):
title="Top K docs to collect post compression",
)

allowed_compressor_model_providers: ClassVar[Collection[str]] = ("mixedbread-ai",)

@validator("compressor_model_provider")
def validate_retriever_type(cls, value) -> Dict:
assert (
value in cls.allowed_compressor_model_providers
), f"Compressor model of {value} not allowed. Valid values are: {cls.allowed_compressor_model_providers}"
return value
allowed_compressor_model_providers: ClassVar[Collection[str]]


class ContextualCompressionMultiQueryRetrieverConfig(
Expand All @@ -101,10 +90,6 @@ class ContextualCompressionMultiQueryRetrieverConfig(
pass


class LordOfRetrievers(ContextualCompressionRetrieverConfig, MultiQueryRetrieverConfig):
pass


class ExampleQueryInput(BaseModel):
"""
Model for Query input.
Expand Down Expand Up @@ -137,7 +122,6 @@ class ExampleQueryInput(BaseModel):
"multi-query",
"contextual-compression",
"contextual-compression-multi-query",
"lord-of-the-retrievers",
)

stream: Optional[bool] = Field(title="Stream the results", default=False)
Expand Down Expand Up @@ -170,9 +154,4 @@ def validate_retriever_type(cls, values: Dict) -> Dict:
**values.get("retriever_config")
)

elif retriever_name == "lord-of-the-retrievers":
values["retriever_config"] = LordOfRetrievers(
**values.get("retriever_config")
)

return values
Loading

0 comments on commit e08207c

Please sign in to comment.