Skip to content

Commit 55c524e

Browse files
author
fk
committed
Fix problem with VectorStoreChatMemoryAdvisor using pgvector
- add integration test
1 parent 37c3450 commit 55c524e

File tree

2 files changed

+146
-2
lines changed

2 files changed

+146
-2
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
8787

8888
var searchRequest = SearchRequest.query(request.userText())
8989
.withTopK(this.doGetChatMemoryRetrieveSize(context))
90-
.withFilterExpression(
91-
"'" + DOCUMENT_METADATA_CONVERSATION_ID + "'=='" + this.doGetConversationId(context) + "'");
90+
.withFilterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(context) + "'");
9291

9392
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
9493

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.vectorstore;
17+
18+
import org.jetbrains.annotations.NotNull;
19+
import org.junit.jupiter.api.Assertions;
20+
import org.junit.jupiter.api.DisplayName;
21+
import org.junit.jupiter.api.Test;
22+
import org.mockito.ArgumentCaptor;
23+
import org.postgresql.ds.PGSimpleDataSource;
24+
import org.springframework.ai.chat.client.ChatClient;
25+
import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor;
26+
import org.springframework.ai.chat.messages.AssistantMessage;
27+
import org.springframework.ai.chat.messages.SystemMessage;
28+
import org.springframework.ai.chat.model.ChatModel;
29+
import org.springframework.ai.chat.model.ChatResponse;
30+
import org.springframework.ai.chat.model.Generation;
31+
import org.springframework.ai.chat.prompt.Prompt;
32+
import org.springframework.ai.document.Document;
33+
import org.springframework.ai.embedding.EmbeddingModel;
34+
import org.springframework.jdbc.core.JdbcTemplate;
35+
import org.testcontainers.containers.PostgreSQLContainer;
36+
import org.testcontainers.junit.jupiter.Container;
37+
import org.testcontainers.junit.jupiter.Testcontainers;
38+
39+
import java.util.List;
40+
import java.util.Map;
41+
42+
import static org.assertj.core.api.Assertions.assertThat;
43+
import static org.mockito.Mockito.*;
44+
45+
/**
46+
* @author Fabian Krüger
47+
*/
48+
@Testcontainers
49+
class PgVectorStoreWithChatMemoryAdvisorIT {
50+
51+
float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F };
52+
53+
@Container
54+
@SuppressWarnings("resource")
55+
static PostgreSQLContainer<?> postgresContainer = new PostgreSQLContainer<>("pgvector/pgvector:pg16")
56+
.withUsername("postgres")
57+
.withPassword("postgres");
58+
59+
/**
60+
* Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
61+
* messages from the (gp)vector store.
62+
*/
63+
@Test
64+
@DisplayName("Advised chat should have similar messages from vector store")
65+
void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception {
66+
// faked ChatModel
67+
ChatModel chatModel = chatModelAlwaysReturnsTheSameReply();
68+
// faked embedding model
69+
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
70+
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
71+
72+
// do the chat
73+
ChatClient.builder(chatModel)
74+
.build()
75+
.prompt()
76+
.user("joke")
77+
.advisors(new VectorStoreChatMemoryAdvisor(store))
78+
.call()
79+
.chatResponse();
80+
81+
verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel);
82+
}
83+
84+
private static @NotNull ChatModel chatModelAlwaysReturnsTheSameReply() {
85+
ChatModel chatModel = mock(ChatModel.class);
86+
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
87+
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
88+
Why don't scientists trust atoms?
89+
Because they make up everything!
90+
"""))));
91+
when(chatModel.call(argumentCaptor.capture())).thenReturn(chatResponse);
92+
return chatModel;
93+
}
94+
95+
private static void initStore(PgVectorStore store) throws Exception {
96+
store.afterPropertiesSet();
97+
// fill the store
98+
store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")),
99+
new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER"))));
100+
}
101+
102+
private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingModel embeddingModel) throws Exception {
103+
JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer();
104+
PgVectorStore vectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withDimensions(3) // match
105+
// embeddings
106+
.withInitializeSchema(true)
107+
.build();
108+
initStore(vectorStore);
109+
return vectorStore;
110+
}
111+
112+
private static @NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer() {
113+
PGSimpleDataSource ds = new PGSimpleDataSource();
114+
ds.setUrl("jdbc:postgresql://localhost:" + postgresContainer.getMappedPort(5432) + "/postgres");
115+
ds.setUser(postgresContainer.getUsername());
116+
ds.setPassword(postgresContainer.getPassword());
117+
return new JdbcTemplate(ds);
118+
}
119+
120+
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
121+
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);
122+
when(embeddingModel.embed(any(Document.class))).thenReturn(embed);
123+
when(embeddingModel.embed(any(String.class))).thenReturn(embed);
124+
return embeddingModel;
125+
}
126+
127+
private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatModel chatModel) {
128+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
129+
verify(chatModel).call(promptCaptor.capture());
130+
assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
131+
assertThat(promptCaptor.getValue().getInstructions().get(0).getContent()).isEqualTo("""
132+
133+
134+
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
135+
136+
---------------------
137+
LONG_TERM_MEMORY:
138+
Tell me a good joke
139+
Tell me a bad joke
140+
---------------------
141+
142+
""");
143+
}
144+
145+
}

0 commit comments

Comments
 (0)