Skip to content

Commit

Permalink
support nlpcloud, fixes
Browse files Browse the repository at this point in the history
riccardobl committed Apr 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 40d87cb commit 1dfb615
Showing 3 changed files with 32 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ ENV HF_HOME="/cache/hugginface"
ENV MODEL="intfloat/multilingual-e5-base"
ENV ADD_MARKERS_TO_SENTENCES="true"
ENV OPENAI_API_KEY=""

ENV NLP_CLOUD_API_KEY=""
VOLUME /cache

CMD ["python", "-u", "main.py"]
2 changes: 1 addition & 1 deletion Dockerfile.cuda
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ ENV MODEL="intfloat/multilingual-e5-base"
ENV MAX_TEXT_LENGTH=512
ENV ADD_MARKERS_TO_SENTENCES="true"
ENV OPENAI_API_KEY=""

ENV NLP_CLOUD_API_KEY=""
VOLUME /cache


42 changes: 30 additions & 12 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -15,19 +15,22 @@
from sentence_transformers.quantization import quantize_embeddings
import tiktoken
from openai import OpenAI
import nlpcloud
import numpy as np


class Runner (JobRunner):
openai = None
nlpcloud = None
def __init__(self, filters, meta, template, sockets):
super().__init__(filters, meta, template, sockets)
self.device = int(os.getenv('TRANSFORMERS_DEVICE', "-1"))
self.cachePath = os.getenv('CACHE_PATH', os.path.join(os.path.dirname(__file__), "cache"))
now = time.time()
self.modelName = os.getenv('MODEL', "intfloat/multilingual-e5-base")
self.maxTextLength = os.getenv('MAX_TEXT_LENGTH', 512)

if self.modelName.startswith("openai:"):
if self.modelName.startswith("nlpcloud:"):
self.nlpcloud = nlpcloud.Client(self.modelName.replace("nlpcloud:",""), os.getenv('NLP_CLOUD_API_KEY'))
elif self.modelName.startswith("openai:"):
self.log("Using OpenAI API "+ self.modelName)
self.openai = OpenAI()
self.openaiModelName = self.modelName.replace("openai:","")
@@ -70,20 +73,32 @@ def encode(self, sentences):
with open(cache_file, "rb") as f:
out.append(pickle.load(f))
# use openai for encoding
if self.openai:
encoded = []
for i in range(len(to_encode)):
response = self.openai.embeddings.create(
input=to_encode[i],
if self.nlpcloud :
if len(to_encode)>0:
embeddingsvs=self.nlpcloud.embeddings(to_encode).embeddings
encoded = []
for i in range(len(embeddingsvs)):
embeddings = embeddingsvs[i]
embeddings = np.array(embeddings)
encoded.append(embeddings)
elif self.openai:
if len(to_encode)>0:
embeddingsvs=self.openai.embeddings.create(
input=to_encode,
model=self.openaiModelName
)
embeddings = response.data[0].embedding
embeddings = np.array(embeddings)
encoded.append(embeddings)
encoded = []
for i in range(len(to_encode)):
embeddings = embeddingsvs.data[i].embedding
embeddings = np.array(embeddings)
encoded.append(embeddings)


# TODO: more apis?
# Use local model
else:
encoded = self.pipe.encode(to_encode)
if len(to_encode)>0:
encoded = self.pipe.encode(to_encode)

for i in range(len(to_encode_index)):
out[to_encode_index[i]] = encoded[i]
@@ -123,6 +138,7 @@ def getParamValue(key,default=None):
data = jin.data
data_type = jin.type
marker = jin.marker
self.log("Use data: "+data)
if marker != "query": marker="passage"
if data_type == "text":
sentences.append([data,marker])
@@ -138,6 +154,7 @@ def getParamValue(key,default=None):
raise Exception("Unsupported data type: "+data_type)

# Check local cache
self.log("Check cache...")
cacheId = str( self.modelName) + str(outputFormat) + str(max_tokens) + str(overlap) + str(quantize) + "".join([sentences[i][0] + ":" + sentences[i][1] for i in range(len(sentences))])
cacheId = hashlib.sha256(cacheId.encode("utf-8")).hexdigest()
cacheFile = os.path.join(self.cachePath, cacheId+".dat")
@@ -146,6 +163,7 @@ def getParamValue(key,default=None):
return f.read()

# Split long sentences
self.log("Split long sentences...")
sentences_chunks=[]
for sentence in sentences:
self.split(sentence[0], max_tokens, overlap, sentence[1], sentences_chunks)

0 comments on commit 1dfb615

Please sign in to comment.