Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(azure-cosmosdb): add session context for a user mongodb #7436

Merged
merged 10 commits into from
Jan 18, 2025
81 changes: 78 additions & 3 deletions libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ export interface AzureCosmosDBMongoChatHistoryDBConfig {
readonly collectionName?: string;
}

export type ChatSessionMongo = {
id: string;
context: Record<string, unknown>;
};

const ID_KEY = "sessionId";
const ID_USER = "userId";

export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"];
Expand All @@ -33,6 +39,8 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

private initPromise?: Promise<void>;

private context: Record<string, unknown> = {};

private readonly client: MongoClient | undefined;

private database: Db;
Expand All @@ -41,11 +49,14 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

private sessionId: string;

private userId: string;

initialize: () => Promise<void>;

constructor(
dbConfig: AzureCosmosDBMongoChatHistoryDBConfig,
sessionId: string
sessionId: string,
userId: string
) {
super();

Expand All @@ -70,6 +81,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
const collectionName = dbConfig.collectionName ?? "chatHistory";

this.sessionId = sessionId;
this.userId = userId ?? "anonymous";

// Deferring initialization to the first call to `initialize`
this.initialize = () => {
Expand Down Expand Up @@ -120,6 +132,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
const messages = document?.messages || [];
return mapStoredMessagesToChatMessages(messages);
Expand All @@ -134,10 +147,12 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
await this.initialize();

const messages = mapChatMessagesToStoredMessages([message]);
const context = await this.getContext();
await this.collection.updateOne(
{ [ID_KEY]: this.sessionId },
{ [ID_KEY]: this.sessionId, [ID_USER]: this.userId },
{
$push: { messages: { $each: messages } } as PushOperator<Document>,
$set: { context },
},
{ upsert: true }
);
Expand All @@ -150,6 +165,66 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
async clear(): Promise<void> {
await this.initialize();

await this.collection.deleteOne({ [ID_KEY]: this.sessionId });
await this.collection.deleteOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
}

async getAllSessions(): Promise<ChatSessionMongo[]> {
await this.initialize();
const documents = await this.collection
.find({
[ID_USER]: this.userId,
})
.toArray();

const chatSessions: ChatSessionMongo[] = documents.map((doc) => ({
id: doc[ID_KEY],
user_id: doc[ID_USER],
context: doc.context || {},
}));

return chatSessions;
}

async clearAllSessions() {
await this.initialize();
try {
await this.collection.deleteMany({
[ID_USER]: this.userId,
});
} catch (error) {
console.error("Error clearing chat history sessions:", error);
throw error;
}
}

async getContext(): Promise<Record<string, unknown>> {
await this.initialize();

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
this.context = document?.context || this.context;
return this.context;
}

async setContext(context: Record<string, unknown>): Promise<void> {
await this.initialize();

try {
await this.collection.updateOne(
{ [ID_KEY]: this.sessionId },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need [ID_USER]: this.userId, here too?

I assume this is sufficient in this case to just pass id?

{
$set: { context },
},
{ upsert: true }
);
} catch (error) {
console.error("Error setting chat history context", error);
throw error;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ test("Test Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

const blankResult = await chatHistory.getMessages();
Expand Down Expand Up @@ -70,9 +72,11 @@ test("Test clear Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

await chatHistory.addUserMessage("Who is the best vocalist?");
Expand All @@ -93,3 +97,50 @@ test("Test clear Azure Cosmos MongoDB history store", async () => {

await mongoClient.close();
});

test("Test getAllSessions and clearAllSessions", async () => {
expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined();

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const mongoClient = new MongoClient(
process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING!
);
const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = {
client: mongoClient,
connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING,
databaseName: "langchain",
collectionName: "chathistory",
};

const sessionId1 = new ObjectId().toString();
const userId1 = new ObjectId().toString();
const sessionId2 = new ObjectId().toString();
const userId2 = new ObjectId().toString();

const chatHistory1 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId1,
userId1
);
const chatHistory2 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId2,
userId2
);

await chatHistory1.addUserMessage("What is AI?");
await chatHistory1.addAIChatMessage("AI stands for Artificial Intelligence.");
await chatHistory2.addUserMessage("What is the best programming language?");
await chatHistory2.addAIChatMessage("It depends on the use case.");

const allSessions = await chatHistory1.getAllSessions();
expect(allSessions.length).toBe(2);
expect(allSessions[0].id).toBe(sessionId1);
expect(allSessions[1].id).toBe(sessionId2);

await chatHistory1.clearAllSessions();
const clearedSessions = await chatHistory1.getAllSessions();
expect(clearedSessions.length).toBe(0);

await mongoClient.close();
});
Loading