From 541ece1037d48eea8aa545400d6082e5a6fee5fd Mon Sep 17 00:00:00 2001 From: Ben Burns <803016+benjamincburns@users.noreply.github.com> Date: Thu, 24 Oct 2024 21:29:26 +1300 Subject: [PATCH] fix(checkpoint-mongodb): fix state deltas, pendingWrites, pending_sends Fixes #595 Fixes #589 --- libs/checkpoint-mongodb/src/index.ts | 326 +++++++++++++----- .../src/tests/checkpoints.int.test.ts | 6 +- libs/checkpoint-validation/src/spec/list.ts | 7 +- libs/checkpoint-validation/src/spec/put.ts | 3 - 4 files changed, 254 insertions(+), 88 deletions(-) diff --git a/libs/checkpoint-mongodb/src/index.ts b/libs/checkpoint-mongodb/src/index.ts index eba6c50f..2c6c1ee1 100644 --- a/libs/checkpoint-mongodb/src/index.ts +++ b/libs/checkpoint-mongodb/src/index.ts @@ -1,4 +1,9 @@ -import { type MongoClient, type Db as MongoDatabase } from "mongodb"; +import { + Binary, + WithId, + type MongoClient, + type Db as MongoDatabase, +} from "mongodb"; import type { RunnableConfig } from "@langchain/core/runnables"; import { BaseCheckpointSaver, @@ -10,6 +15,10 @@ import { type CheckpointMetadata, CheckpointPendingWrite, validCheckpointMetadataKeys, + ChannelVersions, + copyCheckpoint, + TASKS, + SendProtocol, } from "@langchain/langgraph-checkpoint"; import { applyMigrations, needsMigration } from "./migrations/index.js"; @@ -23,9 +32,45 @@ export type MongoDBSaverParams = { dbName?: string; checkpointCollectionName?: string; checkpointWritesCollectionName?: string; + channelVersionsCollectionName?: string; schemaVersionCollectionName?: string; }; +interface CheckpointDoc { + thread_id: string; + checkpoint_ns: string; + checkpoint_id: string; + parent_checkpoint_id: string | null; + type: string; + checkpoint: Binary; + metadata: CheckpointMetadata; +} + +interface CheckpointWriteDoc { + thread_id: string; + checkpoint_ns: string; + checkpoint_id: string; + task_id: string; + idx: number; + channel: string; + type: string; + value: Binary; +} + +interface ChannelVersionDoc { + thread_id: string; + checkpoint_ns: string; + checkpoint_id: string; + channel: string; + version: string | number; + type: string; + value: Binary; +} + +interface SchemaVersionDoc { + version: number; +} + /** * A LangGraph checkpoint saver backed by a MongoDB database. */ @@ -40,6 +85,8 @@ export class MongoDBSaver extends BaseCheckpointSaver { checkpointWritesCollectionName = "checkpoint_writes"; + channelVersionsCollectionName = "channel_versions"; + schemaVersionCollectionName = "schema_version"; constructor( @@ -48,6 +95,7 @@ export class MongoDBSaver extends BaseCheckpointSaver { dbName, checkpointCollectionName, checkpointWritesCollectionName, + channelVersionsCollectionName, schemaVersionCollectionName, }: MongoDBSaverParams, serde?: SerializerProtocol @@ -61,6 +109,8 @@ export class MongoDBSaver extends BaseCheckpointSaver { checkpointWritesCollectionName ?? this.checkpointWritesCollectionName; this.schemaVersionCollectionName = schemaVersionCollectionName ?? this.schemaVersionCollectionName; + this.channelVersionsCollectionName = + channelVersionsCollectionName ?? this.channelVersionsCollectionName; } /** @@ -90,16 +140,12 @@ export class MongoDBSaver extends BaseCheckpointSaver { } private async initializeSchemaVersion(): Promise { - const schemaVersionCollection = this.db.collection( + const schemaVersionCollection = this.db.collection( this.schemaVersionCollectionName ); // empty database, no migrations needed - just set the schema version and move on if (await this.isDatabaseEmpty()) { - const schemaVersionCollection = this.db.collection( - this.schemaVersionCollectionName - ); - const versionDoc = await schemaVersionCollection.findOne({}); if (!versionDoc) { await schemaVersionCollection.insertOne({ @@ -173,40 +219,66 @@ export class MongoDBSaver extends BaseCheckpointSaver { } } - /** - * Retrieves a checkpoint from the MongoDB database based on the - * provided config. If the config contains a "checkpoint_id" key, the checkpoint with - * the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint - * for the given thread ID is retrieved. - */ - async getTuple(config: RunnableConfig): Promise { - await this.setup(); + private async getChannelValues( + thread_id: string, + checkpoint_ns: string, + checkpoint_id: string, + channel_versions: ChannelVersions + ): Promise> { + return Object.fromEntries( + await Promise.all( + Object.entries(channel_versions).map(async ([channel, version]) => { + const doc = await this.db + .collection(this.channelVersionsCollectionName) + .findOne({ + thread_id, + checkpoint_ns, + checkpoint_id, + channel, + version, + }); + if (!doc) { + return []; + } + return [ + channel, + await this.serde.loadsTyped(doc.type, doc.value.value()), + ]; + }) + ) + ); + } - const { - thread_id, - checkpoint_ns = "", - checkpoint_id, - } = config.configurable ?? {}; - let query; - if (checkpoint_id) { - query = { - thread_id, - checkpoint_ns, - checkpoint_id, - }; - } else { - query = { thread_id, checkpoint_ns }; - } - const result = await this.db - .collection(this.checkpointCollectionName) - .find(query) - .sort("checkpoint_id", -1) - .limit(1) - .toArray(); - if (result.length === 0) { - return undefined; - } - const doc = result[0]; + private async getPendingSends( + thread_id: string, + checkpoint_ns: string, + parent_checkpoint_id: string + ): Promise { + return Promise.all( + ( + await this.db + .collection(this.checkpointWritesCollectionName) + .find({ + thread_id, + checkpoint_ns, + checkpoint_id: parent_checkpoint_id, + channel: TASKS, + }) + .toArray() + ).map((write) => { + return this.serde.loadsTyped( + write.type, + write.value.value() + ) as SendProtocol; + }) + ); + } + + private async constructCheckpointTuple( + thread_id: string, + checkpoint_ns: string, + doc: CheckpointDoc + ): Promise { const configurableValues = { thread_id, checkpoint_ns, @@ -216,10 +288,37 @@ export class MongoDBSaver extends BaseCheckpointSaver { doc.type, doc.checkpoint.value() )) as Checkpoint; + + checkpoint.pending_sends = doc.parent_checkpoint_id + ? await this.getPendingSends( + thread_id, + checkpoint_ns, + doc.parent_checkpoint_id + ) + : []; + + checkpoint.channel_values = checkpoint.channel_values ?? {}; + checkpoint.channel_versions = checkpoint.channel_versions ?? {}; + + // fetch channel values if they weren't stored with the rest of the checkpoint data + if ( + Object.keys(checkpoint.channel_versions).length !== + Object.keys(checkpoint.channel_values).length + ) { + checkpoint.channel_values = + (await this.getChannelValues( + thread_id, + checkpoint_ns, + doc.checkpoint_id, + checkpoint.channel_versions + )) ?? {}; + } + const serializedWrites = await this.db - .collection(this.checkpointWritesCollectionName) + .collection(this.checkpointWritesCollectionName) .find(configurableValues) .toArray(); + const pendingWrites: CheckpointPendingWrite[] = await Promise.all( serializedWrites.map(async (serializedWrite) => { return [ @@ -236,7 +335,7 @@ export class MongoDBSaver extends BaseCheckpointSaver { config: { configurable: configurableValues }, checkpoint, pendingWrites, - metadata: doc.metadata as CheckpointMetadata, + metadata: doc.metadata, parentConfig: doc.parent_checkpoint_id != null ? { @@ -250,6 +349,43 @@ export class MongoDBSaver extends BaseCheckpointSaver { }; } + /** + * Retrieves a checkpoint from the MongoDB database based on the + * provided config. If the config contains a "checkpoint_id" key, the checkpoint with + * the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint + * for the given thread ID is retrieved. + */ + async getTuple(config: RunnableConfig): Promise { + await this.setup(); + + const { + thread_id, + checkpoint_ns = "", + checkpoint_id, + } = config.configurable ?? {}; + let query; + if (checkpoint_id) { + query = { + thread_id, + checkpoint_ns, + checkpoint_id, + }; + } else { + query = { thread_id, checkpoint_ns }; + } + const result = await this.db + .collection(this.checkpointCollectionName) + .find(query) + .sort("checkpoint_id", -1) + .limit(1) + .toArray(); + if (result.length === 0) { + return undefined; + } + const doc = result[0]; + return this.constructCheckpointTuple(thread_id, checkpoint_ns, doc); + } + /** * Retrieve a list of checkpoint tuples from the MongoDB database based * on the provided config. The checkpoints are ordered by checkpoint ID @@ -302,35 +438,50 @@ export class MongoDBSaver extends BaseCheckpointSaver { } for await (const doc of result) { - const checkpoint = (await this.serde.loadsTyped( - doc.type, - doc.checkpoint.value() - )) as Checkpoint; - const metadata = doc.metadata as CheckpointMetadata; - - yield { - config: { - configurable: { - thread_id: doc.thread_id, - checkpoint_ns: doc.checkpoint_ns, - checkpoint_id: doc.checkpoint_id, - }, - }, - checkpoint, - metadata, - parentConfig: doc.parent_checkpoint_id - ? { - configurable: { - thread_id: doc.thread_id, - checkpoint_ns: doc.checkpoint_ns, - checkpoint_id: doc.parent_checkpoint_id, - }, - } - : undefined, - }; + yield this.constructCheckpointTuple( + doc.thread_id, + doc.checkpoint_ns, + doc as WithId + ); } } + private async putChannelData( + thread_id: string, + checkpoint_ns: string, + checkpoint_id: string, + channel_values: Record, + newVersions: ChannelVersions + ) { + await Promise.all( + Object.entries(newVersions).map(async ([channel, version]) => { + const [type, value] = this.serde.dumpsTyped(channel_values[channel]); + + const doc: ChannelVersionDoc = { + thread_id, + checkpoint_ns, + checkpoint_id, + channel, + version, + type, + value: new Binary(value), + }; + + const upsertQuery = { + thread_id, + checkpoint_ns, + checkpoint_id, + channel, + version, + }; + + await this.db + .collection(this.channelVersionsCollectionName) + .updateOne(upsertQuery, { $set: doc }, { upsert: true }); + }) + ); + } + /** * Saves a checkpoint to the MongoDB database. The checkpoint is associated * with the provided config and its parent config (if any). @@ -338,7 +489,8 @@ export class MongoDBSaver extends BaseCheckpointSaver { async put( config: RunnableConfig, checkpoint: Checkpoint, - metadata: CheckpointMetadata + metadata: CheckpointMetadata, + newVersions: ChannelVersions ): Promise { await this.setup(); @@ -350,12 +502,28 @@ export class MongoDBSaver extends BaseCheckpointSaver { `The provided config must contain a configurable field with a "thread_id" field.` ); } + + const preparedCheckpoint: Partial = copyCheckpoint(checkpoint); + delete preparedCheckpoint.pending_sends; + delete preparedCheckpoint.channel_values; + + await this.putChannelData( + thread_id, + checkpoint_ns, + checkpoint_id, + checkpoint.channel_values, + newVersions + ); + const [checkpointType, serializedCheckpoint] = - this.serde.dumpsTyped(checkpoint); - const doc = { + this.serde.dumpsTyped(preparedCheckpoint); + const doc: CheckpointDoc = { + thread_id, + checkpoint_ns, + checkpoint_id, parent_checkpoint_id: config.configurable?.checkpoint_id, type: checkpointType, - checkpoint: serializedCheckpoint, + checkpoint: new Binary(serializedCheckpoint), metadata, }; const upsertQuery = { @@ -363,13 +531,15 @@ export class MongoDBSaver extends BaseCheckpointSaver { checkpoint_ns, checkpoint_id, }; - await this.db.collection(this.checkpointCollectionName).updateOne( - upsertQuery, - { - $set: doc, - }, - { upsert: true } - ); + await this.db + .collection(this.checkpointCollectionName) + .updateOne( + upsertQuery, + { + $set: doc, + }, + { upsert: true } + ); return { configurable: { thread_id, diff --git a/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts b/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts index d5528ee3..31ffa948 100644 --- a/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts +++ b/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts @@ -68,7 +68,8 @@ describe("MongoDBSaver", () => { const runnableConfig = await saver.put( { configurable: { thread_id: "1" } }, checkpoint1, - { source: "update", step: -1, writes: null, parents: {} } + { source: "update", step: -1, writes: null, parents: {} }, + checkpoint1.channel_versions ); expect(runnableConfig).toEqual({ configurable: { @@ -117,7 +118,8 @@ describe("MongoDBSaver", () => { }, }, checkpoint2, - { source: "update", step: -1, writes: null, parents: {} } + { source: "update", step: -1, writes: null, parents: {} }, + checkpoint2.channel_versions ); // verify that parentTs is set and retrieved correctly for second checkpoint diff --git a/libs/checkpoint-validation/src/spec/list.ts b/libs/checkpoint-validation/src/spec/list.ts index bb00f590..cad583df 100644 --- a/libs/checkpoint-validation/src/spec/list.ts +++ b/libs/checkpoint-validation/src/spec/list.ts @@ -116,14 +116,11 @@ export function listTests( } else { expect(actualTuplesMap.size).toEqual(expectedTuplesMap.size); for (const [key, value] of actualTuplesMap.entries()) { - // TODO: MongoDBSaver and SQLiteSaver don't return pendingWrites on list, so we need to special case them - // see: https://github.com/langchain-ai/langgraphjs/issues/589 + // TODO: SQLiteSaver doesn't return pendingWrites on list, so we need to special case it // see: https://github.com/langchain-ai/langgraphjs/issues/590 const checkpointerIncludesPendingWritesOnList = initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-mongodb" && - initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-sqlite"; + "@langchain/langgraph-checkpoint-sqlite"; const expectedTuple = expectedTuplesMap.get(key); if (!checkpointerIncludesPendingWritesOnList) { diff --git a/libs/checkpoint-validation/src/spec/put.ts b/libs/checkpoint-validation/src/spec/put.ts index 5ecc343d..04177510 100644 --- a/libs/checkpoint-validation/src/spec/put.ts +++ b/libs/checkpoint-validation/src/spec/put.ts @@ -216,10 +216,7 @@ export function putTests( // TODO: all of the checkpointers below store full channel_values on every put, rather than storing deltas // see: https://github.com/langchain-ai/langgraphjs/issues/593 // see: https://github.com/langchain-ai/langgraphjs/issues/594 - // see: https://github.com/langchain-ai/langgraphjs/issues/595 MemorySaver: "TODO: MemorySaver doesn't store channel deltas", - "@langchain/langgraph-checkpoint-mongodb": - "TODO: MongoDBSaver doesn't store channel deltas", "@langchain/langgraph-checkpoint-sqlite": "TODO: SQLiteSaver doesn't store channel deltas", })(