-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Llama3 Integration and RAG pipeline for querying (#2)
* Initial commit containing Ollama and Llama3 code * Removed ngrok URL * Improved performance and integrated GoaT API entity lookup * Implemented changes requested in PR review * Minor changes * Minor changes
- Loading branch information
Showing
11 changed files
with
607 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
OLLAMA_HOST_URL=http://127.0.0.1:11434 | ||
RETRY_COUNT=3 | ||
GOAT_BASE_URL=https://goat.genomehubs.org/api/v2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
prompts/__pycache__ | ||
rich_query_index | ||
.DS_Store | ||
.env | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
|
||
# Installation Guide | ||
|
||
This guide provides step-by-step instructions to set up the project after cloning the repository. | ||
|
||
## Step 1: Install Miniconda | ||
|
||
Download and install Miniconda using the following commands: | ||
|
||
```bash | ||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh > Miniconda3.sh | ||
chmod +x Miniconda3.sh | ||
./Miniconda3.sh | ||
``` | ||
|
||
## Step 2: Create a Conda Environment | ||
|
||
Create a new Conda environment with Python 3.12 and activate it: | ||
|
||
```bash | ||
conda create -y -n nlp python=3.12 | ||
conda activate nlp | ||
``` | ||
|
||
## Step 3: Clone the Repository | ||
|
||
Clone the repository using the specified branch: | ||
|
||
```bash | ||
git clone https://github.com/genomehubs/goat-nlp | ||
cd goat-nlp | ||
``` | ||
|
||
## Step 4: Install Python Dependencies | ||
|
||
Install the required Python packages using pip: | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Step 5: Install Ollama | ||
|
||
Install Ollama using the provided script: | ||
|
||
```bash | ||
curl -fsSL https://ollama.com/install.sh | sh | ||
``` | ||
|
||
## Step 6: Run Ollama | ||
|
||
Run the Ollama application: | ||
|
||
```bash | ||
ollama run llama3 | ||
``` | ||
|
||
## Step 7: Start the Flask Application | ||
|
||
Set the necessary environment variables and start the Flask application: | ||
|
||
```bash | ||
export OLLAMA_HOST_URL=http://127.0.0.1:11434 | ||
export RETRY_COUNT=5 | ||
export GOAT_BASE_URL=https://goat.genomehubs.org/api/v2 | ||
python -m flask run | ||
``` | ||
|
||
The UI will be available at `http://localhost:5000/` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import sys | ||
from flask import Flask, request, render_template, jsonify | ||
from index import load_index, query_engine | ||
from query_reformulation import fetch_related_taxons | ||
import json | ||
import logging | ||
|
||
|
||
app = Flask("goat_nlp") | ||
|
||
handler = logging.StreamHandler(sys.stdout) | ||
handler.setFormatter(logging.Formatter( | ||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | ||
app.logger.addHandler(handler) | ||
app.logger.setLevel(logging.INFO) | ||
|
||
|
||
def chat_bot_rag(query): | ||
entity_taxon_map = fetch_related_taxons(query) | ||
|
||
return query_engine.custom_query(query, | ||
entity_taxon_map) | ||
|
||
|
||
@app.route('/') | ||
def home(): | ||
return render_template('chat.html') | ||
|
||
|
||
@app.route('/rebuildIndex') | ||
def index(): | ||
load_index(force_reload=True) | ||
|
||
|
||
@app.route('/chat', methods=['POST']) | ||
def chat(): | ||
user_message = request.form['user_input'] | ||
bot_message = chat_bot_rag(user_message) | ||
|
||
try: | ||
bot_message = json.loads(bot_message)["url"] | ||
except Exception: | ||
pass | ||
|
||
return jsonify({'response': str(bot_message)}) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from dotenv import load_dotenv | ||
from llama_index.core import VectorStoreIndex, StorageContext | ||
from llama_index.core import load_index_from_storage | ||
from llama_index.core import get_response_synthesizer | ||
from llama_index.core import SimpleDirectoryReader | ||
from llama_index.core import Settings | ||
from llama_index.llms.ollama import Ollama | ||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | ||
import os | ||
|
||
from prompt import QUERY_PROMPT | ||
from query_engine import GoaTAPIQueryEngine | ||
|
||
|
||
load_dotenv() | ||
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5") | ||
Settings.llm = Ollama(model="llama3", base_url=os. | ||
getenv("OLLAMA_HOST_URL", "http://127.0.0.1:11434"), | ||
request_timeout=36000.0) | ||
Settings.chunk_size = 256 | ||
|
||
|
||
def build_index(documents, | ||
save_dir="rich_query_index", | ||
force=False): | ||
''' | ||
Build the index from the given rich queries and save it in the specified | ||
directory. | ||
Parameters: | ||
- documents (list): A list of rich queries to build the index from. | ||
- save_dir (str): The directory path where the index will be saved. | ||
Defaults to "rich_query_index". | ||
- force (bool): If True, forces the index to be rebuilt even if the | ||
save directory already exists. Defaults to False. | ||
Returns: | ||
- query_index (VectorStoreIndex): The built index. | ||
Raises: | ||
- FileNotFoundError: If the save directory does not exist and force is | ||
set to False. | ||
''' | ||
if not os.path.exists(save_dir) or force: | ||
query_index = VectorStoreIndex.from_documents( | ||
documents | ||
) | ||
query_index.storage_context.persist(persist_dir=save_dir) | ||
else: | ||
query_index = load_index_from_storage( | ||
StorageContext.from_defaults(persist_dir=save_dir) | ||
) | ||
|
||
return query_index | ||
|
||
|
||
def load_index(force_reload=False): | ||
''' | ||
Load the index and query engine for the GoaT NLP system. | ||
Parameters: | ||
force_reload (bool): If True, force reload the index and rebuild it. | ||
Default is False. | ||
Returns: | ||
tuple: A tuple containing the index and query engine. | ||
''' | ||
documents = SimpleDirectoryReader( | ||
"rich_queries" | ||
).load_data() | ||
|
||
index = build_index(documents, force=force_reload) | ||
retriever = index.as_retriever(similarity_top_k=3) | ||
synthesizer = get_response_synthesizer(response_mode="compact") | ||
|
||
query_engine = GoaTAPIQueryEngine( | ||
retriever=retriever, | ||
response_synthesizer=synthesizer, | ||
llm=Settings.llm, | ||
qa_prompt=QUERY_PROMPT, | ||
) | ||
|
||
return index, query_engine | ||
|
||
|
||
index, query_engine = load_index() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from llama_index.core import PromptTemplate | ||
|
||
|
||
QUERY_PROMPT = PromptTemplate('''We need to query a database that is exposed | ||
by an API that has its own query syntax. I am giving you the query by the user, | ||
you need to convert it to the API counter part. Use the examples given below as | ||
reference: | ||
{context_str} | ||
------ | ||
The current date and time is {time} | ||
Use this for any time related calculation | ||
We have also fetched some related entities and their taxon id: | ||
{entity_taxon_map} | ||
Use the best matching result from this in the final output. | ||
Query given by the user: | ||
{query_str} | ||
Return your response in a JSON of the following pattern: | ||
{{ | ||
"url": "" | ||
}} | ||
I do not want any explanation, return ONLY the json | ||
''') | ||
|
||
ENTITY_PROMPT = '''The following query is given by the user: | ||
{query} | ||
We need to make an API call using this query. | ||
For that we need to convert all the entities in this query to their | ||
scientific counterparts (specifically their family/species name). | ||
For e.g. cat/fox will be translated to Felidae, elephant to Elephantidae. | ||
Return all entities and their converted form as a single list of strings in a JSON of the following format: | ||
{{ | ||
"entity": ["", ""] | ||
}} | ||
I do not want any explanation, return ONLY the json | ||
''' | ||
|
||
|
||
def wrap_with_entity_prompt(query: str): | ||
return ENTITY_PROMPT.format(query=query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import logging | ||
from llama_index.llms.ollama import Ollama | ||
from llama_index.core.query_engine import CustomQueryEngine | ||
from llama_index.core.retrievers import BaseRetriever | ||
from llama_index.core.response_synthesizers import BaseSynthesizer | ||
from llama_index.core import PromptTemplate | ||
from datetime import datetime | ||
|
||
|
||
logger = logging.getLogger('goat_nlp.query_engine') | ||
|
||
class GoaTAPIQueryEngine(CustomQueryEngine): | ||
""" | ||
Custom query engine for the GoaT API. | ||
Attributes: | ||
retriever (BaseRetriever): The retriever used to retrieve nodes. | ||
response_synthesizer (BaseSynthesizer): The synthesizer used to | ||
generate responses. | ||
llm (Ollama): The language model used for completion. | ||
qa_prompt (PromptTemplate): The template for the QA prompt. | ||
""" | ||
|
||
retriever: BaseRetriever | ||
response_synthesizer: BaseSynthesizer | ||
llm: Ollama | ||
qa_prompt: PromptTemplate | ||
|
||
def custom_query(self, query_str: str, entity_taxon_map: dict): | ||
""" | ||
Custom query method. | ||
Args: | ||
query_str (str): The query string. | ||
entity_taxon_map (dict): The entity taxon map. | ||
Returns: | ||
str: The response generated by the language model. | ||
""" | ||
nodes = self.retriever.retrieve(query_str) | ||
|
||
context_str = "\n\n".join([n.node.get_content() for n in nodes]) | ||
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | ||
populated_prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str, | ||
entity_taxon_map=entity_taxon_map, | ||
time=current_time) | ||
logger.info(populated_prompt) | ||
response = self.llm.complete( | ||
populated_prompt | ||
) | ||
|
||
return str(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from llama_index.core import Settings | ||
from prompt import wrap_with_entity_prompt | ||
import os | ||
import json | ||
import requests | ||
import logging | ||
|
||
logger = logging.getLogger('goat_nlp.query_reformulation') | ||
|
||
|
||
def fetch_related_taxons(query: str): | ||
""" | ||
Fetches related taxons for a given query. | ||
Args: | ||
query (str): The query for which related taxons need to be fetched. | ||
Returns: | ||
dict: A dictionary mapping entities to their corresponding taxons. | ||
Raises: | ||
Exception: If an error occurs while making the API call to retrieve | ||
taxons. | ||
Example: | ||
>>> query = "find the number of assemblies for bat" | ||
>>> fetch_related_taxons(query) | ||
{'bat': ['Chiroptera', 'bat']} | ||
""" | ||
entity_taxon_map = {} | ||
for _ in range(int(os.getenv("RETRY_COUNT", 3))): | ||
try: | ||
llm_response = Settings.llm.complete(wrap_with_entity_prompt(query)) | ||
entities = json.loads(llm_response.text)['entity'] | ||
logger.info(entities) | ||
entity_taxon_map = goat_api_call_for_taxon(entities) | ||
break | ||
except Exception: | ||
pass | ||
return entity_taxon_map | ||
|
||
|
||
def goat_api_call_for_taxon(entities: list) -> dict: | ||
""" | ||
Makes an API call to retrieve taxons for a list of entities. | ||
Args: | ||
entities (list): A list of entities for which taxons need to be | ||
retrieved. | ||
Returns: | ||
dict: A dictionary mapping entities to their corresponding taxons. | ||
Raises: | ||
Exception: If an error occurs while making the API call to retrieve | ||
taxons. | ||
Example: | ||
>>> entities = ["bat", "cat", "dog"] | ||
>>> goat_api_call_for_taxon(entities) | ||
{'bat': ['Chiroptera', 'bat'], 'cat': ['Felis', 'cat'], | ||
'dog': ['Canis', 'dog']} | ||
""" | ||
entity_result_map = {} | ||
for entity in entities: | ||
try: | ||
response = requests.get(os.getenv('GOAT_BASE_URL', 'https://goat.genomehubs.org/api/v2') | ||
+ f"/lookup?searchTerm={entity}" | ||
+ "&result=taxon&taxonomy=ncbi") | ||
json_data = response.json() if response and response.status_code == 200 else None | ||
entity_result_map[entity] = [x["result"] for x in json_data['results']] | ||
except Exception: | ||
pass | ||
return entity_result_map |
Oops, something went wrong.