Skip to content

Commit

Permalink
Prebuilt code! (#16)
Browse files Browse the repository at this point in the history
* Prebuilt code!

* format

* added agent executor prebuilt

* cr

* lint

* readme

* lint
  • Loading branch information
bracesproul authored Jan 18, 2024
1 parent d040843 commit f9b25ca
Show file tree
Hide file tree
Showing 13 changed files with 675 additions and 365 deletions.
374 changes: 24 additions & 350 deletions README.md

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,20 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/core": "^0.1.15"
"@langchain/community": "^0.0.17",
"@langchain/core": "^0.1.15",
"@langchain/openai": "^0.0.12",
"langchain": "^0.1.3",
"zod": "^3.22.4"
},
"devDependencies": {
"@jest/globals": "^29.5.0",
"@langchain/openai": "^0.0.12",
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
"@tsconfig/recommended": "^1.0.3",
"@typescript-eslint/eslint-plugin": "^6.12.0",
"@typescript-eslint/parser": "^6.12.0",
"zod": "^3.22.4",
"dotenv": "^16.3.1",
"dpdm": "^3.12.0",
"eslint": "^8.33.0",
Expand Down
3 changes: 2 additions & 1 deletion langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import { ChannelRead } from "../pregel/read.js";

export const START = "__start__";

export interface StateGraphArgs<T> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export interface StateGraphArgs<T = any> {
channels: Record<
string,
{
Expand Down
133 changes: 133 additions & 0 deletions langgraph/src/prebuilt/agent_executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import { AgentAction, AgentFinish } from "@langchain/core/agents";
import { BaseMessage } from "@langchain/core/messages";
import { Runnable, RunnableLambda } from "@langchain/core/runnables";
import { StructuredTool } from "@langchain/core/tools";
import { ToolExecutor } from "./tool_executor.js";
import { StateGraph, StateGraphArgs } from "../graph/state.js";
import { END } from "../index.js";
import { Pregel } from "../pregel/index.js";

interface AgentStateBase {
agentOutcome?: AgentAction | AgentFinish;
steps: Array<[AgentAction, string]>;
}

interface AgentState extends AgentStateBase {
input: string;
chatHistory?: BaseMessage[];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type AgentChannels<T> = StateGraphArgs<Array<any> | T>["channels"];

// eslint-disable-next-line @typescript-eslint/no-explicit-any
function _getAgentState<T extends Array<any> = Array<any>>(
inputSchema?: AgentChannels<T>
): AgentChannels<T> {
if (!inputSchema) {
return {
input: {
value: null,
},
agentOutcome: {
value: null,
},
steps: {
value: (x, y) => x.concat(y),
default: () => [],
},
};
} else {
return inputSchema;
}
}

export function createAgentExecutor<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends Array<any> = Array<any>
>({
agentRunnable,
tools,
inputSchema,
}: {
agentRunnable: Runnable;
tools: Array<StructuredTool> | ToolExecutor;
inputSchema?: AgentChannels<T>;
}): Pregel {
let toolExecutor: ToolExecutor;
if (!Array.isArray(tools)) {
toolExecutor = tools;
} else {
toolExecutor = new ToolExecutor({
tools,
});
}

const state = _getAgentState<T>(inputSchema);

// Define logic that will be used to determine which conditional edge to go down
const shouldContinue = (data: AgentState) => {
if (data.agentOutcome && "returnValues" in data.agentOutcome) {
return "end";
}
return "continue";
};

const runAgent = async (data: AgentState) => {
const agentOutcome = await agentRunnable.invoke(data);
return {
agentOutcome,
};
};

const executeTools = async (data: AgentState) => {
const agentAction = data.agentOutcome;
if (!agentAction || "returnValues" in agentAction) {
throw new Error("Agent has not been run yet");
}
const output = await toolExecutor.invoke(agentAction);
return {
steps: [[agentAction, output]],
};
};

// Define a new graph
const workflow = new StateGraph({
channels: state,
});

// Define the two nodes we will cycle between
workflow.addNode("agent", new RunnableLambda({ func: runAgent }));
workflow.addNode("action", new RunnableLambda({ func: executeTools }));

// Set the entrypoint as `agent`
// This means that this node is the first one called
workflow.setEntryPoint("agent");

// We now add a conditional edge
workflow.addConditionalEdges(
// First, we define the start node. We use `agent`.
// This means these are the edges taken after the `agent` node is called.
"agent",
// Next, we pass in the function that will determine which node is called next.
shouldContinue,
// Finally we pass in a mapping.
// The keys are strings, and the values are other nodes.
// END is a special node marking that the graph should finish.
// What will happen is we will call `should_continue`, and then the output of that
// will be matched against the keys in this mapping.
// Based on which one it matches, that node will then be called.
{
// If `tools`, then we call the tool node.
continue: "action",
// Otherwise we finish.
end: END,
}
);

// We now add a normal edge from `tools` to `agent`.
// This means that after `tools` is called, `agent` node is called next.
workflow.addEdge("action", "agent");

return workflow.compile();
}
153 changes: 153 additions & 0 deletions langgraph/src/prebuilt/chat_agent_executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import { StructuredTool } from "@langchain/core/tools";
import { convertToOpenAIFunction } from "@langchain/core/utils/function_calling";
import { AgentAction } from "@langchain/core/agents";
import { FunctionMessage, BaseMessage } from "@langchain/core/messages";
import { RunnableLambda } from "@langchain/core/runnables";
import { ToolExecutor } from "./tool_executor.js";
import { StateGraph, StateGraphArgs } from "../graph/state.js";
import { END } from "../index.js";

export function createFunctionCallingExecutor<Model extends object>({
model,
tools,
}: {
model: Model;
tools: Array<StructuredTool> | ToolExecutor;
}) {
let toolExecutor: ToolExecutor;
let toolClasses: Array<StructuredTool>;
if (!Array.isArray(tools)) {
toolExecutor = tools;
toolClasses = tools.tools;
} else {
toolExecutor = new ToolExecutor({
tools,
});
toolClasses = tools;
}

const toolsAsOpenAIFunctions = toolClasses.map((tool) =>
convertToOpenAIFunction(tool)
);
if (!("bind" in model) || typeof model.bind !== "function") {
throw new Error("Model must be bindable");
}
const newModel = model.bind({
functions: toolsAsOpenAIFunctions,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);

// Define the function that determines whether to continue or not
const shouldContinue = (state: { messages: Array<BaseMessage> }) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
// If there is no function call, then we finish
if (
!("function_call" in lastMessage.additional_kwargs) ||
!lastMessage.additional_kwargs.function_call
) {
return "end";
}
// Otherwise if there is, we continue
return "continue";
};

// Define the function that calls the model
const callModel = async (state: { messages: Array<BaseMessage> }) => {
const { messages } = state;
const response = await newModel.invoke(messages);
// We return a list, because this will get added to the existing list
return {
messages: [response],
};
};

// Define the function to execute tools
const _getAction = (state: { messages: Array<BaseMessage> }): AgentAction => {
const { messages } = state;
// Based on the continue condition
// we know the last message involves a function call
const lastMessage = messages[messages.length - 1];
if (!lastMessage) {
throw new Error("No messages found.");
}
if (!lastMessage.additional_kwargs.function_call) {
throw new Error("No function call found in message.");
}
// We construct an AgentAction from the function_call
return {
tool: lastMessage.additional_kwargs.function_call.name,
toolInput: JSON.stringify(
lastMessage.additional_kwargs.function_call.arguments
),
log: "",
};
};

const callTool = async (state: { messages: Array<BaseMessage> }) => {
const action = _getAction(state);
// We call the tool_executor and get back a response
const response = await toolExecutor.invoke(action);
// We use the response to create a FunctionMessage
const functionMessage = new FunctionMessage({
content: response,
name: action.tool,
});
// We return a list, because this will get added to the existing list
return { messages: [functionMessage] };
};

// We create the AgentState that we will pass around
// This simply involves a list of messages
// We want steps to return messages to append to the list
// So we annotate the messages attribute with operator.add
const schema: StateGraphArgs["channels"] = {
messages: {
value: (x: BaseMessage[], y: BaseMessage[]) => x.concat(y),
default: () => [],
},
};

// Define a new graph
const workflow = new StateGraph({
channels: schema,
});

// Define the two nodes we will cycle between
workflow.addNode("agent", new RunnableLambda({ func: callModel }));
workflow.addNode("action", new RunnableLambda({ func: callTool }));

// Set the entrypoint as `agent`
// This means that this node is the first one called
workflow.setEntryPoint("agent");

// We now add a conditional edge
workflow.addConditionalEdges(
// First, we define the start node. We use `agent`.
// This means these are the edges taken after the `agent` node is called.
"agent",
// Next, we pass in the function that will determine which node is called next.
shouldContinue,
// Finally we pass in a mapping.
// The keys are strings, and the values are other nodes.
// END is a special node marking that the graph should finish.
// What will happen is we will call `should_continue`, and then the output of that
// will be matched against the keys in this mapping.
// Based on which one it matches, that node will then be called.
{
// If `tools`, then we call the tool node.
continue: "action",
// Otherwise we finish.
end: END,
}
);

// We now add a normal edge from `tools` to `agent`.
// This means that after `tools` is called, `agent` node is called next.
workflow.addEdge("action", "agent");

// Finally, we compile it!
// This compiles it into a LangChain Runnable,
// meaning you can use it as you would any other runnable
return workflow.compile();
}
7 changes: 7 additions & 0 deletions langgraph/src/prebuilt/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export { createAgentExecutor } from "./agent_executor.js";
export { createFunctionCallingExecutor } from "./chat_agent_executor.js";
export {
type ToolExecutorArgs,
type ToolInvocationInterface,
ToolExecutor,
} from "./tool_executor.js";
Loading

0 comments on commit f9b25ca

Please sign in to comment.