1
+ /*
2
+ * Copyright 2023 - 2024 the original author or authors.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * https://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ package org .springframework .ai .vectorstore ;
17
+
18
+ import org .jetbrains .annotations .NotNull ;
19
+ import org .junit .jupiter .api .Assertions ;
20
+ import org .junit .jupiter .api .DisplayName ;
21
+ import org .junit .jupiter .api .Test ;
22
+ import org .mockito .ArgumentCaptor ;
23
+ import org .postgresql .ds .PGSimpleDataSource ;
24
+ import org .springframework .ai .chat .client .ChatClient ;
25
+ import org .springframework .ai .chat .client .advisor .VectorStoreChatMemoryAdvisor ;
26
+ import org .springframework .ai .chat .messages .AssistantMessage ;
27
+ import org .springframework .ai .chat .messages .SystemMessage ;
28
+ import org .springframework .ai .chat .model .ChatModel ;
29
+ import org .springframework .ai .chat .model .ChatResponse ;
30
+ import org .springframework .ai .chat .model .Generation ;
31
+ import org .springframework .ai .chat .prompt .Prompt ;
32
+ import org .springframework .ai .document .Document ;
33
+ import org .springframework .ai .embedding .EmbeddingModel ;
34
+ import org .springframework .jdbc .core .JdbcTemplate ;
35
+ import org .testcontainers .containers .PostgreSQLContainer ;
36
+ import org .testcontainers .junit .jupiter .Container ;
37
+ import org .testcontainers .junit .jupiter .Testcontainers ;
38
+
39
+ import java .util .List ;
40
+ import java .util .Map ;
41
+
42
+ import static org .assertj .core .api .Assertions .assertThat ;
43
+ import static org .mockito .Mockito .*;
44
+
45
+ /**
46
+ * @author Fabian Krüger
47
+ */
48
+ @ Testcontainers
49
+ class PgVectorStoreWithChatMemoryAdvisorIT {
50
+
51
+ float [] embed = { 0.003961659F , -0.0073295482F , 0.02663665F };
52
+
53
+ @ Container
54
+ @ SuppressWarnings ("resource" )
55
+ static PostgreSQLContainer <?> postgresContainer = new PostgreSQLContainer <>("pgvector/pgvector:pg16" )
56
+ .withUsername ("postgres" )
57
+ .withPassword ("postgres" );
58
+
59
+ /**
60
+ * Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
61
+ * messages from the (gp)vector store.
62
+ */
63
+ @ Test
64
+ @ DisplayName ("Advised chat should have similar messages from vector store" )
65
+ void advisedChatShouldHaveSimilarMessagesFromVectorStore () throws Exception {
66
+ // faked ChatModel
67
+ ChatModel chatModel = chatModelAlwaysReturnsTheSameReply ();
68
+ // faked embedding model
69
+ EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed ();
70
+ PgVectorStore store = createPgVectorStoreUsingTestcontainer (embeddingModel );
71
+
72
+ // do the chat
73
+ ChatClient .builder (chatModel )
74
+ .build ()
75
+ .prompt ()
76
+ .user ("joke" )
77
+ .advisors (new VectorStoreChatMemoryAdvisor (store ))
78
+ .call ()
79
+ .chatResponse ();
80
+
81
+ verifyRequestHasBeenAdvisedWithMessagesFromVectorStore (chatModel );
82
+ }
83
+
84
+ private static @ NotNull ChatModel chatModelAlwaysReturnsTheSameReply () {
85
+ ChatModel chatModel = mock (ChatModel .class );
86
+ ArgumentCaptor <Prompt > argumentCaptor = ArgumentCaptor .forClass (Prompt .class );
87
+ ChatResponse chatResponse = new ChatResponse (List .of (new Generation (new AssistantMessage ("""
88
+ Why don't scientists trust atoms?
89
+ Because they make up everything!
90
+ """ ))));
91
+ when (chatModel .call (argumentCaptor .capture ())).thenReturn (chatResponse );
92
+ return chatModel ;
93
+ }
94
+
95
+ private static void initStore (PgVectorStore store ) throws Exception {
96
+ store .afterPropertiesSet ();
97
+ // fill the store
98
+ store .add (List .of (new Document ("Tell me a good joke" , Map .of ("conversationId" , "default" )),
99
+ new Document ("Tell me a bad joke" , Map .of ("conversationId" , "default" , "messageType" , "USER" ))));
100
+ }
101
+
102
+ private static PgVectorStore createPgVectorStoreUsingTestcontainer (EmbeddingModel embeddingModel ) throws Exception {
103
+ JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer ();
104
+ PgVectorStore vectorStore = new PgVectorStore .Builder (jdbcTemplate , embeddingModel ).withDimensions (3 ) // match
105
+ // embeddings
106
+ .withInitializeSchema (true )
107
+ .build ();
108
+ initStore (vectorStore );
109
+ return vectorStore ;
110
+ }
111
+
112
+ private static @ NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer () {
113
+ PGSimpleDataSource ds = new PGSimpleDataSource ();
114
+ ds .setUrl ("jdbc:postgresql://localhost:" + postgresContainer .getMappedPort (5432 ) + "/postgres" );
115
+ ds .setUser (postgresContainer .getUsername ());
116
+ ds .setPassword (postgresContainer .getPassword ());
117
+ return new JdbcTemplate (ds );
118
+ }
119
+
120
+ private @ NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed () {
121
+ EmbeddingModel embeddingModel = mock (EmbeddingModel .class );
122
+ when (embeddingModel .embed (any (Document .class ))).thenReturn (embed );
123
+ when (embeddingModel .embed (any (String .class ))).thenReturn (embed );
124
+ return embeddingModel ;
125
+ }
126
+
127
+ private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore (ChatModel chatModel ) {
128
+ ArgumentCaptor <Prompt > promptCaptor = ArgumentCaptor .forClass (Prompt .class );
129
+ verify (chatModel ).call (promptCaptor .capture ());
130
+ assertThat (promptCaptor .getValue ().getInstructions ().get (0 )).isInstanceOf (SystemMessage .class );
131
+ assertThat (promptCaptor .getValue ().getInstructions ().get (0 ).getContent ()).isEqualTo ("""
132
+
133
+
134
+ Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
135
+
136
+ ---------------------
137
+ LONG_TERM_MEMORY:
138
+ Tell me a good joke
139
+ Tell me a bad joke
140
+ ---------------------
141
+
142
+ """ );
143
+ }
144
+
145
+ }
0 commit comments