Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for llama-cpp-python embedder and automatic selection when using llama as LLM #435

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import string
import json
from typing import List
from itertools import combinations
from sklearn.feature_extraction.text import CountVectorizer
from langchain.embeddings.base import Embeddings
import httpx


class DumbEmbedder(Embeddings):
Expand Down Expand Up @@ -41,3 +44,21 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_query(self, text: str) -> List[float]:
"""Embed a string of text and returns the embedding vector as a list of floats."""
return self.embedder.transform([text]).astype(float).todense().tolist()[0]


class CustomOpenAIEmbeddings(Embeddings):
def __init__(self, url):
self.url = os.path.join(url, "v1/embeddings")

def embed_documents(self, texts: List[str]) -> List[List[float]]:
payload = json.dumps({"input": texts})
ret = httpx.post(self.url, data=payload, timeout=None)
ret.raise_for_status()
return [e['embedding'] for e in ret.json()['data']]

def embed_query(self, text: str) -> List[float]:
payload = json.dumps({"input": text})
ret = httpx.post(self.url, data=payload, timeout=None)
ret.raise_for_status()
return ret.json()['data'][0]['embedding']

11 changes: 8 additions & 3 deletions core/cat/factory/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def _identifying_params(self) -> Mapping[str, Any]:


class CustomOpenAI(OpenAI):
url: str

def __init__(self, **kwargs):

model_kwargs = {
'repeat_penalty': kwargs.pop('repeat_penalty')
'repeat_penalty': kwargs.pop('repeat_penalty'),
'top_k': kwargs.pop('top_k')
}

stop = kwargs.pop('stop', None)
Expand All @@ -81,4 +83,7 @@ def __init__(self, **kwargs):
model_kwargs=model_kwargs,
**kwargs
)
self.openai_api_base = os.path.join(kwargs['url'], "v1")

self.url = kwargs['url']
self.openai_api_base = os.path.join(self.url, "v1")

16 changes: 14 additions & 2 deletions core/cat/factory/embedder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import langchain
from pydantic import PyObject, BaseSettings

from cat.factory.dumb_embedder import DumbEmbedder
from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings


# Base class to manage LLM configuration.
Expand Down Expand Up @@ -36,11 +36,22 @@ class EmbedderDumbConfig(EmbedderSettings):

class Config:
schema_extra = {
"name_human_readable": "Dumb Embedder",
"humanReadableName": "Dumb Embedder",
"description": "Configuration for default embedder. It encodes the pairs of characters",
}


class EmbedderLlamaCppConfig(EmbedderSettings):
url: str
_pyclass = PyObject = CustomOpenAIEmbeddings

class Config:
schema_extra = {
"humanReadableName": "Self-hosted llama-cpp-python embedder",
"description": "Self-hosted llama-cpp-python embedder",
}


class EmbedderOpenAIConfig(EmbedderSettings):
openai_api_key: str
model: str = "text-embedding-ada-002"
Expand Down Expand Up @@ -98,6 +109,7 @@ class Config:
SUPPORTED_EMDEDDING_MODELS = [
EmbedderDumbConfig,
EmbedderFakeConfig,
EmbedderLlamaCppConfig,
EmbedderOpenAIConfig,
EmbedderAzureOpenAIConfig,
EmbedderCohereConfig,
Expand Down
9 changes: 9 additions & 0 deletions core/cat/mad_hatter/core_plugin/hooks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from langchain import HuggingFaceHub
from langchain.chat_models import AzureChatOpenAI
from cat.mad_hatter.decorators import hook
from cat.factory.custom_llm import CustomOpenAI


@hook(priority=0)
Expand Down Expand Up @@ -142,6 +143,14 @@ def get_language_embedder(cat) -> embedders.EmbedderSettings:
}
)

# Llama-cpp-python
elif type(cat._llm) in [CustomOpenAI]:
embedder = embedders.EmbedderLlamaCppConfig.get_embedder_from_config(
{
"url": cat._llm.url
}
)

else:
# If no embedder matches vendor, and no external embedder is configured, we use the DumbEmbedder.
# `This embedder is not a model properly trained
Expand Down
Loading