Skip to content

Commit

Permalink
Opt out of return value typing for now, swap StateType with UpdateTyp…
Browse files Browse the repository at this point in the history
…e for correct semantics, use Annotation.Root for prebuilt agent executor
  • Loading branch information
dqbd committed Nov 4, 2024
1 parent 17ab0e3 commit ab7a649
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 44 deletions.
28 changes: 14 additions & 14 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -441,39 +441,39 @@ export class Graph<
export class CompiledGraph<
N extends string,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput = any,
State = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput = any,
Update = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType extends Record<string, any> = Record<string, any>
> extends Pregel<
Record<N | typeof START, PregelNode<RunInput, RunOutput>>,
Record<N | typeof START, PregelNode<State, Update>>,
Record<N | typeof START | typeof END | string, BaseChannel>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType & Record<string, any>,
RunInput,
RunOutput
Update,
State
> {
declare NodeType: N;

declare RunInput: RunInput;
declare RunInput: State;

declare RunOutput: RunOutput;
declare RunOutput: Update;

builder: Graph<N, RunInput, RunOutput>;
builder: Graph<N, State, Update>;

constructor({
builder,
...rest
}: { builder: Graph<N, RunInput, RunOutput> } & PregelParams<
Record<N | typeof START, PregelNode<RunInput, RunOutput>>,
}: { builder: Graph<N, State, Update> } & PregelParams<
Record<N | typeof START, PregelNode<State, Update>>,
Record<N | typeof START | typeof END | string, BaseChannel>
>) {
super(rest);
this.builder = builder;
}

attachNode(key: N, node: NodeSpec<RunInput, RunOutput>): void {
attachNode(key: N, node: NodeSpec<State, Update>): void {
this.channels[key] = new EphemeralValue();
this.nodes[key] = new PregelNode({
channels: [],
Expand Down Expand Up @@ -505,7 +505,7 @@ export class CompiledGraph<
attachBranch(
start: N | typeof START,
name: string,
branch: Branch<RunInput, N>
branch: Branch<State, N>
) {
// add hidden start node
if (start === START && this.nodes[START]) {
Expand Down Expand Up @@ -590,7 +590,7 @@ export class CompiledGraph<

for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [
N,
NodeSpec<RunInput, RunOutput>
NodeSpec<State, Update>
][]) {
const displayKey = _escapeMermaidKeywords(key);
const node = nodeSpec.runnable;
Expand Down Expand Up @@ -771,7 +771,7 @@ export class CompiledGraph<

for (const [key, nodeSpec] of Object.entries(this.builder.nodes) as [
N,
NodeSpec<RunInput, RunOutput>
NodeSpec<State, Update>
][]) {
const displayKey = _escapeMermaidKeywords(key);
const node = nodeSpec.runnable;
Expand Down
24 changes: 6 additions & 18 deletions libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,9 @@ import {
} from "@langchain/core/language_models/base";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { All, BaseCheckpointSaver } from "@langchain/langgraph-checkpoint";
import {
END,
messagesStateReducer,
START,
StateGraph,
} from "../graph/index.js";
import { END, START, StateGraph } from "../graph/index.js";
import { MessagesAnnotation } from "../graph/messages_annotation.js";
import { CompiledStateGraph, StateGraphArgs } from "../graph/state.js";
import { CompiledStateGraph } from "../graph/state.js";
import { ToolNode } from "./tool_node.js";

export interface AgentState {
Expand Down Expand Up @@ -107,11 +102,12 @@ export type CreateReactAgentParams = {
* // Returns the messages in the state at each step of execution
* ```
*/

export function createReactAgent(
params: CreateReactAgentParams
): CompiledStateGraph<
AgentState,
Partial<AgentState>,
(typeof MessagesAnnotation)["State"],
(typeof MessagesAnnotation)["Update"],
typeof START | "agent" | "tools"
> {
const {
Expand All @@ -122,12 +118,6 @@ export function createReactAgent(
interruptBefore,
interruptAfter,
} = params;
const schema: StateGraphArgs<AgentState>["channels"] = {
messages: {
value: messagesStateReducer,
default: () => [],
},
};

let toolClasses: (StructuredToolInterface | DynamicTool | RunnableToolLike)[];
if (!Array.isArray(tools)) {
Expand Down Expand Up @@ -160,9 +150,7 @@ export function createReactAgent(
return { messages: [await modelRunnable.invoke(messages, config)] };
};

const workflow = new StateGraph<AgentState>({
channels: schema,
})
const workflow = new StateGraph(MessagesAnnotation)
.addNode("agent", callModel)
.addNode("tools", new ToolNode<AgentState>(toolClasses))
.addEdge(START, "agent")
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ export class Pregel<
override async stream(
input: InputType | null,
options?: Partial<PregelOptions<Nn, Cc, ConfigurableFieldType>>
): Promise<IterableReadableStream<OutputType>> {
): Promise<IterableReadableStream<PregelOutputType>> {
return super.stream(input, options);
}

Expand Down
21 changes: 15 additions & 6 deletions libs/langgraph/src/tests/graph.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ describe("State", () => {
it("should allow reducers with different argument types", async () => {
const StateAnnotation = Annotation.Root({
val: Annotation<number>,
testval: Annotation<string[], string>({
reducer: (left, right) =>
right ? left.concat([right.toString()]) : left,
testval: Annotation<string[], string | string[]>({
reducer: (left, right) => {
if (typeof right === "string") {
return right ? left.concat([right.toString()]) : left;
}
return right.length ? left.concat(right) : left;
},
}),
});
const stateGraph = new StateGraph(StateAnnotation);
Expand All @@ -41,6 +45,7 @@ describe("State", () => {
.addEdge(START, "testnode")
.addEdge("testnode", END)
.compile();

expect(await graph.invoke({ testval: ["hello"] })).toEqual({
testval: ["hello", "hi!"],
val: 3,
Expand All @@ -51,12 +56,16 @@ describe("State", () => {
const stateGraph = new StateGraph<
unknown,
{ testval: string[] },
{ testval: string }
{ testval: string | string[] }
>({
channels: {
testval: {
reducer: (left: string[], right?: string) =>
right ? left.concat([right.toString()]) : left,
reducer: (left, right) => {
if (typeof right === "string") {
return right ? left.concat([right.toString()]) : left;
}
return right.length ? left.concat(right) : left;
},
},
},
});
Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/src/tests/prebuilt.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ describe("createReactAgent", () => {
expect(response.messages.length > 1).toBe(true);
const lastMessage = response.messages[response.messages.length - 1];
expect(lastMessage._getType()).toBe("ai");
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
expect((lastMessage.content as string).toLowerCase()).toContain(
"not too cold"
);
});

it("can stream a tool call with a checkpointer", async () => {
Expand Down
3 changes: 3 additions & 0 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3541,6 +3541,7 @@ export function runPregelTests(
hello: "there",
bye: "world",
messages: ["hello"],
// @ts-expect-error This should emit a TS error
now: 345, // ignored because not in input schema
})
).toEqual({
Expand All @@ -3553,6 +3554,7 @@ export function runPregelTests(
hello: "there",
bye: "world",
messages: ["hello"],
// @ts-expect-error This should emit a TS error
now: 345, // ignored because not in input schema
})
)
Expand Down Expand Up @@ -3712,6 +3714,7 @@ export function runPregelTests(
};
const res = await app.invoke(
{
// @ts-expect-error Messages is not in schema
messages: ["initial input"],
},
config
Expand Down
14 changes: 10 additions & 4 deletions libs/langgraph/src/tests/tracing.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,16 @@ Only add steps to the plan that still NEED to be done. Do not return previously
state: PlanExecuteState
): Promise<Partial<PlanExecuteState>> {
const task = state.input;
const agentResponse = await agentExecutor.invoke({ input: task });
return {
pastSteps: [task, agentResponse.agentOutcome.returnValues.output],
};
const agentResponse = await agentExecutor.invoke({
input: task ?? undefined,
});

const outcome = agentResponse.agentOutcome;
if (!outcome || !("returnValues" in outcome)) {
throw new Error("Agent did not return a valid outcome.");
}

return { pastSteps: [task, outcome.returnValues.output] };
}

async function planStep(
Expand Down

0 comments on commit ab7a649

Please sign in to comment.