1
+ /*
2
+ * Copyright 2023 - 2024 the original author or authors.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * https://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ package org .springframework .ai .vectorstore ;
17
+
18
+ import org .jetbrains .annotations .NotNull ;
19
+ import org .junit .jupiter .api .DisplayName ;
20
+ import org .junit .jupiter .api .Test ;
21
+ import org .mockito .ArgumentCaptor ;
22
+ import org .postgresql .ds .PGSimpleDataSource ;
23
+ import org .springframework .ai .chat .client .ChatClient ;
24
+ import org .springframework .ai .chat .client .advisor .VectorStoreChatMemoryAdvisor ;
25
+ import org .springframework .ai .chat .messages .AssistantMessage ;
26
+ import org .springframework .ai .chat .messages .SystemMessage ;
27
+ import org .springframework .ai .chat .model .ChatModel ;
28
+ import org .springframework .ai .chat .model .ChatResponse ;
29
+ import org .springframework .ai .chat .model .Generation ;
30
+ import org .springframework .ai .chat .prompt .Prompt ;
31
+ import org .springframework .ai .document .Document ;
32
+ import org .springframework .ai .embedding .EmbeddingModel ;
33
+ import org .springframework .jdbc .core .JdbcTemplate ;
34
+ import org .testcontainers .containers .PostgreSQLContainer ;
35
+ import org .testcontainers .junit .jupiter .Container ;
36
+ import org .testcontainers .junit .jupiter .Testcontainers ;
37
+
38
+ import java .util .List ;
39
+ import java .util .Map ;
40
+
41
+ import static org .assertj .core .api .Assertions .assertThat ;
42
+ import static org .mockito .Mockito .*;
43
+
44
+ /**
45
+ * @author Fabian Krüger
46
+ */
47
+ @ Testcontainers
48
+ class PgVectorStoreWithChatMemoryAdvisorIT {
49
+
50
+ float [] embed = { 0.003961659F , -0.0073295482F , 0.02663665F };
51
+
52
+ @ Container
53
+ @ SuppressWarnings ("resource" )
54
+ static PostgreSQLContainer <?> postgresContainer = new PostgreSQLContainer <>("pgvector/pgvector:pg16" )
55
+ .withUsername ("postgres" )
56
+ .withPassword ("postgres" );
57
+
58
+ /**
59
+ * Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
60
+ * messages from the (gp)vector store.
61
+ */
62
+ @ Test
63
+ @ DisplayName ("pgvector" )
64
+ void pgvector () throws Exception {
65
+ // faked ChatModel
66
+ ChatModel chatModel = chatModelAlwaysReturnsTheSameReply ();
67
+ // faked embedding model
68
+ EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed ();
69
+ PgVectorStore store = createPgVectorStoreUsingTestcontainer (embeddingModel );
70
+
71
+ // do the chat
72
+ ChatClient .builder (chatModel )
73
+ .build ()
74
+ .prompt ()
75
+ .user ("joke" )
76
+ .advisors (new VectorStoreChatMemoryAdvisor (store ))
77
+ .call ()
78
+ .chatResponse ();
79
+
80
+ verifyRequestHasBeenAdvisedWithMessagesFromVectorStore (chatModel );
81
+ }
82
+
83
+ private static @ NotNull ChatModel chatModelAlwaysReturnsTheSameReply () {
84
+ ChatModel chatModel = mock (ChatModel .class );
85
+ ArgumentCaptor <Prompt > argumentCaptor = ArgumentCaptor .forClass (Prompt .class );
86
+ ChatResponse chatResponse = new ChatResponse (List .of (new Generation (new AssistantMessage ("""
87
+ Why don't scientists trust atoms?
88
+ Because they make up everything!
89
+ """ ))));
90
+ when (chatModel .call (argumentCaptor .capture ())).thenReturn (chatResponse );
91
+ return chatModel ;
92
+ }
93
+
94
+ private static void initStore (PgVectorStore store ) throws Exception {
95
+ store .afterPropertiesSet ();
96
+ // fill the store
97
+ store .add (List .of (new Document ("Tell me a good joke" , Map .of ("conversationId" , "default" )),
98
+ new Document ("Tell me a bad joke" , Map .of ("conversationId" , "default" , "messageType" , "USER" ))));
99
+ }
100
+
101
+ private static PgVectorStore createPgVectorStoreUsingTestcontainer (EmbeddingModel embeddingModel ) throws Exception {
102
+ JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer ();
103
+ PgVectorStore vectorStore = new PgVectorStore .Builder (jdbcTemplate , embeddingModel ).withDimensions (3 ) // match
104
+ // embeddings
105
+ .withInitializeSchema (true )
106
+ .build ();
107
+ initStore (vectorStore );
108
+ return vectorStore ;
109
+ }
110
+
111
+ private static @ NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer () {
112
+ PGSimpleDataSource ds = new PGSimpleDataSource ();
113
+ ds .setUrl ("jdbc:postgresql://localhost:" + postgresContainer .getMappedPort (5432 ) + "/postgres" );
114
+ ds .setUser (postgresContainer .getUsername ());
115
+ ds .setPassword (postgresContainer .getPassword ());
116
+ return new JdbcTemplate (ds );
117
+ }
118
+
119
+ private @ NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed () {
120
+ EmbeddingModel embeddingModel = mock (EmbeddingModel .class );
121
+ when (embeddingModel .embed (any (Document .class ))).thenReturn (embed );
122
+ when (embeddingModel .embed (any (String .class ))).thenReturn (embed );
123
+ return embeddingModel ;
124
+ }
125
+
126
+ private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore (ChatModel chatModel ) {
127
+ ArgumentCaptor <Prompt > promptCaptor = ArgumentCaptor .forClass (Prompt .class );
128
+ verify (chatModel ).call (promptCaptor .capture ());
129
+ assertThat (promptCaptor .getValue ().getInstructions ().get (0 )).isInstanceOf (SystemMessage .class );
130
+ assertThat (promptCaptor .getValue ().getInstructions ().get (0 ).getContent ()).isEqualTo ("""
131
+
132
+
133
+ Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
134
+
135
+ ---------------------
136
+ LONG_TERM_MEMORY:
137
+ Tell me a good joke
138
+ Tell me a bad joke
139
+ ---------------------
140
+
141
+ """ );
142
+ }
143
+
144
+ }
0 commit comments