-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Prebuilt code! * format * added agent executor prebuilt * cr * lint * readme * lint
- Loading branch information
1 parent
d040843
commit f9b25ca
Showing
13 changed files
with
675 additions
and
365 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import { AgentAction, AgentFinish } from "@langchain/core/agents"; | ||
import { BaseMessage } from "@langchain/core/messages"; | ||
import { Runnable, RunnableLambda } from "@langchain/core/runnables"; | ||
import { StructuredTool } from "@langchain/core/tools"; | ||
import { ToolExecutor } from "./tool_executor.js"; | ||
import { StateGraph, StateGraphArgs } from "../graph/state.js"; | ||
import { END } from "../index.js"; | ||
import { Pregel } from "../pregel/index.js"; | ||
|
||
interface AgentStateBase { | ||
agentOutcome?: AgentAction | AgentFinish; | ||
steps: Array<[AgentAction, string]>; | ||
} | ||
|
||
interface AgentState extends AgentStateBase { | ||
input: string; | ||
chatHistory?: BaseMessage[]; | ||
} | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
type AgentChannels<T> = StateGraphArgs<Array<any> | T>["channels"]; | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
function _getAgentState<T extends Array<any> = Array<any>>( | ||
inputSchema?: AgentChannels<T> | ||
): AgentChannels<T> { | ||
if (!inputSchema) { | ||
return { | ||
input: { | ||
value: null, | ||
}, | ||
agentOutcome: { | ||
value: null, | ||
}, | ||
steps: { | ||
value: (x, y) => x.concat(y), | ||
default: () => [], | ||
}, | ||
}; | ||
} else { | ||
return inputSchema; | ||
} | ||
} | ||
|
||
export function createAgentExecutor< | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
T extends Array<any> = Array<any> | ||
>({ | ||
agentRunnable, | ||
tools, | ||
inputSchema, | ||
}: { | ||
agentRunnable: Runnable; | ||
tools: Array<StructuredTool> | ToolExecutor; | ||
inputSchema?: AgentChannels<T>; | ||
}): Pregel { | ||
let toolExecutor: ToolExecutor; | ||
if (!Array.isArray(tools)) { | ||
toolExecutor = tools; | ||
} else { | ||
toolExecutor = new ToolExecutor({ | ||
tools, | ||
}); | ||
} | ||
|
||
const state = _getAgentState<T>(inputSchema); | ||
|
||
// Define logic that will be used to determine which conditional edge to go down | ||
const shouldContinue = (data: AgentState) => { | ||
if (data.agentOutcome && "returnValues" in data.agentOutcome) { | ||
return "end"; | ||
} | ||
return "continue"; | ||
}; | ||
|
||
const runAgent = async (data: AgentState) => { | ||
const agentOutcome = await agentRunnable.invoke(data); | ||
return { | ||
agentOutcome, | ||
}; | ||
}; | ||
|
||
const executeTools = async (data: AgentState) => { | ||
const agentAction = data.agentOutcome; | ||
if (!agentAction || "returnValues" in agentAction) { | ||
throw new Error("Agent has not been run yet"); | ||
} | ||
const output = await toolExecutor.invoke(agentAction); | ||
return { | ||
steps: [[agentAction, output]], | ||
}; | ||
}; | ||
|
||
// Define a new graph | ||
const workflow = new StateGraph({ | ||
channels: state, | ||
}); | ||
|
||
// Define the two nodes we will cycle between | ||
workflow.addNode("agent", new RunnableLambda({ func: runAgent })); | ||
workflow.addNode("action", new RunnableLambda({ func: executeTools })); | ||
|
||
// Set the entrypoint as `agent` | ||
// This means that this node is the first one called | ||
workflow.setEntryPoint("agent"); | ||
|
||
// We now add a conditional edge | ||
workflow.addConditionalEdges( | ||
// First, we define the start node. We use `agent`. | ||
// This means these are the edges taken after the `agent` node is called. | ||
"agent", | ||
// Next, we pass in the function that will determine which node is called next. | ||
shouldContinue, | ||
// Finally we pass in a mapping. | ||
// The keys are strings, and the values are other nodes. | ||
// END is a special node marking that the graph should finish. | ||
// What will happen is we will call `should_continue`, and then the output of that | ||
// will be matched against the keys in this mapping. | ||
// Based on which one it matches, that node will then be called. | ||
{ | ||
// If `tools`, then we call the tool node. | ||
continue: "action", | ||
// Otherwise we finish. | ||
end: END, | ||
} | ||
); | ||
|
||
// We now add a normal edge from `tools` to `agent`. | ||
// This means that after `tools` is called, `agent` node is called next. | ||
workflow.addEdge("action", "agent"); | ||
|
||
return workflow.compile(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import { StructuredTool } from "@langchain/core/tools"; | ||
import { convertToOpenAIFunction } from "@langchain/core/utils/function_calling"; | ||
import { AgentAction } from "@langchain/core/agents"; | ||
import { FunctionMessage, BaseMessage } from "@langchain/core/messages"; | ||
import { RunnableLambda } from "@langchain/core/runnables"; | ||
import { ToolExecutor } from "./tool_executor.js"; | ||
import { StateGraph, StateGraphArgs } from "../graph/state.js"; | ||
import { END } from "../index.js"; | ||
|
||
export function createFunctionCallingExecutor<Model extends object>({ | ||
model, | ||
tools, | ||
}: { | ||
model: Model; | ||
tools: Array<StructuredTool> | ToolExecutor; | ||
}) { | ||
let toolExecutor: ToolExecutor; | ||
let toolClasses: Array<StructuredTool>; | ||
if (!Array.isArray(tools)) { | ||
toolExecutor = tools; | ||
toolClasses = tools.tools; | ||
} else { | ||
toolExecutor = new ToolExecutor({ | ||
tools, | ||
}); | ||
toolClasses = tools; | ||
} | ||
|
||
const toolsAsOpenAIFunctions = toolClasses.map((tool) => | ||
convertToOpenAIFunction(tool) | ||
); | ||
if (!("bind" in model) || typeof model.bind !== "function") { | ||
throw new Error("Model must be bindable"); | ||
} | ||
const newModel = model.bind({ | ||
functions: toolsAsOpenAIFunctions, | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
} as any); | ||
|
||
// Define the function that determines whether to continue or not | ||
const shouldContinue = (state: { messages: Array<BaseMessage> }) => { | ||
const { messages } = state; | ||
const lastMessage = messages[messages.length - 1]; | ||
// If there is no function call, then we finish | ||
if ( | ||
!("function_call" in lastMessage.additional_kwargs) || | ||
!lastMessage.additional_kwargs.function_call | ||
) { | ||
return "end"; | ||
} | ||
// Otherwise if there is, we continue | ||
return "continue"; | ||
}; | ||
|
||
// Define the function that calls the model | ||
const callModel = async (state: { messages: Array<BaseMessage> }) => { | ||
const { messages } = state; | ||
const response = await newModel.invoke(messages); | ||
// We return a list, because this will get added to the existing list | ||
return { | ||
messages: [response], | ||
}; | ||
}; | ||
|
||
// Define the function to execute tools | ||
const _getAction = (state: { messages: Array<BaseMessage> }): AgentAction => { | ||
const { messages } = state; | ||
// Based on the continue condition | ||
// we know the last message involves a function call | ||
const lastMessage = messages[messages.length - 1]; | ||
if (!lastMessage) { | ||
throw new Error("No messages found."); | ||
} | ||
if (!lastMessage.additional_kwargs.function_call) { | ||
throw new Error("No function call found in message."); | ||
} | ||
// We construct an AgentAction from the function_call | ||
return { | ||
tool: lastMessage.additional_kwargs.function_call.name, | ||
toolInput: JSON.stringify( | ||
lastMessage.additional_kwargs.function_call.arguments | ||
), | ||
log: "", | ||
}; | ||
}; | ||
|
||
const callTool = async (state: { messages: Array<BaseMessage> }) => { | ||
const action = _getAction(state); | ||
// We call the tool_executor and get back a response | ||
const response = await toolExecutor.invoke(action); | ||
// We use the response to create a FunctionMessage | ||
const functionMessage = new FunctionMessage({ | ||
content: response, | ||
name: action.tool, | ||
}); | ||
// We return a list, because this will get added to the existing list | ||
return { messages: [functionMessage] }; | ||
}; | ||
|
||
// We create the AgentState that we will pass around | ||
// This simply involves a list of messages | ||
// We want steps to return messages to append to the list | ||
// So we annotate the messages attribute with operator.add | ||
const schema: StateGraphArgs["channels"] = { | ||
messages: { | ||
value: (x: BaseMessage[], y: BaseMessage[]) => x.concat(y), | ||
default: () => [], | ||
}, | ||
}; | ||
|
||
// Define a new graph | ||
const workflow = new StateGraph({ | ||
channels: schema, | ||
}); | ||
|
||
// Define the two nodes we will cycle between | ||
workflow.addNode("agent", new RunnableLambda({ func: callModel })); | ||
workflow.addNode("action", new RunnableLambda({ func: callTool })); | ||
|
||
// Set the entrypoint as `agent` | ||
// This means that this node is the first one called | ||
workflow.setEntryPoint("agent"); | ||
|
||
// We now add a conditional edge | ||
workflow.addConditionalEdges( | ||
// First, we define the start node. We use `agent`. | ||
// This means these are the edges taken after the `agent` node is called. | ||
"agent", | ||
// Next, we pass in the function that will determine which node is called next. | ||
shouldContinue, | ||
// Finally we pass in a mapping. | ||
// The keys are strings, and the values are other nodes. | ||
// END is a special node marking that the graph should finish. | ||
// What will happen is we will call `should_continue`, and then the output of that | ||
// will be matched against the keys in this mapping. | ||
// Based on which one it matches, that node will then be called. | ||
{ | ||
// If `tools`, then we call the tool node. | ||
continue: "action", | ||
// Otherwise we finish. | ||
end: END, | ||
} | ||
); | ||
|
||
// We now add a normal edge from `tools` to `agent`. | ||
// This means that after `tools` is called, `agent` node is called next. | ||
workflow.addEdge("action", "agent"); | ||
|
||
// Finally, we compile it! | ||
// This compiles it into a LangChain Runnable, | ||
// meaning you can use it as you would any other runnable | ||
return workflow.compile(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
export { createAgentExecutor } from "./agent_executor.js"; | ||
export { createFunctionCallingExecutor } from "./chat_agent_executor.js"; | ||
export { | ||
type ToolExecutorArgs, | ||
type ToolInvocationInterface, | ||
ToolExecutor, | ||
} from "./tool_executor.js"; |
Oops, something went wrong.