Skip to content

Commit

Permalink
Llama3 Integration and RAG pipeline for querying (#2)
Browse files Browse the repository at this point in the history
* Initial commit containing Ollama and Llama3 code

* Removed ngrok URL

* Improved performance and integrated GoaT API entity lookup

* Implemented changes requested in PR review

* Minor changes

* Minor changes
  • Loading branch information
deepnayak authored Jun 15, 2024
1 parent b54dac0 commit 99c3669
Show file tree
Hide file tree
Showing 11 changed files with 607 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .env.dist
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
OLLAMA_HOST_URL=http://127.0.0.1:11434
RETRY_COUNT=3
GOAT_BASE_URL=https://goat.genomehubs.org/api/v2
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
prompts/__pycache__
rich_query_index
.DS_Store
.env
__pycache__
70 changes: 70 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

# Installation Guide

This guide provides step-by-step instructions to set up the project after cloning the repository.

## Step 1: Install Miniconda

Download and install Miniconda using the following commands:

```bash
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh > Miniconda3.sh
chmod +x Miniconda3.sh
./Miniconda3.sh
```

## Step 2: Create a Conda Environment

Create a new Conda environment with Python 3.12 and activate it:

```bash
conda create -y -n nlp python=3.12
conda activate nlp
```

## Step 3: Clone the Repository

Clone the repository using the specified branch:

```bash
git clone https://github.com/genomehubs/goat-nlp
cd goat-nlp
```

## Step 4: Install Python Dependencies

Install the required Python packages using pip:

```bash
pip install -r requirements.txt
```

## Step 5: Install Ollama

Install Ollama using the provided script:

```bash
curl -fsSL https://ollama.com/install.sh | sh
```

## Step 6: Run Ollama

Run the Ollama application:

```bash
ollama run llama3
```

## Step 7: Start the Flask Application

Set the necessary environment variables and start the Flask application:

```bash
export OLLAMA_HOST_URL=http://127.0.0.1:11434
export RETRY_COUNT=5
export GOAT_BASE_URL=https://goat.genomehubs.org/api/v2
python -m flask run
```

The UI will be available at `http://localhost:5000/`

49 changes: 49 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
from flask import Flask, request, render_template, jsonify
from index import load_index, query_engine
from query_reformulation import fetch_related_taxons
import json
import logging


app = Flask("goat_nlp")

handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
app.logger.addHandler(handler)
app.logger.setLevel(logging.INFO)


def chat_bot_rag(query):
entity_taxon_map = fetch_related_taxons(query)

return query_engine.custom_query(query,
entity_taxon_map)


@app.route('/')
def home():
return render_template('chat.html')


@app.route('/rebuildIndex')
def index():
load_index(force_reload=True)


@app.route('/chat', methods=['POST'])
def chat():
user_message = request.form['user_input']
bot_message = chat_bot_rag(user_message)

try:
bot_message = json.loads(bot_message)["url"]
except Exception:
pass

return jsonify({'response': str(bot_message)})


if __name__ == '__main__':
app.run(debug=True)
87 changes: 87 additions & 0 deletions index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from dotenv import load_dotenv
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core import load_index_from_storage
from llama_index.core import get_response_synthesizer
from llama_index.core import SimpleDirectoryReader
from llama_index.core import Settings
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os

from prompt import QUERY_PROMPT
from query_engine import GoaTAPIQueryEngine


load_dotenv()
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
Settings.llm = Ollama(model="llama3", base_url=os.
getenv("OLLAMA_HOST_URL", "http://127.0.0.1:11434"),
request_timeout=36000.0)
Settings.chunk_size = 256


def build_index(documents,
save_dir="rich_query_index",
force=False):
'''
Build the index from the given rich queries and save it in the specified
directory.
Parameters:
- documents (list): A list of rich queries to build the index from.
- save_dir (str): The directory path where the index will be saved.
Defaults to "rich_query_index".
- force (bool): If True, forces the index to be rebuilt even if the
save directory already exists. Defaults to False.
Returns:
- query_index (VectorStoreIndex): The built index.
Raises:
- FileNotFoundError: If the save directory does not exist and force is
set to False.
'''
if not os.path.exists(save_dir) or force:
query_index = VectorStoreIndex.from_documents(
documents
)
query_index.storage_context.persist(persist_dir=save_dir)
else:
query_index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=save_dir)
)

return query_index


def load_index(force_reload=False):
'''
Load the index and query engine for the GoaT NLP system.
Parameters:
force_reload (bool): If True, force reload the index and rebuild it.
Default is False.
Returns:
tuple: A tuple containing the index and query engine.
'''
documents = SimpleDirectoryReader(
"rich_queries"
).load_data()

index = build_index(documents, force=force_reload)
retriever = index.as_retriever(similarity_top_k=3)
synthesizer = get_response_synthesizer(response_mode="compact")

query_engine = GoaTAPIQueryEngine(
retriever=retriever,
response_synthesizer=synthesizer,
llm=Settings.llm,
qa_prompt=QUERY_PROMPT,
)

return index, query_engine


index, query_engine = load_index()
49 changes: 49 additions & 0 deletions prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from llama_index.core import PromptTemplate


QUERY_PROMPT = PromptTemplate('''We need to query a database that is exposed
by an API that has its own query syntax. I am giving you the query by the user,
you need to convert it to the API counter part. Use the examples given below as
reference:
{context_str}
------
The current date and time is {time}
Use this for any time related calculation
We have also fetched some related entities and their taxon id:
{entity_taxon_map}
Use the best matching result from this in the final output.
Query given by the user:
{query_str}
Return your response in a JSON of the following pattern:
{{
"url": ""
}}
I do not want any explanation, return ONLY the json
''')

ENTITY_PROMPT = '''The following query is given by the user:
{query}
We need to make an API call using this query.
For that we need to convert all the entities in this query to their
scientific counterparts (specifically their family/species name).
For e.g. cat/fox will be translated to Felidae, elephant to Elephantidae.
Return all entities and their converted form as a single list of strings in a JSON of the following format:
{{
"entity": ["", ""]
}}
I do not want any explanation, return ONLY the json
'''


def wrap_with_entity_prompt(query: str):
return ENTITY_PROMPT.format(query=query)
52 changes: 52 additions & 0 deletions query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
from llama_index.llms.ollama import Ollama
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.response_synthesizers import BaseSynthesizer
from llama_index.core import PromptTemplate
from datetime import datetime


logger = logging.getLogger('goat_nlp.query_engine')

class GoaTAPIQueryEngine(CustomQueryEngine):
"""
Custom query engine for the GoaT API.
Attributes:
retriever (BaseRetriever): The retriever used to retrieve nodes.
response_synthesizer (BaseSynthesizer): The synthesizer used to
generate responses.
llm (Ollama): The language model used for completion.
qa_prompt (PromptTemplate): The template for the QA prompt.
"""

retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
llm: Ollama
qa_prompt: PromptTemplate

def custom_query(self, query_str: str, entity_taxon_map: dict):
"""
Custom query method.
Args:
query_str (str): The query string.
entity_taxon_map (dict): The entity taxon map.
Returns:
str: The response generated by the language model.
"""
nodes = self.retriever.retrieve(query_str)

context_str = "\n\n".join([n.node.get_content() for n in nodes])
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
populated_prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str,
entity_taxon_map=entity_taxon_map,
time=current_time)
logger.info(populated_prompt)
response = self.llm.complete(
populated_prompt
)

return str(response)
75 changes: 75 additions & 0 deletions query_reformulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from llama_index.core import Settings
from prompt import wrap_with_entity_prompt
import os
import json
import requests
import logging

logger = logging.getLogger('goat_nlp.query_reformulation')


def fetch_related_taxons(query: str):
"""
Fetches related taxons for a given query.
Args:
query (str): The query for which related taxons need to be fetched.
Returns:
dict: A dictionary mapping entities to their corresponding taxons.
Raises:
Exception: If an error occurs while making the API call to retrieve
taxons.
Example:
>>> query = "find the number of assemblies for bat"
>>> fetch_related_taxons(query)
{'bat': ['Chiroptera', 'bat']}
"""
entity_taxon_map = {}
for _ in range(int(os.getenv("RETRY_COUNT", 3))):
try:
llm_response = Settings.llm.complete(wrap_with_entity_prompt(query))
entities = json.loads(llm_response.text)['entity']
logger.info(entities)
entity_taxon_map = goat_api_call_for_taxon(entities)
break
except Exception:
pass
return entity_taxon_map


def goat_api_call_for_taxon(entities: list) -> dict:
"""
Makes an API call to retrieve taxons for a list of entities.
Args:
entities (list): A list of entities for which taxons need to be
retrieved.
Returns:
dict: A dictionary mapping entities to their corresponding taxons.
Raises:
Exception: If an error occurs while making the API call to retrieve
taxons.
Example:
>>> entities = ["bat", "cat", "dog"]
>>> goat_api_call_for_taxon(entities)
{'bat': ['Chiroptera', 'bat'], 'cat': ['Felis', 'cat'],
'dog': ['Canis', 'dog']}
"""
entity_result_map = {}
for entity in entities:
try:
response = requests.get(os.getenv('GOAT_BASE_URL', 'https://goat.genomehubs.org/api/v2')
+ f"/lookup?searchTerm={entity}"
+ "&result=taxon&taxonomy=ncbi")
json_data = response.json() if response and response.status_code == 200 else None
entity_result_map[entity] = [x["result"] for x in json_data['results']]
except Exception:
pass
return entity_result_map
Loading

0 comments on commit 99c3669

Please sign in to comment.