Skip to content

Commit 57b36d8

Browse files
author
fk
committed
Fix problem with VectorStoreChatMemoryAdvisor using pgvector
- add integration test
1 parent 37c3450 commit 57b36d8

File tree

2 files changed

+145
-2
lines changed

2 files changed

+145
-2
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
8787

8888
var searchRequest = SearchRequest.query(request.userText())
8989
.withTopK(this.doGetChatMemoryRetrieveSize(context))
90-
.withFilterExpression(
91-
"'" + DOCUMENT_METADATA_CONVERSATION_ID + "'=='" + this.doGetConversationId(context) + "'");
90+
.withFilterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(context) + "'");
9291

9392
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
9493

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.vectorstore;
17+
18+
import org.jetbrains.annotations.NotNull;
19+
import org.junit.jupiter.api.DisplayName;
20+
import org.junit.jupiter.api.Test;
21+
import org.mockito.ArgumentCaptor;
22+
import org.postgresql.ds.PGSimpleDataSource;
23+
import org.springframework.ai.chat.client.ChatClient;
24+
import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor;
25+
import org.springframework.ai.chat.messages.AssistantMessage;
26+
import org.springframework.ai.chat.messages.SystemMessage;
27+
import org.springframework.ai.chat.model.ChatModel;
28+
import org.springframework.ai.chat.model.ChatResponse;
29+
import org.springframework.ai.chat.model.Generation;
30+
import org.springframework.ai.chat.prompt.Prompt;
31+
import org.springframework.ai.document.Document;
32+
import org.springframework.ai.embedding.EmbeddingModel;
33+
import org.springframework.jdbc.core.JdbcTemplate;
34+
import org.testcontainers.containers.PostgreSQLContainer;
35+
import org.testcontainers.junit.jupiter.Container;
36+
import org.testcontainers.junit.jupiter.Testcontainers;
37+
38+
import java.util.List;
39+
import java.util.Map;
40+
41+
import static org.assertj.core.api.Assertions.assertThat;
42+
import static org.mockito.Mockito.*;
43+
44+
/**
45+
* @author Fabian Krüger
46+
*/
47+
@Testcontainers
48+
class PgVectorStoreWithChatMemoryAdvisorIT {
49+
50+
float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F };
51+
52+
@Container
53+
@SuppressWarnings("resource")
54+
static PostgreSQLContainer<?> postgresContainer = new PostgreSQLContainer<>("pgvector/pgvector:pg16")
55+
.withUsername("postgres")
56+
.withPassword("postgres");
57+
58+
/**
59+
* Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
60+
* messages from the (gp)vector store.
61+
*/
62+
@Test
63+
@DisplayName("pgvector")
64+
void pgvector() throws Exception {
65+
// faked ChatModel
66+
ChatModel chatModel = chatModelAlwaysReturnsTheSameReply();
67+
// faked embedding model
68+
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
69+
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
70+
71+
// do the chat
72+
ChatClient.builder(chatModel)
73+
.build()
74+
.prompt()
75+
.user("joke")
76+
.advisors(new VectorStoreChatMemoryAdvisor(store))
77+
.call()
78+
.chatResponse();
79+
80+
verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel);
81+
}
82+
83+
private static @NotNull ChatModel chatModelAlwaysReturnsTheSameReply() {
84+
ChatModel chatModel = mock(ChatModel.class);
85+
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
86+
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
87+
Why don't scientists trust atoms?
88+
Because they make up everything!
89+
"""))));
90+
when(chatModel.call(argumentCaptor.capture())).thenReturn(chatResponse);
91+
return chatModel;
92+
}
93+
94+
private static void initStore(PgVectorStore store) throws Exception {
95+
store.afterPropertiesSet();
96+
// fill the store
97+
store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")),
98+
new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER"))));
99+
}
100+
101+
private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingModel embeddingModel) throws Exception {
102+
JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer();
103+
PgVectorStore vectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withDimensions(3) // match
104+
// embeddings
105+
.withInitializeSchema(true)
106+
.build();
107+
initStore(vectorStore);
108+
return vectorStore;
109+
}
110+
111+
private static @NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer() {
112+
PGSimpleDataSource ds = new PGSimpleDataSource();
113+
ds.setUrl("jdbc:postgresql://localhost:" + postgresContainer.getMappedPort(5432) + "/postgres");
114+
ds.setUser(postgresContainer.getUsername());
115+
ds.setPassword(postgresContainer.getPassword());
116+
return new JdbcTemplate(ds);
117+
}
118+
119+
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
120+
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);
121+
when(embeddingModel.embed(any(Document.class))).thenReturn(embed);
122+
when(embeddingModel.embed(any(String.class))).thenReturn(embed);
123+
return embeddingModel;
124+
}
125+
126+
private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatModel chatModel) {
127+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
128+
verify(chatModel).call(promptCaptor.capture());
129+
assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
130+
assertThat(promptCaptor.getValue().getInstructions().get(0).getContent()).isEqualTo("""
131+
132+
133+
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
134+
135+
---------------------
136+
LONG_TERM_MEMORY:
137+
Tell me a good joke
138+
Tell me a bad joke
139+
---------------------
140+
141+
""");
142+
}
143+
144+
}

0 commit comments

Comments
 (0)