Skip to content

Commit

Permalink
chore: Update code formatting and editor settings
Browse files Browse the repository at this point in the history
Refined logic for time based queries
  • Loading branch information
deepnayak committed Jun 18, 2024
1 parent 3abfcdd commit 43c77e6
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 91 deletions.
9 changes: 6 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
{
"git.ignoreLimitWarning": true,
"editor.formatOnSave": true,
"flake8.args": ["--max-line-length=88"],
"flake8.args": [
"--max-line-length=88"
],
"[python]": {
"editor.codeActionsOnSave": {
"source.organizeImports.python": "explicit"
},
"editor.defaultFormatter": "ms-python.black-formatter"
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
}
}
}
41 changes: 26 additions & 15 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,45 @@
app = Flask("goat_nlp")

handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
app.logger.addHandler(handler)
app.logger.setLevel(logging.INFO)


def construct_url(json_output):
base_url = "https://goat.genomehubs.org/"
endpoint = "search?"
suffix = "&result=taxon&summaryValues=count&taxonomy=ncbi&offset=0"
suffix += "&fields=assembly_level%2Cassembly_span%2Cgenome_size%2C"
suffix += "chromosome_number%2Chaploid_number&names=common_name&ranks="
suffix += "&includeEstimates=false&size=100"

if json_output['intent'] == 'count':
if json_output["intent"] == "count":
endpoint = "count?"
elif json_output['intent'] == 'record':
elif json_output["intent"] == "record":
endpoint = "record?"

params = []

if 'taxon' in json_output:
if "taxon" in json_output:
params.append(f"tax_tree(* {json_output['taxon']})")
if 'rank' in json_output:
if "rank" in json_output:
params.append(f"tax_rank({json_output['rank']})")
if 'field' in json_output:
if "field" in json_output:
params.append(f"{json_output['field']}")
if "time_frame_query" in json_output:
params.append(f"{json_output['time_frame_query']}")
suffix = "&result=assembly&summaryValues=count&taxonomy=ncbi&offset=0"
suffix += "&fields=assembly_level%2Cassembly_span%2Cgenome_size%2C"
suffix += "chromosome_number%2Chaploid_number&names=common_name&ranks="
suffix += "&includeEstimates=false&size=100"

query_string = " AND ".join(params)
return base_url + endpoint + "query=" + urllib.parse.quote_plus(query_string) + "&result=taxon&summaryValues=count&taxonomy=ncbi&offset=0&fields=assembly_level%2Cassembly_span%2Cgenome_size%2Cchromosome_number%2Chaploid_number&names=common_name&ranks=&includeEstimates=false&size=100"
return (
base_url + endpoint + "query=" + urllib.parse.quote_plus(query_string) + suffix
)


def chat_bot_rag(query):
Expand All @@ -51,28 +62,28 @@ def chat_bot_rag(query):
return construct_url(json.loads(query_engine.custom_query(query)))


@app.route('/')
@app.route("/")
def home():
return render_template('chat.html')
return render_template("chat.html")


@app.route('/rebuildIndex')
@app.route("/rebuildIndex")
def index():
load_index(force_reload=True)


@app.route('/chat', methods=['POST'])
@app.route("/chat", methods=["POST"])
def chat():
user_message = request.form['user_input']
user_message = request.form["user_input"]
bot_message = chat_bot_rag(user_message)

try:
bot_message = json.loads(bot_message)["url"]
except Exception:
pass

return jsonify({'response': str(bot_message)})
return jsonify({"response": str(bot_message)})


if __name__ == '__main__':
if __name__ == "__main__":
app.run(debug=True)
33 changes: 16 additions & 17 deletions src/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

load_dotenv()
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
Settings.llm = Ollama(model="codellama", base_url=os.
getenv("OLLAMA_HOST_URL", "http://127.0.0.1:11434"),
request_timeout=36000.0)
Settings.llm = Ollama(
model="codellama",
base_url=os.getenv("OLLAMA_HOST_URL", "http://127.0.0.1:11434"),
request_timeout=36000.0,
)
Settings.chunk_size = 256


def build_index(documents,
save_dir="query_index",
force=False):
'''
def build_index(documents, save_dir="query_index", force=False):
"""
Build the index from the given rich queries and save it in the specified
directory.
Expand All @@ -42,11 +42,9 @@ def build_index(documents,
Raises:
- FileNotFoundError: If the save directory does not exist and force is
set to False.
'''
"""
if not os.path.exists(save_dir) or force:
query_index = VectorStoreIndex(
documents
)
query_index = VectorStoreIndex(documents)
query_index.storage_context.persist(persist_dir=save_dir)
else:
query_index = load_index_from_storage(
Expand All @@ -57,7 +55,7 @@ def build_index(documents,


def load_index(force_reload=False):
'''
"""
Load the index and query engine for the GoaT NLP system.
Parameters:
Expand All @@ -67,14 +65,15 @@ def load_index(force_reload=False):
Returns:
tuple: A tuple containing the index and query engine.
'''
"""
query_list = json.load(open("queries/script_generated_queries.json"))
for x in query_list:
x.pop('api_query', None)
question_store = {x['english_query']: x for x in query_list}
x.pop("api_query", None)
question_store = {x["english_query"]: x for x in query_list}

index = build_index([TextNode(text=x) for x in question_store.keys()],
force=force_reload)
index = build_index(
[TextNode(text=x) for x in question_store.keys()], force=force_reload
)
retriever = index.as_retriever(similarity_top_k=5)
synthesizer = get_response_synthesizer(response_mode="compact")

Expand Down
15 changes: 13 additions & 2 deletions src/prompt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from llama_index.core import PromptTemplate


QUERY_PROMPT = PromptTemplate('''We need to parse a query given by the user.
QUERY_PROMPT = PromptTemplate(
"""We need to parse a query given by the user.
The user is asking a genomics question, we need to parse the query into different parts.
Use the examples given below as reference:
Expand All @@ -19,14 +20,24 @@
{query_str}
------
Omit fields that are not required and remember that the examples given to you are
just basic samples, you may have to combine multiple fields to get the correct output.
Time related queries will always have the last_updated field in the time_frame_query
part.
You can use >= and <= operators to filter such data. e.g. last_updated>=2021-01-01
------
ONLY RETURN THE JSON OF THE FOLLOWING FORMAT AND NOTHING ELSE
{{
"taxon": "...",
"rank": "...",
"phrase": "...",
"intent": "...",
"field": "..."
"time_frame": "...",
"time_frame_query": "..."
...
}}
''')
"""
)
22 changes: 12 additions & 10 deletions src/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime


logger = logging.getLogger('goat_nlp.query_engine')
logger = logging.getLogger("goat_nlp.query_engine")


class GoaTAPIQueryEngine(CustomQueryEngine):
Expand Down Expand Up @@ -42,16 +42,18 @@ def custom_query(self, query_str: str):
"""
nodes = self.retriever.retrieve(query_str)

context_str = "\n\n".join([json.dumps(self.question_store[n.node.get_content()],
indent=2) for n in nodes])
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
populated_prompt = self.qa_prompt.format(context_str=context_str,
query_str=query_str,
time=current_time)
logger.info(populated_prompt)
response = self.llm.complete(
populated_prompt
context_str = "\n\n".join(
[
json.dumps(self.question_store[n.node.get_content()], indent=2)
for n in nodes
]
)
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
populated_prompt = self.qa_prompt.format(
context_str=context_str, query_str=query_str, time=current_time
)
logger.info(populated_prompt)
response = self.llm.complete(populated_prompt)
logger.info(response)

return str(response)
Loading

0 comments on commit 43c77e6

Please sign in to comment.