Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Elastic Search #656

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion docs/README_LangChain.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ or for CPU FAISS database, run:
pip install -r reqs_optional/requirements_optional_faiss_cpu.txt
```

or for Weaviate, run:
or for Weaviate or ElasticSearch, run:
```bash
pip install -r reqs_optional/requirements_optional_langchain.txt
```
Note: Elasticsearch client / server should have the same major versioning.
## Supported Datatypes

Open-source data types are supported, .msg is not supported due to GPL-3 requirement. Other meta types support other types inside them. Special support for some behaviors is provided by the UI itself.
Expand Down
4 changes: 4 additions & 0 deletions reqs_optional/requirements_optional_langchain.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ pip-licenses==4.3.0

# weaviate vector db
weaviate-client==3.22.1

# elasticsearch vector db
elasticsearch==8.9.0

60 changes: 56 additions & 4 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from langchain.chains.question_answering import load_qa_chain
from langchain.docstore.document import Document
from langchain import PromptTemplate, HuggingFaceTextGenInference
from langchain.vectorstores import Chroma
from langchain.vectorstores import Chroma, ElasticVectorSearch


def get_db(sources, use_openai_embedding=False, db_type='faiss',
Expand Down Expand Up @@ -124,6 +124,16 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss',
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
use_openai_embedding=use_openai_embedding,
hf_embedding_model=hf_embedding_model)
elif db_type == 'elasticsearch':
index_name = collection_name.lower()
if os.getenv('ELASTICSEARCH_HOST', None):
db = _create_local_elasticsearch_client(embedding, index_name)
else:
db = ElasticVectorSearch(
elasticsearch_url="http://localhost:9200",
index_name=index_name,
embedding=embedding
)
else:
raise RuntimeError("No such db_type=%s" % db_type)

Expand Down Expand Up @@ -218,9 +228,34 @@ def add_to_db(db, sources, db_type='faiss',
clear_embedding(db)
# save here is for migration, in case old db directory without embedding saved
save_embed(db, use_openai_embedding, hf_embedding_model)
elif db_type == 'elasticsearch':
if db.client.indices.exists(index=db.index_name):
field = ""
if avoid_dup_by_file:
field = "source"
if avoid_dup_by_content:
field = "hashid"
# Perform the aggregation query
response = db.client.search(index=db.index_name, body={
"size": 0,
"aggs": {
"unique_sources": {
"terms": {
"field": f"metadata.{field}.keyword",
"size": 10000 # Set size to a high value to retrieve all unique sources
}
}
}
})
# Extract the unique sources from the response
unique_sources = [bucket["key"] for bucket in response["aggregations"]["unique_sources"]["buckets"]]
sources = [x for x in sources if x.metadata[f"{field}"] not in unique_sources]
num_new_sources = len(sources)
if num_new_sources == 0:
return db, num_new_sources, []
db.add_documents(documents=sources)
else:
raise RuntimeError("No such db_type=%s" % db_type)

new_sources_metadata = [x.metadata for x in sources]

return db, num_new_sources, new_sources_metadata
Expand Down Expand Up @@ -2762,7 +2797,7 @@ def get_chain(query=None,
if top_k_docs == -1:
top_k_docs = len(db_documents)
# similar to langchain's chroma's _results_to_docs_and_scores
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
docs_with_score = [(Document(page_content=str(result[0]), metadata=result[1] or {}), 0)
for result in zip(db_documents, db_metadatas)]

# order documents
Expand Down Expand Up @@ -2797,7 +2832,8 @@ def get_chain(query=None,
tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score]
template_tokens = len(llm.pipeline.tokenizer(template)['input_ids'])
elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss',
'weaviate']:
'weaviate',
'elasticsearch']:
# use ticktoken for faiss since embedding called differently
tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score]
template_tokens = llm.get_num_tokens(template)
Expand Down Expand Up @@ -3104,5 +3140,21 @@ def _create_local_weaviate_client():
return None


def _create_local_elasticsearch_client(embedding, index_name):
ELASTICSEARCH_HOST = os.getenv('ELASTICSEARCH_HOST', "localhost")
ELASTICSEARCH_PORT = os.getenv('ELASTICSEARCH_PORT', "9200")
ELASTICSEARCH_USERNAME = os.getenv('ELASTICSEARCH_USERNAME')
ELASTICSEARCH_PASSWORD = os.getenv('ELASTICSEARCH_PASSWORD')
ELASTICSEARCH_SCHEMA = os.getenv('ELASTICSEARCH_SCHEMA', "https")
if ELASTICSEARCH_USERNAME is not None and ELASTICSEARCH_PASSWORD is not None:
ELASTICSEARCH_URL = f"{ELASTICSEARCH_SCHEMA}://{ELASTICSEARCH_USERNAME}:{ELASTICSEARCH_PASSWORD}@{ELASTICSEARCH_HOST}:{ELASTICSEARCH_PORT}"
else:
ELASTICSEARCH_URL = f"{ELASTICSEARCH_SCHEMA}://{ELASTICSEARCH_HOST}:{ELASTICSEARCH_PORT}"

elastic_vector_search = ElasticVectorSearch(embedding=embedding, elasticsearch_url=ELASTICSEARCH_URL,
index_name=index_name)
return elastic_vector_search


if __name__ == '__main__':
pass
2 changes: 1 addition & 1 deletion src/gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2832,7 +2832,7 @@ def _update_user_db(file,
if dbs is None:
dbs = {}
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
# assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type
# assert db_type in ['faiss', 'chroma', 'elasticsearch'], "db_type %s not supported" % db_type
from gpt_langchain import add_to_db, get_db, path_to_docs
# handle case of list of temp buffer
if isinstance(file, list) and len(file) > 0 and hasattr(file[0], 'name'):
Expand Down