Skip to content

Commit

Permalink
WIP: research agent orchestrator
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Apr 26, 2024
1 parent bb691e8 commit 068f5f4
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 88 deletions.
84 changes: 84 additions & 0 deletions examples/research_agent/research_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
import sys
import kuzu

from langchain.prompts import PromptTemplate

from motleycrew.storage import MotleyKuzuGraphStore
from motleycrew.tool.llm_tool import LLMTool

logging.basicConfig(stream=sys.stdout, level=logging.INFO)


QUESTION_PRIORITIZATION_TEMPLATE = PromptTemplate(
template=(
"You are provided with the following list of questions:"
" {unanswered_questions} \n"
" Your task is to choose one question from the above list"
" that is the most pertinent to the following query:\n"
" '{original_question}' \n"
" Respond with one question out of the provided list of questions."
" Return the questions as it is without any edits."
" Format your response like:\n"
" #. question"
),
input_variables=["unanswered_questions", "original_question"],
)


class KnowledgeGainingOrchestrator:
def __init__(self, db_path: str):
self.db = kuzu.Database(db_path)
self.storage = MotleyKuzuGraphStore(
self.db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"}
)

self.question_prioritization_tool = LLMTool(
name="question_prioritization_tool",
description="find the most important question",
prompt=QUESTION_PRIORITIZATION_TEMPLATE,
)
self.question_generation_tool = None

def get_unanswered_questions(self, only_without_children: bool = False) -> list[dict]:
if only_without_children:
query = "MATCH (n1:{}) WHERE n1.answer IS NULL AND NOT (n1)-[:{}]->(:{}) RETURN n1;".format(
self.storage.node_table_name, self.storage.rel_table_name, self.storage.node_table_name
)
else:
query = "MATCH (n1:{}) WHERE n1.answer IS NULL RETURN n1;".format(self.storage.node_table_name)

query_result = self.storage.run_query(query)
return [row[0] for row in query_result] # flatten

def __call__(self, query: str, max_iter: int):
self.storage.create_entity({"question": query})

for iter_n in range(max_iter):
logging.info("====== Iteration %s of %s ======", iter_n, max_iter)

unanswered_questions = self.get_unanswered_questions(only_without_children=True)
logging.info("Loaded unanswered questions: %s", unanswered_questions)

tool_input = "\n".join(f"{i}. {question}" for i, question in enumerate(unanswered_questions))
most_pertinent_question_raw = self.question_prioritization_tool.invoke(tool_input)
logging.info("Most pertinent question according to the tool: %s", most_pertinent_question_raw)

i, most_pertinent_question_text = most_pertinent_question_raw.split(".", 1)
assert i < len(unanswered_questions)

most_pertinent_question = unanswered_questions[i]
assert most_pertinent_question_text.strip() == most_pertinent_question["question"].strip()

logging.info("Generating new questions")


if __name__ == "__main__":
from pathlib import Path
import shutil

here = Path(__file__).parent
db_path = here / "research_db"
shutil.rmtree(db_path, ignore_errors=True)

orchestrator = KnowledgeGainingOrchestrator(db_path=str(db_path))
2 changes: 1 addition & 1 deletion motleycrew/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .kuzu_graph_store import MotleyQuestionGraphStore
from .kuzu_graph_store import MotleyKuzuGraphStore
185 changes: 98 additions & 87 deletions motleycrew/storage/kuzu_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@

from typing import Any, Dict, List, Optional

import json
import kuzu


class MotleyQuestionGraphStore:
IS_SUBQUESTION_PREDICATE = "IS_SUBQUESTION"

class MotleyKuzuGraphStore:
def __init__(
self,
database: Any,
node_table_name: str = "question",
node_table_schema: dict[str, str],
node_table_name: str = "entity",
rel_table_name: str = "links",
**kwargs: Any,
) -> None:
self.database = database
self.connection = kuzu.Connection(database)
self.node_table_schema = node_table_schema
self.node_table_name = node_table_name
self.rel_table_name = rel_table_name
self.init_schema()
Expand All @@ -28,10 +29,13 @@ def init_schema(self) -> None:
"""Initialize schema if the tables do not exist."""
node_tables = self.connection._get_node_table_names()
if self.node_table_name not in node_tables:
self.connection.execute(
"CREATE NODE TABLE %s (ID SERIAL, question STRING, answer STRING, context STRING[], PRIMARY KEY(ID))"
% self.node_table_name
node_table_schema_expr = ", ".join(
["id SERIAL"]
+ [f"{name} {datatype}" for name, datatype in self.node_table_schema.items()]
+ ["PRIMARY KEY(id)"]
)
self.connection.execute("CREATE NODE TABLE {} ({})".format(self.node_table_name, node_table_schema_expr))

rel_tables = self.connection._get_rel_table_names()
rel_tables = [rel_table["name"] for rel_table in rel_tables]
if self.rel_table_name not in rel_tables:
Expand All @@ -45,121 +49,117 @@ def init_schema(self) -> None:
def client(self) -> Any:
return self.connection

def check_question_exists(self, question_id: int) -> bool:
def check_entity_exists(self, entity_id: int) -> bool:
is_exists_result = self.connection.execute(
"MATCH (n:%s) WHERE n.ID = $question_id RETURN n.ID" % self.node_table_name,
{"question_id": question_id},
"MATCH (n:%s) WHERE n.id = $entity_id RETURN n.id" % self.node_table_name,
{"entity_id": entity_id},
)
return is_exists_result.has_next()

def get_question(self, question_id: int) -> Optional[dict]:
def get_entity(self, entity_id: int) -> Optional[dict]:
query = """
MATCH (n1:%s)
WHERE n1.ID = $question_id
WHERE n1.id = $entity_id
RETURN n1;
"""
prepared_statement = self.connection.prepare(query % self.node_table_name)
query_result = self.connection.execute(prepared_statement, {"question_id": question_id})
query_result = self.connection.execute(prepared_statement, {"entity_id": entity_id})

if query_result.has_next():
row = query_result.get_next()
return row[0]

def get_subquestions(self, question_id: int) -> List[int]:
query = """
MATCH (n1:%s)-[r:%s]->(n2:%s)
WHERE n1.ID = $question_id
AND r.predicate = $is_subquestion_predicate
RETURN n2.ID;
"""
prepared_statement = self.connection.prepare(
query % (self.node_table_name, self.rel_table_name, self.node_table_name)
)
query_result = self.connection.execute(
prepared_statement,
{
"question_id": question_id,
"is_subquestion_predicate": MotleyQuestionGraphStore.IS_SUBQUESTION_PREDICATE,
},
)
retval = []
while query_result.has_next():
row = query_result.get_next()
retval.append(row[0])
return retval
item = row[0]
return item

def create_question(self, question: str) -> int:
def create_entity(self, entity: dict) -> int:
"""Create a new entity and return its id"""
create_result = self.connection.execute(
"CREATE (n:%s {question: $question}) " "RETURN n.ID" % self.node_table_name,
{"question": question},
"CREATE (n:{} $entity) RETURN n.id".format(self.node_table_name),
{"entity": entity},
)
assert create_result.has_next()
return create_result.get_next()[0]

def create_subquestion(self, question_id: int, subquestion: str) -> int:
def create_subquestion_rel(connection: Any, question_id: int, subquestion_id: int) -> None:
connection.execute(
(
"MATCH (n1:{}), (n2:{}) WHERE n1.ID = $question_id AND n2.ID = $subquestion_id "
"CREATE (n1)-[r:{} {{predicate: $is_subquestion_predicate}}]->(n2)"
).format(self.node_table_name, self.node_table_name, self.rel_table_name),
{
"question_id": question_id,
"subquestion_id": subquestion_id,
"is_subquestion_predicate": MotleyQuestionGraphStore.IS_SUBQUESTION_PREDICATE,
},
)

if not self.check_question_exists(question_id):
raise Exception(f"No question with id {question_id}")

subquestion_id = self.create_question(subquestion)
create_subquestion_rel(self.connection, question_id=question_id, subquestion_id=subquestion_id)
return subquestion_id
def create_rel(self, from_id: int, to_id: int, predicate: str) -> None:
self.connection.execute(
(
"MATCH (n1:{}), (n2:{}) WHERE n1.id = $from_id AND n2.id = $to_id "
"CREATE (n1)-[r:{} {{predicate: $predicate}}]->(n2)"
).format(self.node_table_name, self.node_table_name, self.rel_table_name),
{
"from_id": from_id,
"to_id": to_id,
"predicate": predicate,
},
)

def delete_question(self, question_id: int) -> None:
"""Deletes question and its relations."""
def delete_entity(self, entity_id: int) -> None:
"""Delete a given entity and its relations"""

def delete_rels(connection: Any, question_id: int) -> None:
def delete_rels(connection: Any, entity_id: int) -> None:
# Undirected relation removal is not supported for some reason
connection.execute(
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $question_id DELETE r".format(
self.node_table_name, self.rel_table_name, self.node_table_name
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.id = $entity_id DELETE r;"
"MATCH (n1:{})<-[r:{}]-(n2:{}) WHERE n1.id = $entity_id DELETE r".format(
self.node_table_name,
self.rel_table_name,
self.node_table_name,
self.node_table_name,
self.rel_table_name,
self.node_table_name,
),
{"question_id": question_id},
{"entity_id": entity_id},
)
connection.execute(
"MATCH (n1:{})<-[r:{}]-(n2:{}) WHERE n1.ID = $question_id DELETE r".format(
"MATCH (n1:{})<-[r:{}]-(n2:{}) WHERE n1.id = $entity_id DELETE r".format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
{"question_id": question_id},
{"entity_id": entity_id},
)

def delete_question(connection: Any, question_id: int) -> None:
def delete_entity(connection: Any, entity_id: int) -> None:
connection.execute(
"MATCH (n:%s) WHERE n.ID = $question_id DELETE n" % self.node_table_name,
{"question_id": question_id},
"MATCH (n:%s) WHERE n.id = $entity_id DELETE n" % self.node_table_name,
{"entity_id": entity_id},
)

delete_rels(self.connection, question_id)
delete_question(self.connection, question_id)
delete_rels(self.connection, entity_id)
delete_entity(self.connection, entity_id)

def set_property(self, entity_id: int, property_name: str, property_value: Any):
query = """
MATCH (n1:{})
WHERE n1.id = $entity_id
SET n1.{} = $property_value;
"""
prepared_statement = self.connection.prepare(query.format(self.node_table_name, property_name))
self.connection.execute(prepared_statement, {"entity_id": entity_id, "property_value": property_value})

def run_query(self, query: str, parameters: Optional[dict] = None) -> list[list]:
"""Run a Cypher query and return the results"""
query_result = self.connection.execute(query=query, parameters=parameters)
retval = []
while query_result.has_next():
retval.append(query_result.get_next())
return retval

@classmethod
def from_persist_dir(
cls,
persist_dir: str,
node_table_schema: dict[str, str],
node_table_name: str = "entity",
rel_table_name: str = "links",
) -> "MotleyQuestionGraphStore":
) -> "MotleyKuzuGraphStore":
"""Load from persist dir."""
try:
import kuzu
except ImportError:
raise ImportError("Please install kuzu: pip install kuzu")
database = kuzu.Database(persist_dir)
return cls(database, node_table_name, rel_table_name)
return cls(database, node_table_schema, node_table_name, rel_table_name)

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyQuestionGraphStore":
def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore":
"""Initialize graph store from configuration dictionary.
Args:
Expand All @@ -173,25 +173,36 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyQuestionGraphStore":

if __name__ == "__main__":
from pathlib import Path
import shutil

here = Path(__file__).parent
db_path = here / "test1"
shutil.rmtree(db_path, ignore_errors=True)
db = kuzu.Database(str(db_path))
graph_store = MotleyQuestionGraphStore(db)
graph_store = MotleyKuzuGraphStore(
db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"}
)

IS_SUBQUESTION_PREDICATE = "is_subquestion"

q1_id = graph_store.create_entity({"question": "q1"})
assert graph_store.get_entity(q1_id)["question"] == "q1"

q1_id = graph_store.create_question("q1")
assert graph_store.get_question(q1_id)["question"] == "q1"
q2_id = graph_store.create_entity({"question": "q2"})
q3_id = graph_store.create_entity({"question": "q3"})
q4_id = graph_store.create_entity({"question": "q4"})
graph_store.create_rel(q1_id, q2_id, IS_SUBQUESTION_PREDICATE)
graph_store.create_rel(q1_id, q3_id, IS_SUBQUESTION_PREDICATE)
graph_store.create_rel(q3_id, q4_id, IS_SUBQUESTION_PREDICATE)

q2_id = graph_store.create_subquestion(q1_id, "q2")
q3_id = graph_store.create_subquestion(q1_id, "q3")
q4_id = graph_store.create_subquestion(q3_id, "q4")
graph_store.delete_entity(q4_id)
assert graph_store.get_entity(q4_id) is None

assert set(graph_store.get_subquestions(q1_id)) == {q2_id, q3_id}
assert set(graph_store.get_subquestions(q3_id)) == {q4_id}
graph_store.set_property(q2_id, property_name="answer", property_value="a2")
graph_store.set_property(q3_id, property_name="", property_value=["c3_1", "c3_2"])

graph_store.delete_question(q4_id)
assert graph_store.get_question(q4_id) is None
assert not graph_store.get_subquestions(q3_id)
assert graph_store.get_entity(q2_id)["answer"] == "a2"
assert graph_store.get_entity(q3_id)["context"] == ["c3_1", "c3_2"]

print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest")
print("MATCH (A)-[r]->(B) RETURN *;")

0 comments on commit 068f5f4

Please sign in to comment.