Skip to content

Commit

Permalink
Merge pull request #732 from tedcochran/issue-731-embeddings-not-call…
Browse files Browse the repository at this point in the history
…able

Created a function for embeddings options
  • Loading branch information
PromtEngineer authored Feb 4, 2024
2 parents 8450efc + e4383d0 commit c70d068
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 43 deletions.
46 changes: 23 additions & 23 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,29 +160,29 @@ def main(device_type):
their respective huggingface repository, project page or github repository.
"""

if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
query_instruction = 'Represent this sentence for searching relevant passages:'

return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:

return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)
def get_embeddings():
if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)
embeddings = get_embeddings()
logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}")

db = Chroma.from_documents(
texts,
Expand Down
43 changes: 23 additions & 20 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,30 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
their respective huggingface repository, project page or github repository.
"""

if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)
def get_embeddings():
if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)
embeddings = get_embeddings()
logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}")

# load the vectorstore
db = Chroma(
persist_directory=PERSIST_DIRECTORY,
Expand Down

0 comments on commit c70d068

Please sign in to comment.