Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom LLMs in research agent #86

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
619 changes: 219 additions & 400 deletions examples/Multi-step research agent.ipynb

Large diffs are not rendered by default.

26 changes: 17 additions & 9 deletions examples/research_agent/research_agent_main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from pathlib import Path
import shutil
import os
import platform
import shutil
from pathlib import Path

import kuzu
from dotenv import load_dotenv

from motleycrew import MotleyCrew
from motleycrew.storage import MotleyKuzuGraphStore
from motleycrew.common import configure_logging
from motleycrew.applications.research_agent.question_task import QuestionTask
from motleycrew.applications.research_agent.answer_task import AnswerTask

from motleycrew.applications.research_agent.question_task import QuestionTask
from motleycrew.common import LLMFramework, configure_logging
from motleycrew.common.llms import init_llm
from motleycrew.storage import MotleyKuzuGraphStore
from motleycrew.tools.simple_retriever_tool import SimpleRetrieverTool


WORKING_DIR = Path(__file__).parent
if "Dropbox" in WORKING_DIR.parts and platform.system() == "Windows":
# On Windows, kuzu has file locking issues with Dropbox
Expand All @@ -31,21 +30,30 @@


def main():
llm = init_llm(
llm_framework=LLMFramework.LANGCHAIN
) # throughout this project, we use LangChain's LLM wrappers

load_dotenv()
configure_logging(verbose=True)

shutil.rmtree(DB_PATH)

# You can pass any LlamaIndex embedding to the retriever tool, default is OpenAI's text-embedding-ada-002
query_tool = SimpleRetrieverTool(DATA_DIR, PERSIST_DIR, return_strings_only=True)

db = kuzu.Database(DB_PATH)
graph_store = MotleyKuzuGraphStore(db)
crew = MotleyCrew(graph_store=graph_store)

question_task = QuestionTask(
crew=crew, question=QUESTION, query_tool=query_tool, max_iter=MAX_ITER
crew=crew,
question=QUESTION,
query_tool=query_tool,
max_iter=MAX_ITER,
llm=llm,
)
answer_task = AnswerTask(answer_length=ANSWER_LENGTH, crew=crew)
answer_task = AnswerTask(answer_length=ANSWER_LENGTH, crew=crew, llm=llm)

question_task >> answer_task

Expand Down
4 changes: 3 additions & 1 deletion motleycrew/applications/research_agent/answer_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional

from langchain_core.runnables import Runnable
from langchain_core.language_models import BaseLanguageModel

from motleycrew.applications.research_agent.question import Question
from motleycrew.applications.research_agent.question_answerer import AnswerSubQuestionTool
Expand All @@ -21,6 +22,7 @@ def __init__(
self,
crew: MotleyCrew,
answer_length: int = 1000,
llm: Optional[BaseLanguageModel] = None,
):
super().__init__(
name="AnswerTask",
Expand All @@ -30,7 +32,7 @@ def __init__(
)
self.answer_length = answer_length
self.answerer = AnswerSubQuestionTool(
graph=self.graph_store, answer_length=self.answer_length
graph=self.graph_store, answer_length=self.answer_length, llm=llm
)

def get_next_unit(self) -> QuestionAnsweringTaskUnit | None:
Expand Down
15 changes: 9 additions & 6 deletions motleycrew/applications/research_agent/question_answerer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Optional

from langchain.prompts import PromptTemplate
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import (
RunnablePassthrough,
RunnableLambda,
chain,
)
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, chain
from langchain_core.tools import Tool

from motleycrew.applications.research_agent.question import Question
from motleycrew.common.utils import print_passthrough
from motleycrew.storage import MotleyGraphStore
from motleycrew.tools import MotleyTool, LLMTool
from motleycrew.tools import LLMTool, MotleyTool

_default_prompt = PromptTemplate.from_template(
"""
Expand All @@ -37,11 +36,13 @@ def __init__(
graph: MotleyGraphStore,
answer_length: int,
prompt: str | BasePromptTemplate = None,
llm: Optional[BaseLanguageModel] = None,
):
langchain_tool = create_answer_question_langchain_tool(
graph=graph,
answer_length=answer_length,
prompt=prompt,
llm=llm,
)

super().__init__(langchain_tool)
Expand Down Expand Up @@ -70,6 +71,7 @@ def create_answer_question_langchain_tool(
graph: MotleyGraphStore,
answer_length: int,
prompt: str | BasePromptTemplate = None,
llm: Optional[BaseLanguageModel] = None,
) -> Tool:
if prompt is None:
prompt = _default_prompt
Expand All @@ -78,6 +80,7 @@ def create_answer_question_langchain_tool(
prompt=prompt.partial(answer_length=str(answer_length)),
name="Question answerer",
description="Tool to answer a question from notes and sub-questions",
llm=llm,
)
"""
Gets a valid question node ID, question, and context as input dict
Expand Down
4 changes: 3 additions & 1 deletion motleycrew/applications/research_agent/question_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional

from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables import Runnable

from motleycrew.common import logger
Expand All @@ -26,6 +27,7 @@ def __init__(
crew: MotleyCrew,
max_iter: int = 10,
allow_async_units: bool = False,
llm: Optional[BaseLanguageModel] = None,
name: str = "QuestionTask",
):
super().__init__(
Expand All @@ -41,7 +43,7 @@ def __init__(
self.graph_store.insert_node(self.question)
self.question_prioritization_tool = QuestionPrioritizerTool()
self.question_generation_tool = QuestionGeneratorTool(
query_tool=query_tool, graph=self.graph_store
query_tool=query_tool, graph=self.graph_store, llm=llm
)

def get_next_unit(self) -> QuestionGenerationTaskUnit | None:
Expand Down
16 changes: 12 additions & 4 deletions motleycrew/tools/simple_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
load_index_from_storage,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding

from motleycrew.applications.research_agent.question import Question
Expand All @@ -26,6 +27,7 @@ def __init__(
return_strings_only: bool = False,
return_direct: bool = False,
exceptions_to_reflect: Optional[List[Exception]] = None,
embeddings: Optional[BaseEmbedding] = None,
):
"""
Args:
Expand All @@ -34,7 +36,7 @@ def __init__(
return_strings_only: Whether to return only the text of the retrieved documents.
"""
tool = make_retriever_langchain_tool(
data_dir, persist_dir, return_strings_only=return_strings_only
data_dir, persist_dir, return_strings_only=return_strings_only, embeddings=embeddings
)
super().__init__(
tool=tool, return_direct=return_direct, exceptions_to_reflect=exceptions_to_reflect
Expand All @@ -49,9 +51,15 @@ class RetrieverToolInput(BaseModel, arbitrary_types_allowed=True):
)


def make_retriever_langchain_tool(data_dir, persist_dir, return_strings_only: bool = False):
text_embedding_model = "text-embedding-ada-002"
embeddings = OpenAIEmbedding(model=text_embedding_model)
def make_retriever_langchain_tool(
data_dir,
persist_dir,
return_strings_only: bool = False,
embeddings: Optional[BaseEmbedding] = None,
):
if embeddings is None:
text_embedding_model = "text-embedding-ada-002"
embeddings = OpenAIEmbedding(model=text_embedding_model)

if not os.path.exists(persist_dir):
# load the documents and create the index
Expand Down
Loading