Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement types for Pregel.stream using recursive type #512

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions libs/langgraph/src/channels/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ export abstract class BaseChannel<
}
}

export function emptyChannels<Cc extends Record<string, BaseChannel>>(
channels: Cc,
export function emptyChannels<Channels extends Record<string, BaseChannel>>(
channels: Channels,
checkpoint: ReadonlyCheckpoint
): Cc {
): Channels {
const filteredChannels = Object.fromEntries(
Object.entries(channels).filter(([, value]) => isBaseChannel(value))
) as Cc;
) as Channels;

const newChannels = {} as Cc;
const newChannels = {} as Channels;
for (const k in filteredChannels) {
if (Object.prototype.hasOwnProperty.call(filteredChannels, k)) {
const channelValue = checkpoint.channel_values[k];
Expand Down
93 changes: 51 additions & 42 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export type NodeSpec<RunInput, RunOutput> = {
export type AddNodeOptions = { metadata?: Record<string, unknown> };

export class Graph<
N extends string = typeof END,
NodeNames extends string = typeof END,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -129,18 +129,18 @@ export class Graph<
RunOutput
>
> {
nodes: Record<N, NodeSpecType>;
nodes: Record<NodeNames, NodeSpecType>;

edges: Set<[N | typeof START, N | typeof END]>;
edges: Set<[NodeNames | typeof START, NodeNames | typeof END]>;

branches: Record<string, Record<string, Branch<RunInput, N>>>;
branches: Record<string, Record<string, Branch<RunInput, NodeNames>>>;

entryPoint?: string;

compiled = false;

constructor() {
this.nodes = {} as Record<N, NodeSpecType>;
this.nodes = {} as Record<NodeNames, NodeSpecType>;
this.edges = new Set();
this.branches = {};
}
Expand All @@ -159,7 +159,7 @@ export class Graph<
key: K,
action: RunnableLike<NodeInput, RunOutput>,
options?: AddNodeOptions
): Graph<N | K, RunInput, RunOutput> {
): Graph<NodeNames | K, RunInput, RunOutput> {
if (key.includes(CHECKPOINT_NAMESPACE_SEPARATOR)) {
throw new Error(
`"${CHECKPOINT_NAMESPACE_SEPARATOR}" is a reserved character and is not allowed in node names.`
Expand All @@ -176,18 +176,21 @@ export class Graph<
throw new Error(`Node \`${key}\` is reserved.`);
}

this.nodes[key as unknown as N] = {
this.nodes[key as unknown as NodeNames] = {
runnable: _coerceToRunnable<RunInput, RunOutput>(
// Account for arbitrary state due to Send API
action as RunnableLike<RunInput, RunOutput>
),
metadata: options?.metadata,
} as NodeSpecType;

return this as Graph<N | K, RunInput, RunOutput, NodeSpecType>;
return this as Graph<NodeNames | K, RunInput, RunOutput, NodeSpecType>;
}

addEdge(startKey: N | typeof START, endKey: N | typeof END): this {
addEdge(
startKey: NodeNames | typeof START,
endKey: NodeNames | typeof END
): this {
this.warnIfCompiled(
`Adding an edge to a graph that has already been compiled. This will not be reflected in the compiled graph.`
);
Expand All @@ -212,20 +215,20 @@ export class Graph<
return this;
}

addConditionalEdges(source: BranchOptions<RunInput, N>): this;
addConditionalEdges(source: BranchOptions<RunInput, NodeNames>): this;

addConditionalEdges(
source: N,
path: Branch<RunInput, N>["condition"],
pathMap?: BranchOptions<RunInput, N>["pathMap"]
source: NodeNames,
path: Branch<RunInput, NodeNames>["condition"],
pathMap?: BranchOptions<RunInput, NodeNames>["pathMap"]
): this;

addConditionalEdges(
source: N | BranchOptions<RunInput, N>,
path?: Branch<RunInput, N>["condition"],
pathMap?: BranchOptions<RunInput, N>["pathMap"]
source: NodeNames | BranchOptions<RunInput, NodeNames>,
path?: Branch<RunInput, NodeNames>["condition"],
pathMap?: BranchOptions<RunInput, NodeNames>["pathMap"]
): this {
const options: BranchOptions<RunInput, N> =
const options: BranchOptions<RunInput, NodeNames> =
typeof source === "object" ? source : { source, path: path!, pathMap };
this.warnIfCompiled(
"Adding an edge to a graph that has already been compiled. This will not be reflected in the compiled graph."
Expand All @@ -249,7 +252,7 @@ export class Graph<
/**
* @deprecated use `addEdge(START, key)` instead
*/
setEntryPoint(key: N): this {
setEntryPoint(key: NodeNames): this {
this.warnIfCompiled(
"Setting the entry point of a graph that has already been compiled. This will not be reflected in the compiled graph."
);
Expand All @@ -260,7 +263,7 @@ export class Graph<
/**
* @deprecated use `addEdge(key, END)` instead
*/
setFinishPoint(key: N): this {
setFinishPoint(key: NodeNames): this {
this.warnIfCompiled(
"Setting a finish point of a graph that has already been compiled. This will not be reflected in the compiled graph."
);
Expand All @@ -274,45 +277,48 @@ export class Graph<
interruptAfter,
}: {
checkpointer?: BaseCheckpointSaver;
interruptBefore?: N[] | All;
interruptAfter?: N[] | All;
} = {}): CompiledGraph<N> {
interruptBefore?: NodeNames[] | All;
interruptAfter?: NodeNames[] | All;
} = {}): CompiledGraph<NodeNames, RunInput, RunOutput> {
// validate the graph
this.validate([
...(Array.isArray(interruptBefore) ? interruptBefore : []),
...(Array.isArray(interruptAfter) ? interruptAfter : []),
]);

// create empty compiled graph
const compiled = new CompiledGraph({
const compiled = new CompiledGraph<NodeNames, RunInput, RunOutput>({
builder: this,
checkpointer,
interruptAfter,
interruptBefore,
autoValidate: false,
nodes: {} as Record<N | typeof START, PregelNode<RunInput, RunOutput>>,
nodes: {} as Record<
NodeNames | typeof START,
PregelNode<RunInput, RunOutput>
>,
channels: {
[START]: new EphemeralValue(),
[END]: new EphemeralValue(),
} as Record<N | typeof START | typeof END | string, BaseChannel>,
} as Record<NodeNames | typeof START | typeof END | string, BaseChannel>,
inputChannels: START,
outputChannels: END,
streamChannels: [] as N[],
streamChannels: [] as NodeNames[],
streamMode: "values",
});

// attach nodes, edges and branches
for (const [key, node] of Object.entries<NodeSpec<RunInput, RunOutput>>(
this.nodes
)) {
compiled.attachNode(key as N, node);
compiled.attachNode(key as NodeNames, node);
}
for (const [start, end] of this.edges) {
compiled.attachEdge(start, end);
}
for (const [start, branches] of Object.entries(this.branches)) {
for (const [name, branch] of Object.entries(branches)) {
compiled.attachBranch(start as N, name, branch);
compiled.attachBranch(start as NodeNames, name, branch);
}
}

Expand Down Expand Up @@ -375,35 +381,35 @@ export class Graph<
}

export class CompiledGraph<
N extends string,
NodeNames extends string,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput = any
> extends Pregel<
Record<N | typeof START, PregelNode<RunInput, RunOutput>>,
Record<N | typeof START | typeof END | string, BaseChannel>
Record<NodeNames | typeof START, PregelNode<RunInput, RunOutput>>,
Record<NodeNames | typeof START | typeof END | string, BaseChannel>
> {
declare NodeType: N;
declare NodeType: NodeNames;

declare RunInput: RunInput;

declare RunOutput: RunOutput;

builder: Graph<N, RunInput, RunOutput>;
builder: Graph<NodeNames, RunInput, RunOutput>;

constructor({
builder,
...rest
}: { builder: Graph<N, RunInput, RunOutput> } & PregelParams<
Record<N | typeof START, PregelNode<RunInput, RunOutput>>,
Record<N | typeof START | typeof END | string, BaseChannel>
}: { builder: Graph<NodeNames, RunInput, RunOutput> } & PregelParams<
Record<NodeNames | typeof START, PregelNode<RunInput, RunOutput>>,
Record<NodeNames | typeof START | typeof END | string, BaseChannel>
>) {
super(rest);
this.builder = builder;
}

attachNode(key: N, node: NodeSpec<RunInput, RunOutput>): void {
attachNode(key: NodeNames, node: NodeSpec<RunInput, RunOutput>): void {
this.channels[key] = new EphemeralValue();
this.nodes[key] = new PregelNode({
channels: [],
Expand All @@ -414,10 +420,13 @@ export class CompiledGraph<
.pipe(
new ChannelWrite([{ channel: key, value: PASSTHROUGH }], [TAG_HIDDEN])
);
(this.streamChannels as N[]).push(key);
(this.streamChannels as NodeNames[]).push(key);
}

attachEdge(start: N | typeof START, end: N | typeof END): void {
attachEdge(
start: NodeNames | typeof START,
end: NodeNames | typeof END
): void {
if (end === END) {
if (start === START) {
throw new Error("Cannot have an edge from START to END");
Expand All @@ -432,9 +441,9 @@ export class CompiledGraph<
}

attachBranch(
start: N | typeof START,
start: NodeNames | typeof START,
name: string,
branch: Branch<RunInput, N>
branch: Branch<RunInput, NodeNames>
) {
// add hidden start node
if (start === START && this.nodes[START]) {
Expand All @@ -460,7 +469,7 @@ export class CompiledGraph<
// attach branch readers
const ends = branch.ends
? Object.values(branch.ends)
: (Object.keys(this.nodes) as N[]);
: (Object.keys(this.nodes) as NodeNames[]);
for (const end of ends) {
if (end !== END) {
const channelName = `branch:${start}:${name}:${end}`;
Expand Down
Loading