Skip to content

Commit

Permalink
Improved query pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
deepnayak committed Jul 4, 2024
1 parent d9faf5c commit 88e5736
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 106 deletions.
51 changes: 48 additions & 3 deletions src/agent/component_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
INDEX_PROMPT,
INTENT_PROMPT,
RANK_PROMPT,
RECORD_PROMPT,
TIME_PROMPT,
)

Expand Down Expand Up @@ -120,12 +121,12 @@ def construct_query(input: str, state: Dict[str, Any]):
entities += f"* {entity['singular_form']},"
entities += f"* {entity['plural_form']},"

entities.removesuffix(",")
entities = entities.removesuffix(",")

query = ""

if entities != "":
query += f"tax_tree({entities}) AND "
query += f"tax_name({entities}) AND "

if state["rank"]["rank"] != "":
query += f"tax_rank({state['rank']['rank']}) AND "
Expand All @@ -151,7 +152,7 @@ def construct_query(input: str, state: Dict[str, Any]):
+ f'{attribute["value"]} AND '
)

query.removesuffix(" AND ")
query = query.removesuffix(" AND ")

state["query"] = query

Expand All @@ -170,3 +171,47 @@ def construct_url(input: str, state: Dict[str, Any]):
state["final_url"] = (
base_url + endpoint + "query=" + urllib.parse.quote(state["query"]) + suffix
)


def identify_record(input: str, state: Dict[str, Any]):
entities = ""
for entity in state["entity"]["entities"]:
entities += f"{entity['singular_form']},"
entities += f"{entity['plural_form']},"
entities += f"{entity['scientific_name']},"
entities += f"* {entity['singular_form']},"
entities += f"* {entity['plural_form']},"

query_url = (
"https://goat.genomehubs.org/api/v2/search?query="
+ urllib.parse.quote(f"tax_name({entities})")
+ "&result=taxon"
)

response = requests.get(query_url)
response_parsed = response.json()
cleaned_taxons = []

for res in response_parsed["results"]:
cleaned_taxons.append(
{
"taxon_id": res["result"]["taxon_id"],
"taxon_rank": res["result"]["taxon_rank"],
"scientific_name": res["result"]["scientific_name"],
"taxon_names": res["result"]["taxon_names"],
}
)

taxon_response = Settings.llm.complete(
RECORD_PROMPT.format(query=input, results=json.dumps(cleaned_taxons, indent=4))
).text
state["record"] = json.loads(extract_json_str(taxon_response))

if "taxon_id" not in state["record"] or "explanation" not in state["record"]:
raise ValueError("Invalid response from model at record identification stage.")

state["final_url"] = (
"https://goat.genomehubs.org/record?recordId="
+ str(state["record"]["taxon_id"])
+ f"&result={state['index']['classification']}"
)
15 changes: 14 additions & 1 deletion src/agent/query_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
identify_index,
identify_intent,
identify_rank,
identify_record,
identify_time_frame,
)
from agent.goat_query_component import GoatQueryComponent
Expand All @@ -24,8 +25,20 @@
"time": GoatQueryComponent(fn=identify_time_frame),
"query": GoatQueryComponent(fn=construct_query),
"url": GoatQueryComponent(fn=construct_url),
"record": GoatQueryComponent(fn=identify_record),
}
)


qp.add_chain(["intent", "index", "entity", "rank", "attribute", "time", "query", "url"])
qp.add_chain(["intent", "index", "entity"])

qp.add_link(
"entity",
"record",
condition_fn=lambda x: x["state"]["intent"]["intent"] == "record",
)
qp.add_link(
"entity", "rank", condition_fn=lambda x: x["state"]["intent"]["intent"] != "record"
)
qp.add_chain(["rank", "attribute", "time", "query", "url"])
# qp.add_chain(["intent", "index", "entity", "rank", "time", "query", "url"])
15 changes: 7 additions & 8 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
import llama_index.core
import phoenix as px
from flask import Flask, render_template, request
from llama_index.core import Settings
from llama_index.llms.ollama import Ollama
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

from agent.query_pipeline import qp
from index import load_index

# agent_worker = QueryPipelineAgentWorker(qp)
# agent = agent_worker.as_agent(callback_manager=CallbackManager([]), verbose=True)
Settings.llm = Ollama(
model="llama3",
base_url=os.getenv("OLLAMA_HOST_URL", "http://127.0.0.1:11434"),
request_timeout=36000.0,
)

px.launch_app()
llama_index.core.set_global_handler("arize_phoenix")
Expand All @@ -39,11 +43,6 @@ def home():
return render_template("chat.html")


@app.route("/rebuildIndex")
def index():
load_index(force_reload=True)


@app.route("/chat", methods=["POST"])
def chat():
# agent.reset()
Expand Down
91 changes: 0 additions & 91 deletions src/index.py

This file was deleted.

40 changes: 37 additions & 3 deletions src/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,44 @@
}}
If there are no attributes in the query, return an empty list.
REMEMBER: The attributes list in your response must be filled ONLY if
**REMEMBER:** The attributes list in your response must be filled **ONLY** if
the user has explicitly mentioned that attribute in the query.
DO NOT assume that an attribute is IMPLIED in the query.
IN MOST CASES, YOUR RESPONSE WILL BE AN EMPTY LIST.
This means that "ebp_date" (or any other attribute) will not be included in
the list unless the phrase "ebp_date" is given in the query by the user.
**DO NOT** assume that an attribute is **IMPLIED** in the query.
**IN MOST CASES, YOUR RESPONSE WILL BE AN EMPTY LIST.**
```json
"""
)

RECORD_PROMPT = PromptTemplate(
"""
You are an intelligent assistant who **ONLY ANSWERS IN JSON FORMAT**.
A user is trying to query a genomics database.
We have already queried the database against the user's query.
We have a set of results from the database, we need to pick the best match.
The query by the user is as follows:
`{query}`
The results from the database are as follows:
{results}
You need to return the best taxon from the results in the following JSON format:
{{
"taxon_id": "...",
"explanation": "..."
}}
The taxon_id HAS TO BE AN INTEGER.
```json
"""
Expand Down

0 comments on commit 88e5736

Please sign in to comment.