From 27ad3ae1e882e736cc753d12f17a23d5f46b132b Mon Sep 17 00:00:00 2001 From: Jett Scythe Date: Tue, 1 Aug 2023 14:37:33 -0300 Subject: [PATCH 1/3] Add support for ElasticSearch --- docs/README_LangChain.md | 3 +- .../requirements_optional_langchain.txt | 4 +++ src/gpt_langchain.py | 32 +++++++++++++++++-- src/gradio_runner.py | 2 +- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/docs/README_LangChain.md b/docs/README_LangChain.md index e1a77e8b6..1c379e99f 100644 --- a/docs/README_LangChain.md +++ b/docs/README_LangChain.md @@ -15,10 +15,11 @@ or for CPU FAISS database, run: pip install -r reqs_optional/requirements_optional_faiss_cpu.txt ``` -or for Weaviate, run: +or for Weaviate or ElasticSearch, run: ```bash pip install -r reqs_optional/requirements_optional_langchain.txt ``` +Note: Elasticsearch client / server should have the same major versioning. ## Supported Datatypes Open-source data types are supported, .msg is not supported due to GPL-3 requirement. Other meta types support other types inside them. Special support for some behaviors is provided by the UI itself. diff --git a/reqs_optional/requirements_optional_langchain.txt b/reqs_optional/requirements_optional_langchain.txt index 161f98362..66457c769 100644 --- a/reqs_optional/requirements_optional_langchain.txt +++ b/reqs_optional/requirements_optional_langchain.txt @@ -40,3 +40,7 @@ pip-licenses==4.3.0 # weaviate vector db weaviate-client==3.22.1 + +# elasticsearch vector db +elasticsearch==7.15.0 + diff --git a/src/gpt_langchain.py b/src/gpt_langchain.py index e8c16d0f8..7d171ca86 100644 --- a/src/gpt_langchain.py +++ b/src/gpt_langchain.py @@ -57,7 +57,7 @@ from langchain.chains.question_answering import load_qa_chain from langchain.docstore.document import Document from langchain import PromptTemplate, HuggingFaceTextGenInference -from langchain.vectorstores import Chroma +from langchain.vectorstores import Chroma, ElasticVectorSearch def get_db(sources, use_openai_embedding=False, db_type='faiss', @@ -117,6 +117,16 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, use_openai_embedding=use_openai_embedding, hf_embedding_model=hf_embedding_model) + elif db_type == 'elasticsearch': + index_name = collection_name.lower() + if os.getenv('ELASTICSEARCH_HOST', None): + db = _create_local_elasticsearch_client(embedding, index_name) + else: + db = ElasticVectorSearch( + elasticsearch_url="http://localhost:9200", + index_name=index_name, + embedding=embedding + ) else: raise RuntimeError("No such db_type=%s" % db_type) @@ -195,6 +205,8 @@ def add_to_db(db, sources, db_type='faiss', db.persist() clear_embedding(db) save_embed(db, use_openai_embedding, hf_embedding_model) + elif db_type == 'elasticsearch': + db.add_documents(documents=sources) else: raise RuntimeError("No such db_type=%s" % db_type) @@ -2527,7 +2539,8 @@ def get_chain(query=None, tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score] template_tokens = len(llm.pipeline.tokenizer(template)['input_ids']) elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss', - 'weaviate']: + 'weaviate', + 'elasticsearch']: # use ticktoken for faiss since embedding called differently tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score] template_tokens = llm.get_num_tokens(template) @@ -2834,5 +2847,20 @@ def _create_local_weaviate_client(): return None +def _create_local_elasticsearch_client(embedding, index_name): + ELASTICSEARCH_HOST = os.getenv('ELASTICSEARCH_HOST', "localhost") + ELASTICSEARCH_PORT = os.getenv('ELASTICSEARCH_PORT', "9200") + ELASTICSEARCH_USERNAME = os.getenv('ELASTICSEARCH_USERNAME') + ELASTICSEARCH_PASSWORD = os.getenv('ELASTICSEARCH_PASSWORD') + ELASTICSEARCH_SCHEMA = os.getenv('ELASTICSEARCH_SCHEMA', "https") + if ELASTICSEARCH_USERNAME is not None and ELASTICSEARCH_PASSWORD is not None: + ELASTICSEARCH_URL = f"{ELASTICSEARCH_SCHEMA}://{ELASTICSEARCH_USERNAME}:{ELASTICSEARCH_PASSWORD}@{ELASTICSEARCH_HOST}:{ELASTICSEARCH_PORT}" + else: + ELASTICSEARCH_URL = f"{ELASTICSEARCH_SCHEMA}://{ELASTICSEARCH_HOST}:{ELASTICSEARCH_PORT}" + + elastic_vector_search = ElasticVectorSearch(embedding=embedding, elasticsearch_url=ELASTICSEARCH_URL, index_name=index_name) + return elastic_vector_search + + if __name__ == '__main__': pass diff --git a/src/gradio_runner.py b/src/gradio_runner.py index 2afe46f8e..e339e42fa 100644 --- a/src/gradio_runner.py +++ b/src/gradio_runner.py @@ -2747,7 +2747,7 @@ def _update_user_db(file, if dbs is None: dbs = {} assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs)) - # assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type + # assert db_type in ['faiss', 'chroma', 'elasticsearch'], "db_type %s not supported" % db_type from gpt_langchain import add_to_db, get_db, path_to_docs # handle case of list of temp buffer if isinstance(file, list) and len(file) > 0 and hasattr(file[0], 'name'): From acab75ec6244b6dc6e6044b1e884657e557c83a8 Mon Sep 17 00:00:00 2001 From: Jett Scythe Date: Wed, 2 Aug 2023 16:18:22 -0300 Subject: [PATCH 2/3] Dont save duplicate files to ES --- .../requirements_optional_langchain.txt | 2 +- src/gpt_langchain.py | 27 +++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/reqs_optional/requirements_optional_langchain.txt b/reqs_optional/requirements_optional_langchain.txt index 66457c769..9cf5fdea0 100644 --- a/reqs_optional/requirements_optional_langchain.txt +++ b/reqs_optional/requirements_optional_langchain.txt @@ -42,5 +42,5 @@ pip-licenses==4.3.0 weaviate-client==3.22.1 # elasticsearch vector db -elasticsearch==7.15.0 +elasticsearch==8.9.0 diff --git a/src/gpt_langchain.py b/src/gpt_langchain.py index fb1655d8c..606caa4ff 100644 --- a/src/gpt_langchain.py +++ b/src/gpt_langchain.py @@ -206,10 +206,32 @@ def add_to_db(db, sources, db_type='faiss', clear_embedding(db) save_embed(db, use_openai_embedding, hf_embedding_model) elif db_type == 'elasticsearch': + field = "" + if avoid_dup_by_file: + field = "source" + if avoid_dup_by_content: + field = "hashid" + # Perform the aggregation query + response = db.client.search(index=db.index_name, body={ + "size": 0, + "aggs": { + "unique_sources": { + "terms": { + "field": f"metadata.{field}.keyword", + "size": 10000 # Set size to a high value to retrieve all unique sources + } + } + } + }) + # Extract the unique sources from the response + unique_sources = [bucket["key"] for bucket in response["aggregations"]["unique_sources"]["buckets"]] + sources = [x for x in sources if x.metadata[f"{field}"] not in unique_sources] + num_new_sources = len(sources) + if num_new_sources == 0: + return db, num_new_sources, [] db.add_documents(documents=sources) else: raise RuntimeError("No such db_type=%s" % db_type) - new_sources_metadata = [x.metadata for x in sources] return db, num_new_sources, new_sources_metadata @@ -2858,7 +2880,8 @@ def _create_local_elasticsearch_client(embedding, index_name): else: ELASTICSEARCH_URL = f"{ELASTICSEARCH_SCHEMA}://{ELASTICSEARCH_HOST}:{ELASTICSEARCH_PORT}" - elastic_vector_search = ElasticVectorSearch(embedding=embedding, elasticsearch_url=ELASTICSEARCH_URL, index_name=index_name) + elastic_vector_search = ElasticVectorSearch(embedding=embedding, elasticsearch_url=ELASTICSEARCH_URL, + index_name=index_name) return elastic_vector_search From a486a07d2e61ce7c0a888af42e3917ef5c8e12c7 Mon Sep 17 00:00:00 2001 From: Jett Scythe Date: Wed, 2 Aug 2023 17:33:38 -0300 Subject: [PATCH 3/3] check if index exists, ensure Document Content is a str --- src/gpt_langchain.py | 45 ++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/gpt_langchain.py b/src/gpt_langchain.py index d381cdba6..24ab121de 100644 --- a/src/gpt_langchain.py +++ b/src/gpt_langchain.py @@ -229,29 +229,30 @@ def add_to_db(db, sources, db_type='faiss', # save here is for migration, in case old db directory without embedding saved save_embed(db, use_openai_embedding, hf_embedding_model) elif db_type == 'elasticsearch': - field = "" - if avoid_dup_by_file: - field = "source" - if avoid_dup_by_content: - field = "hashid" - # Perform the aggregation query - response = db.client.search(index=db.index_name, body={ - "size": 0, - "aggs": { - "unique_sources": { - "terms": { - "field": f"metadata.{field}.keyword", - "size": 10000 # Set size to a high value to retrieve all unique sources + if db.client.indices.exists(index=db.index_name): + field = "" + if avoid_dup_by_file: + field = "source" + if avoid_dup_by_content: + field = "hashid" + # Perform the aggregation query + response = db.client.search(index=db.index_name, body={ + "size": 0, + "aggs": { + "unique_sources": { + "terms": { + "field": f"metadata.{field}.keyword", + "size": 10000 # Set size to a high value to retrieve all unique sources + } } } - } - }) - # Extract the unique sources from the response - unique_sources = [bucket["key"] for bucket in response["aggregations"]["unique_sources"]["buckets"]] - sources = [x for x in sources if x.metadata[f"{field}"] not in unique_sources] - num_new_sources = len(sources) - if num_new_sources == 0: - return db, num_new_sources, [] + }) + # Extract the unique sources from the response + unique_sources = [bucket["key"] for bucket in response["aggregations"]["unique_sources"]["buckets"]] + sources = [x for x in sources if x.metadata[f"{field}"] not in unique_sources] + num_new_sources = len(sources) + if num_new_sources == 0: + return db, num_new_sources, [] db.add_documents(documents=sources) else: raise RuntimeError("No such db_type=%s" % db_type) @@ -2628,7 +2629,7 @@ def get_chain(query=None, if top_k_docs == -1: top_k_docs = len(db_documents) # similar to langchain's chroma's _results_to_docs_and_scores - docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) + docs_with_score = [(Document(page_content=str(result[0]), metadata=result[1] or {}), 0) for result in zip(db_documents, db_metadatas)] # order documents