Skip to content

Commit

Permalink
integrate Infinity embedding, add batch to Ollama embedding, retrieve…
Browse files Browse the repository at this point in the history
… use async
  • Loading branch information
etwk committed Aug 20, 2024
1 parent 0675dc4 commit 42ec7aa
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 195 deletions.
1 change: 0 additions & 1 deletion infra/env.d/check
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ EMBEDDING_BASE_URL=http://ollama:11434
EMBEDDING_MODEL_DEPLOY=api
EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en
INDEX_CHUNK_SIZES=[2048, 512, 128]
THREAD_BUILD_INDEX=12

LLM_MODEL_NAME=google/gemma-2-27b-it
OPENAI_API_KEY=<CHANGE_ME>
Expand Down
1 change: 1 addition & 0 deletions requirements.base.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
aiohttp
dspy-ai==2.4.13
fastapi
httpx[http2]
llama-index==0.10.65
llama-index-postprocessor-jinaai-rerank==0.1.7
openai
Expand Down
22 changes: 22 additions & 0 deletions src/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json

# reference: https://github.com/ollama/ollama-python/blob/main/ollama/_types.py
class ResponseError(Exception):
"""
Common class for response errors.
"""

def __init__(self, error: str, status_code: int = -1):
try:
# try to parse content as JSON and extract 'error'
# fallback to raw content if JSON parsing fails
error = json.loads(error).get('error', error)
except json.JSONDecodeError:
...

super().__init__(error)
self.error = error
'Reason for the error.'

self.status_code = status_code
'HTTP status code of the response.'
2 changes: 2 additions & 0 deletions src/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .infinity_embedding import InfinityEmbedding
from .ollama_embedding import OllamaEmbedding
132 changes: 132 additions & 0 deletions src/integrations/infinity_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# reference: https://github.com/ollama/ollama-python/blob/main/ollama/_client.py

import os
import httpx
from typing import Any, List
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.bridge.pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_fixed

import utils
from _types import ResponseError

DEFAULT_INFINITY_BASE_URL = "http://localhost:7997"

class InfinityEmbedding(BaseEmbedding):
"""Class for Infinity embeddings.
Using retry here cause one failed request could crash the whole embedding process.
Args:
api_key (str): Server API key.
model_name (str): Model for embedding.
base_url (str): Infinity url. Defaults to http://localhost:7997.
"""

_aclient: httpx.AsyncClient = PrivateAttr()
_client: httpx.Client = PrivateAttr()
_settings: dict = PrivateAttr()
_url: str = PrivateAttr()

def __init__(
self,
model_name: str,
api_key: str = "key",
base_url: str = DEFAULT_INFINITY_BASE_URL,
http2: bool = True,
follow_redirects: bool = True,
timeout: Any = None,
**kwargs: Any,
) -> None:
super().__init__(
model_name=model_name,
**kwargs,
)

self._settings = {
'follow_redirects': follow_redirects,
'headers': {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f"Bearer {api_key}",
},
'http2': http2,
'timeout': timeout,
}

self._url = os.path.join(base_url, "embeddings")

@classmethod
def class_name(cls) -> str:
return "InfinityEmbedding"

def _get_client(self, _async: bool = False):
"""Set and return httpx sync or async client"""
if _async:
if not hasattr(self, "_aclient"):
self._aclient = httpx.AsyncClient(**self._settings)
return self._aclient
else:
if not hasattr(self, "_client"):
self._client = httpx.Client(**self._settings)
return self._client

def _process_response(self, response: httpx.Response) -> List[List[float]]:
embeddings = [item['embedding'] for item in response.json()['data']]
return embeddings

@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
client = self._get_client()
response = client.request(
'POST',
self._url,
json={
"input": texts,
"model": self.model_name,
},
)

try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None

return self._process_response(response)

@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True)
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
client = self._get_client(_async=True)
response = await client.request(
'POST',
self._url,
json={
"input": texts,
"model": self.model_name,
},
)

try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None

return self._process_response(response)

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_text_embeddings([query])[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return await self._aget_text_embeddings([query])[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return await self._aget_text_embeddings([text])[0]
133 changes: 133 additions & 0 deletions src/integrations/ollama_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# reference: https://github.com/ollama/ollama-python/blob/main/ollama/_client.py

import os
import httpx
from typing import Any, List
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.bridge.pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_fixed

import utils
from _types import ResponseError

DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434"

class OllamaEmbedding(BaseEmbedding):
"""Class for Ollama embeddings.
Using retry here cause one failed request could crash the whole embedding process.
Args:
api_key (str): Server API key.
model_name (str): Model for embedding.
base_url (str): Ollama url. Defaults to http://localhost:7997.
"""

_aclient: httpx.AsyncClient = PrivateAttr()
_client: httpx.Client = PrivateAttr()
_settings: dict = PrivateAttr()
_url: str = PrivateAttr()

def __init__(
self,
model_name: str,
api_key: str = "key",
base_url: str = DEFAULT_OLLAMA_BASE_URL,
http2: bool = True,
follow_redirects: bool = True,
timeout: Any = None,
**kwargs: Any,
) -> None:
super().__init__(
model_name=model_name,
**kwargs,
)

self._settings = {
'follow_redirects': follow_redirects,
'headers': {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f"Bearer {api_key}",
},
'http2': http2,
'timeout': timeout,
}

self._url = os.path.join(base_url, "api/embed")

@classmethod
def class_name(cls) -> str:
return "OllamaEmbedding"

def _get_client(self, _async: bool = False):
"""Set and return httpx sync or async client"""
if _async:
if not hasattr(self, "_aclient"):
self._aclient = httpx.AsyncClient(**self._settings)
return self._aclient
else:
if not hasattr(self, "_client"):
self._client = httpx.Client(**self._settings)
return self._client

def _process_response(self, response: httpx.Response) -> List[List[float]]:
embeddings = response.json()['embeddings']
return embeddings

@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
client = self._get_client()
response = client.request(
'POST',
self._url,
json={
"input": texts,
"model": self.model_name,
},
)

try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None

return self._process_response(response)

# TODO: debug `Event loop is closed`
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True)
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
client = self._get_client(_async=True)
response = await client.request(
'POST',
self._url,
json={
"input": texts,
"model": self.model_name,
},
)

try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None

return self._process_response(response)

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_text_embeddings([query])[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return await self._aget_text_embeddings([query])[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return await self._aget_text_embeddings([text])[0]
Loading

0 comments on commit 42ec7aa

Please sign in to comment.