Skip to content

Commit

Permalink
Knowledge gaining orchestrator working implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Apr 26, 2024
1 parent efd1697 commit 1876740
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 105 deletions.
37 changes: 17 additions & 20 deletions examples/research_agent/question_generator.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from typing import List, Optional, Dict, Any
from typing import Optional, Any
import json
from pathlib import Path

from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables import (
RunnablePassthrough,
RunnableLambda,
RunnableParallel,
)
from langchain_core.tools import Tool
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts import PromptTemplate

from langchain_core.pydantic_v1 import BaseModel, Field

# TODO: fallback interface if LlamaIndex is not available
from llama_index.core.graph_stores.types import GraphStore

from motleycrew.tool import MotleyTool
from motleycrew.common import LLMFramework
from motleycrew.common.llms import init_llm
from motleycrew.tool.question_insertion_tool import QuestionInsertionTool
from motleycrew.common.utils import print_passthrough
from motleycrew.storage import MotleyGraphStore

from question_struct import Question
from question_inserter import QuestionInsertionTool


default_prompt = PromptTemplate.from_template(
"""
You are a part of a team. The ultimate goal of your team is to
answer the following Question: '{question}'.\n
answer the following Question: '{question_text}'.\n
Your team has discovered some new text (delimited by ```) that may be relevant to your ultimate goal.
text: \n ``` {context} ``` \n
Your task is to ask new questions that may help your team achieve the ultimate goal.
Expand Down Expand Up @@ -57,7 +57,7 @@ class QuestionGeneratorTool(MotleyTool):
def __init__(
self,
query_tool: MotleyTool,
graph: GraphStore,
graph: MotleyGraphStore,
max_questions: int = 3,
llm: Optional[BaseLanguageModel] = None,
prompt: str | BasePromptTemplate = None,
Expand All @@ -76,14 +76,12 @@ def __init__(
class QuestionGeneratorToolInput(BaseModel):
"""Input for the Question Generator Tool."""

question: str = Field(
description="The input question for which to generate subquestions."
)
question: Question = Field(description="The input question for which to generate subquestions.")


def create_question_generator_langchain_tool(
query_tool: MotleyTool,
graph: GraphStore,
graph: MotleyGraphStore,
max_questions: int = 3,
llm: Optional[BaseLanguageModel] = None,
prompt: str | BasePromptTemplate = None,
Expand All @@ -98,14 +96,10 @@ def create_question_generator_langchain_tool(
elif isinstance(prompt, str):
prompt = PromptTemplate.from_template(prompt)

assert isinstance(
prompt, BasePromptTemplate
), "Prompt must be a string or a BasePromptTemplate"
assert isinstance(prompt, BasePromptTemplate), "Prompt must be a string or a BasePromptTemplate"

def partial_inserter(question: dict[str, str]):
out = QuestionInsertionTool(
graph=graph, question=question["question"]
).to_langchain_tool()
def partial_inserter(question: Question):
out = QuestionInsertionTool(graph=graph, question=question).to_langchain_tool()
return (out,)

def insert_questions(input_dict) -> None:
Expand All @@ -124,7 +118,10 @@ def insert_questions(input_dict) -> None:
}
| RunnableLambda(print_passthrough)
| {
"subquestions": prompt.partial(num_questions=max_questions) | llm,
"subquestions": RunnablePassthrough.assign(question_text=lambda x: x["question"]["question"].question)
| RunnableLambda(print_passthrough)
| prompt.partial(num_questions=max_questions)
| llm,
"question_inserter": RunnablePassthrough(),
}
| RunnableLambda(insert_questions)
Expand Down
75 changes: 75 additions & 0 deletions examples/research_agent/question_inserter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import List

from pathlib import Path

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import Tool

from motleycrew.storage import MotleyGraphStore
from motleycrew.tool import MotleyTool

from question_struct import Question


IS_SUBQUESTION_PREDICATE = "is_subquestion"


class QuestionInsertionTool(MotleyTool):
def __init__(self, question: Question, graph: MotleyGraphStore):

langchain_tool = create_question_insertion_langchain_tool(
name="Question Insertion Tool",
description="Insert a list of questions (supplied as a list of strings) into the graph.",
question=question,
graph=graph,
)

super().__init__(langchain_tool)


class QuestionInsertionToolInput(BaseModel):
"""Subquestions of the current question, to be inserted into the knowledge graph."""

questions: List[str] = Field(description="List of questions to be inserted into the knowledge graph.")


def create_question_insertion_langchain_tool(
name: str,
description: str,
question: Question,
graph: MotleyGraphStore,
):
def insert_questions(questions: list[str]) -> None:
for subquestion in questions:
subquestion_data = graph.create_entity(Question(question=subquestion).serialize())
subquestion_obj = Question.deserialize(subquestion_data)
graph.create_rel(from_id=question.id, to_id=subquestion_obj.id, predicate=IS_SUBQUESTION_PREDICATE)

return Tool.from_function(
func=insert_questions,
name=name,
description=description,
args_schema=QuestionInsertionToolInput,
)


if __name__ == "__main__":
import kuzu
from motleycrew.storage import MotleyKuzuGraphStore

here = Path(__file__).parent
db_path = here / "test1"
db = kuzu.Database(db_path)
graph_store = MotleyKuzuGraphStore(
db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"}
)

question_data = graph_store.create_entity(Question(question="What is the capital of France?").serialize())
question = Question.deserialize(question_data)

children = ["What is the capital of France?", "What is the capital of Germany?"]
tool = QuestionInsertionTool(question=question, graph=graph_store)
tool.invoke({"questions": children})

print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest")
print("MATCH (A)-[r]->(B) RETURN *;")
35 changes: 35 additions & 0 deletions examples/research_agent/question_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional
from dataclasses import dataclass
import json


@dataclass
class Question:
id: Optional[int] = None
question: Optional[str] = None
answer: Optional[str] = None
context: Optional[list[str]] = None

def serialize(self):
data = {}

if self.id:
data["id"] = json.dumps(self.id)
if self.context:
data["question"] = json.dumps(self.question)
if self.context:
data["answer"] = json.dumps(self.answer)
if self.context:
data["context"] = json.dumps(self.context)

return data

@staticmethod
def deserialize(data: dict):
context_raw = data["context"]
if context_raw:
context = json.loads(context_raw)
else:
context = None

return Question(id=data["id"], question=data["question"], answer=data["answer"], context=context)
48 changes: 40 additions & 8 deletions examples/research_agent/research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
import kuzu

from langchain.prompts import PromptTemplate
from langchain.tools import Tool

from motleycrew import MotleyTool
from motleycrew.storage import MotleyKuzuGraphStore
from motleycrew.tool.llm_tool import LLMTool

from question_struct import Question
from question_generator import QuestionGeneratorTool
from question_generator import QuestionGeneratorToolInput

logging.basicConfig(stream=sys.stdout, level=logging.INFO)


Expand All @@ -27,20 +33,21 @@


class KnowledgeGainingOrchestrator:
def __init__(self, db_path: str):
def __init__(self, db_path: str, query_tool: MotleyTool):
self.db = kuzu.Database(db_path)
self.storage = MotleyKuzuGraphStore(
self.db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"}
)

self.query_tool = query_tool
self.question_prioritization_tool = LLMTool(
name="question_prioritization_tool",
description="find the most important question",
prompt=QUESTION_PRIORITIZATION_TEMPLATE,
)
self.question_generation_tool = None
self.question_generation_tool = QuestionGeneratorTool(query_tool=query_tool, graph=self.storage)

def get_unanswered_questions(self, only_without_children: bool = False) -> list[dict]:
def get_unanswered_questions(self, only_without_children: bool = False) -> list[Question]:
if only_without_children:
query = "MATCH (n1:{}) WHERE n1.answer IS NULL AND NOT (n1)-[:{}]->(:{}) RETURN n1;".format(
self.storage.node_table_name, self.storage.rel_table_name, self.storage.node_table_name
Expand All @@ -49,7 +56,7 @@ def get_unanswered_questions(self, only_without_children: bool = False) -> list[
query = "MATCH (n1:{}) WHERE n1.answer IS NULL RETURN n1;".format(self.storage.node_table_name)

query_result = self.storage.run_query(query)
return [row[0] for row in query_result] # flatten
return [Question.deserialize(row[0]) for row in query_result]

def __call__(self, query: str, max_iter: int):
self.storage.create_entity({"question": query})
Expand All @@ -60,25 +67,50 @@ def __call__(self, query: str, max_iter: int):
unanswered_questions = self.get_unanswered_questions(only_without_children=True)
logging.info("Loaded unanswered questions: %s", unanswered_questions)

tool_input = "\n".join(f"{i}. {question}" for i, question in enumerate(unanswered_questions))
most_pertinent_question_raw = self.question_prioritization_tool.invoke(tool_input)
question_prioritization_tool_input = {
"unanswered_questions": "\n".join(
f"{i}. {question.question}" for i, question in enumerate(unanswered_questions)
),
"original_question": query,
}
most_pertinent_question_raw = self.question_prioritization_tool.invoke(
question_prioritization_tool_input
).content
logging.info("Most pertinent question according to the tool: %s", most_pertinent_question_raw)

i, most_pertinent_question_text = most_pertinent_question_raw.split(".", 1)
i = int(i)
assert i < len(unanswered_questions)

most_pertinent_question = unanswered_questions[i]
assert most_pertinent_question_text.strip() == most_pertinent_question["question"].strip()
assert most_pertinent_question_text.strip() == most_pertinent_question.question.strip()

logging.info("Generating new questions")
self.question_generation_tool.invoke({"question": most_pertinent_question})


if __name__ == "__main__":
from pathlib import Path
import shutil
from dotenv import load_dotenv

load_dotenv()
here = Path(__file__).parent
db_path = here / "research_db"
shutil.rmtree(db_path, ignore_errors=True)

orchestrator = KnowledgeGainingOrchestrator(db_path=str(db_path))
query_tool = MotleyTool.from_langchain_tool(
Tool.from_function(
func=lambda question: [
"Germany has consisted of many different states over the years",
"The capital of France has moved in 1815, from Lyons to Paris",
"France actually has two capitals, one in the north and one in the south",
],
name="Query Tool",
description="Query the library for relevant information.",
args_schema=QuestionGeneratorToolInput,
)
)

orchestrator = KnowledgeGainingOrchestrator(db_path=str(db_path), query_tool=query_tool)
orchestrator(query="Why did Arjuna kill his step-brother?", max_iter=5)
2 changes: 2 additions & 0 deletions motleycrew/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .graph_store import MotleyGraphStore

from .kuzu_graph_store import MotleyKuzuGraphStore
29 changes: 29 additions & 0 deletions motleycrew/storage/graph_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from typing import Optional, Any


class MotleyGraphStore(ABC):
@abstractmethod
def check_entity_exists(self, entity_id: int) -> bool:
pass

@abstractmethod
def get_entity(self, entity_id: int) -> Optional[dict]:
pass

@abstractmethod
def create_entity(self, entity: dict) -> dict:
"""Create a new entity and return it"""
pass

@abstractmethod
def create_rel(self, from_id: int, to_id: int, predicate: str) -> None:
pass

@abstractmethod
def delete_entity(self, entity_id: int) -> None:
"""Delete a given entity and its relations"""
pass

def set_property(self, entity_id: int, property_name: str, property_value: Any):
pass
8 changes: 5 additions & 3 deletions motleycrew/storage/kuzu_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import kuzu

from motleycrew.storage import MotleyGraphStore

class MotleyKuzuGraphStore:

class MotleyKuzuGraphStore(MotleyGraphStore):
def __init__(
self,
database: Any,
Expand Down Expand Up @@ -81,11 +83,11 @@ def _dict_to_cypher_mapping_with_parameters(entity: dict) -> tuple[str, dict]:
cypher_mapping = cypher_mapping.rstrip(", ") + "}"
return cypher_mapping, parameters

def create_entity(self, entity: dict) -> int:
def create_entity(self, entity: dict) -> dict:
"""Create a new entity and return its id"""
cypher_mapping, parameters = MotleyKuzuGraphStore._dict_to_cypher_mapping_with_parameters(entity)
create_result = self.connection.execute(
"CREATE (n:{} {}) RETURN n.id".format(self.node_table_name, cypher_mapping), parameters=parameters
"CREATE (n:{} {}) RETURN n".format(self.node_table_name, cypher_mapping), parameters=parameters
)
assert create_result.has_next()
return create_result.get_next()[0]
Expand Down
Loading

0 comments on commit 1876740

Please sign in to comment.