Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added reranking via model_gateway class #253

Merged
merged 3 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ RUN if [ "${ADD_VECTORDB}" = "1" ]; then python3 -m pip install --use-pep517 --n
ARG ADD_PRISMA=0
RUN if [ "${ADD_PRISMA}" = "1" ]; then prisma version; fi

# TODO (chiragjn): These should be removed from here and directly added as environment variables
# Temporary addition until templates have been updated using build args as environment variables
ARG ADD_RERANKER_SVC_URL=""
ENV RERANKER_SVC_URL=${ADD_RERANKER_SVC_URL}
ARG ADD_EMBEDDING_SVC_URL=""
ENV EMBEDDING_SVC_URL=${ADD_EMBEDDING_SVC_URL}
# TODO: Remove these when templates inject env vars
ARG MODELS_CONFIG_PATH
ENV MODELS_CONFIG_PATH=${MODELS_CONFIG_PATH}

ARG INFINITY_API_KEY
ENV INFINITY_API_KEY=${INFINITY_API_KEY}

# Copy the project files
COPY . /app
Expand Down
43 changes: 42 additions & 1 deletion backend/modules/model_gateway/model_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from backend.settings import settings
from backend.types import ModelConfig, ModelProviderConfig, ModelType

from .reranker_svc import InfinityRerankerSvc


class ModelGateway:
provider_configs: List[ModelProviderConfig]
Expand All @@ -20,8 +22,9 @@ def __init__(self):
logger.info(f"Loading models config from {settings.MODELS_CONFIG_PATH}")
with open(settings.MODELS_CONFIG_PATH) as f:
data = yaml.safe_load(f)
print(data)
logger.info(f"Loaded models config: {data}")
_providers = data.get("model_providers") or []

# parse the json data into a list of ModelProviderConfig objects
self.provider_configs = [
ModelProviderConfig.parse_obj(item) for item in _providers
Expand All @@ -32,6 +35,9 @@ def __init__(self):
# load embedding models
self.embedding_models: List[ModelConfig] = []

# load reranker models
self.reranker_models: List[ModelConfig] = []

for provider_config in self.provider_configs:
if provider_config.api_key_env_var and not os.environ.get(
provider_config.api_key_env_var
Expand Down Expand Up @@ -65,12 +71,27 @@ def __init__(self):
)
)

for model_id in provider_config.reranking_model_ids:
model_name = f"{provider_config.provider_name}/{model_id}"
self.model_name_to_provider_config[model_name] = provider_config

# Register the model as a reranker model
self.reranker_models.append(
ModelConfig(
name=f"{provider_config.provider_name}/{model_id}",
type=ModelType.reranking,
)
)

def get_embedding_models(self) -> List[ModelConfig]:
return self.embedding_models

def get_llm_models(self) -> List[ModelConfig]:
return self.llm_models

def get_reranker_models(self) -> List[ModelConfig]:
return self.reranker_models

def get_embedder_from_model_config(self, model_name: str) -> Embeddings:
if model_name not in self.model_name_to_provider_config:
raise ValueError(f"Model {model_name} not registered in the model gateway.")
Expand Down Expand Up @@ -116,5 +137,25 @@ def get_llm_from_model_config(
base_url=model_provider_config.base_url,
)

def get_reranker_from_model_config(self, model_name: str, top_k: int = 3):
if model_name not in self.model_name_to_provider_config:
raise ValueError(f"Model {model_name} not registered in the model gateway.")

model_provider_config: ModelProviderConfig = self.model_name_to_provider_config[
model_name
]
if not model_provider_config.api_key_env_var:
api_key = "EMPTY"
else:
api_key = os.environ.get(model_provider_config.api_key_env_var, "")
model_id = "/".join(model_name.split("/")[1:])

return InfinityRerankerSvc(
model=model_id,
api_key=api_key,
base_url=model_provider_config.base_url,
top_k=top_k,
)


model_gateway = ModelGateway()
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class InfinityRerankerSvc(BaseDocumentCompressor):
"""

model: str
top_k: int = 3
url = settings.RERANKER_SVC_URL
top_k: int
base_url: str
api_key: str

def compress_documents(
self,
Expand All @@ -36,8 +37,13 @@ def compress_documents(
"model": self.model,
}

headers = {
"Authorization": f"Bearer {settings.INFINITY_API_KEY}",
"Content-Type": "application/json",
}

reranked_docs = requests.post(
self.url.rstrip("/") + "/rerank", json=payload
self.base_url.rstrip("/") + "/rerank", headers=headers, json=payload
).json()

"""
Expand Down
2 changes: 1 addition & 1 deletion backend/modules/query_controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
)
from backend.modules.query_controllers.query_controller import register_query_controller

register_query_controller("default", BasicRAGQueryController)
register_query_controller("basic-rag", BasicRAGQueryController)
register_query_controller("multimodal", MultiModalRAGQueryController)
46 changes: 14 additions & 32 deletions backend/modules/query_controllers/example/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,18 @@
GENERATION_TIMEOUT_SEC,
ExampleQueryInput,
)
from backend.modules.rerankers.reranker_svc import InfinityRerankerSvc

# from backend.modules.rerankers.reranker_svc import InfinityRerankerSvc
from backend.modules.vector_db.client import VECTOR_STORE_CLIENT
from backend.server.decorators import post, query_controller
from backend.settings import settings
from backend.types import Collection, ModelConfig

EXAMPLES = {
"vector-store-similarity": QUERY_WITH_VECTOR_STORE_RETRIEVER_PAYLOAD,
"contextual-compression-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_PAYLOAD,
"contextual-compression-multi-query-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_MULTI_QUERY_RETRIEVER_SIMILARITY_PAYLOAD,
}

if settings.RERANKER_SVC_URL:
EXAMPLES.update(
{
"contextual-compression-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_PAYLOAD,
}
)
EXAMPLES.update(
{
"contextual-compression-multi-query-similarity": QUERY_WITH_CONTEXTUAL_COMPRESSION_MULTI_QUERY_RETRIEVER_SIMILARITY_PAYLOAD,
}
)


@query_controller("/basic-rag")
class BasicRAGQueryController:
Expand Down Expand Up @@ -111,26 +101,18 @@ def _get_contextual_compression_retriever(self, vector_store, retriever_config):
Get the contextual compression retriever
"""
try:
if settings.RERANKER_SVC_URL:
retriever = self._get_vector_store_retriever(
vector_store, retriever_config
)
logger.info("Using MxBaiRerankerSmall th' service...")
compressor = InfinityRerankerSvc(
top_k=retriever_config.top_k,
model=retriever_config.compressor_model_name,
)
retriever = self._get_vector_store_retriever(vector_store, retriever_config)
logger.info("Using MxBaiRerankerSmall th' service...")

compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
compressor = model_gateway.get_reranker_from_model_config(
model_name=retriever_config.compressor_model_name,
top_k=retriever_config.top_k,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)

return compression_retriever
else:
raise HTTPException(
status_code=500,
detail="Reranker service is not available",
)
return compression_retriever
except Exception as e:
logger.error(f"Error in getting contextual compression retriever: {e}")
raise HTTPException(
Expand Down
22 changes: 8 additions & 14 deletions backend/modules/query_controllers/example/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity",
"search_kwargs": {"k": 10},
Expand Down Expand Up @@ -113,8 +112,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "mmr",
"search_kwargs": {
Expand All @@ -130,7 +128,7 @@
"description": """
Requires k and fetch_k in search kwargs for mmr.
search_type can either be similarity or mmr or similarity_score_threshold.
Currently only support for mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added.""",
Currently only support for local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added.""",
"value": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_SEARCH_TYPE_MMR,
}

Expand All @@ -147,8 +145,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity_score_threshold",
"search_kwargs": {"score_threshold": 0.7},
Expand All @@ -161,7 +158,7 @@
"description": """
Requires score_threshold float (0~1) in search kwargs for similarity search.
search_type can either be similarity or mmr or similarity_score_threshold.
Currently only support for mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added""",
Currently only support for local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1 reranker is added""",
"value": QUERY_WITH_CONTEXTUAL_COMPRESSION_RETRIEVER_SEARCH_TYPE_SIMILARITY_WITH_SCORE,
}

Expand Down Expand Up @@ -273,8 +270,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression-multi-query",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "mmr",
"search_kwargs": {
Expand Down Expand Up @@ -309,8 +305,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression-multi-query",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity",
"search_kwargs": {"k": 10},
Expand Down Expand Up @@ -343,8 +338,7 @@
"prompt_template": PROMPT,
"retriever_name": "contextual-compression-multi-query",
"retriever_config": {
"compressor_model_provider": "mixedbread-ai",
"compressor_model_name": "mixedbread-ai/mxbai-rerank-xsmall-v1",
"compressor_model_name": "local-infinity/mixedbread-ai/mxbai-rerank-xsmall-v1",
"top_k": 5,
"search_type": "similarity_score_threshold",
"search_kwargs": {"score_threshold": 0.7},
Expand Down
23 changes: 1 addition & 22 deletions backend/modules/query_controllers/example/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ class MultiQueryRetrieverConfig(VectorStoreRetrieverConfig):


class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig):
compressor_model_provider: str = Field(
title="provider of the compressor model",
)

compressor_model_name: str = Field(
title="model name of the compressor",
)
Expand All @@ -85,14 +81,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig):
title="Top K docs to collect post compression",
)

allowed_compressor_model_providers: ClassVar[Collection[str]] = ("mixedbread-ai",)

@validator("compressor_model_provider")
def validate_retriever_type(cls, value) -> Dict:
assert (
value in cls.allowed_compressor_model_providers
), f"Compressor model of {value} not allowed. Valid values are: {cls.allowed_compressor_model_providers}"
return value
allowed_compressor_model_providers: ClassVar[Collection[str]]


class ContextualCompressionMultiQueryRetrieverConfig(
Expand All @@ -101,10 +90,6 @@ class ContextualCompressionMultiQueryRetrieverConfig(
pass


class LordOfRetrievers(ContextualCompressionRetrieverConfig, MultiQueryRetrieverConfig):
pass


class ExampleQueryInput(BaseModel):
"""
Model for Query input.
Expand Down Expand Up @@ -137,7 +122,6 @@ class ExampleQueryInput(BaseModel):
"multi-query",
"contextual-compression",
"contextual-compression-multi-query",
"lord-of-the-retrievers",
)

stream: Optional[bool] = Field(title="Stream the results", default=False)
Expand Down Expand Up @@ -170,9 +154,4 @@ def validate_retriever_type(cls, values: Dict) -> Dict:
**values.get("retriever_config")
)

elif retriever_name == "lord-of-the-retrievers":
values["retriever_config"] = LordOfRetrievers(
**values.get("retriever_config")
)

return values
Loading
Loading