Skip to content

Commit bf8dabf

Browse files
tzolovmarkpollack
authored andcommitted
Improve stream advisor processing
* Fixes an issue with advisor name resolution * Streamlines repeating code * Add a new advisor strategy for ON_FINISH_REASON streaming responses, which is used by the Q&A advisor * Improve observable instrumentation by passing the parent observation to the advisor observation
1 parent 37c3450 commit bf8dabf

File tree

12 files changed

+308
-328
lines changed

12 files changed

+308
-328
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ public class SimplePersistentVectorStoreIT {
4444
@Autowired
4545
private EmbeddingModel embeddingModel;
4646

47+
@TempDir(cleanup = CleanupMode.ON_SUCCESS)
48+
Path workingDir;
49+
4750
@Test
48-
void persist(@TempDir(cleanup = CleanupMode.ON_SUCCESS) Path workingDir) {
51+
void persist() {
4952
JsonReader jsonReader = new JsonReader(bikesJsonResource, new ProductMetadataGenerator(), "price", "name",
5053
"shortDescription", "description", "tags");
5154
List<Document> documents = jsonReader.get();

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 125 additions & 145 deletions
Large diffs are not rendered by default.

spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818

1919
import java.util.Map;
2020

21-
import reactor.core.publisher.Flux;
22-
23-
import org.springframework.ai.chat.model.ChatResponse;
2421
import org.springframework.ai.chat.model.ChatModel;
22+
import org.springframework.ai.chat.model.ChatResponse;
23+
import org.springframework.ai.chat.model.MessageAggregator;
2524
import org.springframework.ai.chat.prompt.Prompt;
25+
import org.springframework.util.StringUtils;
26+
27+
import reactor.core.publisher.Flux;
2628

2729
/**
2830
* Advisor called before and after the {@link ChatModel#call(Prompt)} and
@@ -34,6 +36,35 @@
3436
*/
3537
public interface RequestResponseAdvisor {
3638

39+
public enum StreamResponseMode {
40+
41+
/**
42+
* The sync advisor will be called for each response chunk (e.g. on each Flux
43+
* item).
44+
*/
45+
PER_CHUNK,
46+
/**
47+
* The sync advisor is called only on chunks that contain a finish reason. Usually
48+
* the last chunk in the stream.
49+
*/
50+
ON_FINISH_REASON,
51+
/**
52+
* The sync advisor is called only once after the stream is completed and an
53+
* aggregated response is computed. Note that at that stage the advisor can not
54+
* modify the response, but only observe it and react on the aggregated response.
55+
*/
56+
AGGREGATE,
57+
/**
58+
* Delegates to the stream advisor implementation.
59+
*/
60+
CUSTOM;
61+
62+
}
63+
64+
default StreamResponseMode getStreamResponseMode() {
65+
return StreamResponseMode.CUSTOM;
66+
}
67+
3768
/**
3869
* @return the advisor name.
3970
*/
@@ -73,6 +104,31 @@ default ChatResponse adviseResponse(ChatResponse response, Map<String, Object> c
73104
* @return the advised {@link ChatResponse} flux.
74105
*/
75106
default Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
107+
108+
if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) {
109+
return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, context));
110+
}
111+
else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) {
112+
return new MessageAggregator().aggregate(fluxResponse, chatResponse -> {
113+
this.adviseResponse(chatResponse, context);
114+
});
115+
}
116+
else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) {
117+
return fluxResponse.map(chatResponse -> {
118+
boolean withFinishReason = chatResponse.getResults()
119+
.stream()
120+
.filter(result -> result != null && result.getMetadata() != null
121+
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
122+
.findFirst()
123+
.isPresent();
124+
125+
if (withFinishReason) {
126+
return this.adviseResponse(chatResponse, context);
127+
}
128+
return chatResponse;
129+
});
130+
}
131+
76132
return fluxResponse;
77133
}
78134

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int
5959
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
6060
}
6161

62+
@Override
63+
public StreamResponseMode getStreamResponseMode() {
64+
return StreamResponseMode.AGGREGATE;
65+
}
66+
6267
protected T getChatMemoryStore() {
6368
return this.chatMemoryStore;
6469
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,11 @@
2020
import java.util.List;
2121
import java.util.Map;
2222

23-
import reactor.core.publisher.Flux;
24-
2523
import org.springframework.ai.chat.client.AdvisedRequest;
2624
import org.springframework.ai.chat.memory.ChatMemory;
2725
import org.springframework.ai.chat.messages.Message;
2826
import org.springframework.ai.chat.messages.UserMessage;
2927
import org.springframework.ai.chat.model.ChatResponse;
30-
import org.springframework.ai.chat.model.MessageAggregator;
3128

3229
/**
3330
* Memory is retrieved added as a collection of messages to the prompt
@@ -79,17 +76,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map<String, Object
7976
return chatResponse;
8077
}
8178

82-
@Override
83-
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {
84-
85-
return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> {
86-
List<Message> assistantMessages = chatResponse.getResults()
87-
.stream()
88-
.map(g -> (Message) g.getOutput())
89-
.toList();
90-
91-
this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages);
92-
});
93-
}
94-
9579
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,13 @@
2121
import java.util.Map;
2222
import java.util.stream.Collectors;
2323

24-
import org.springframework.ai.model.Content;
25-
import reactor.core.publisher.Flux;
26-
2724
import org.springframework.ai.chat.client.AdvisedRequest;
2825
import org.springframework.ai.chat.memory.ChatMemory;
2926
import org.springframework.ai.chat.messages.Message;
3027
import org.springframework.ai.chat.messages.MessageType;
3128
import org.springframework.ai.chat.messages.UserMessage;
3229
import org.springframework.ai.chat.model.ChatResponse;
33-
import org.springframework.ai.chat.model.MessageAggregator;
30+
import org.springframework.ai.model.Content;
3431

3532
/**
3633
* Memory is retrieved added into the prompt's system text.
@@ -109,17 +106,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map<String, Object
109106
return chatResponse;
110107
}
111108

112-
@Override
113-
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {
114-
115-
return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> {
116-
List<Message> assistantMessages = chatResponse.getResults()
117-
.stream()
118-
.map(g -> (Message) g.getOutput())
119-
.toList();
120-
121-
this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages);
122-
});
123-
}
124-
125109
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
import org.springframework.util.Assert;
3434
import org.springframework.util.StringUtils;
3535

36-
import reactor.core.publisher.Flux;
37-
3836
/**
3937
* Context for the question is retrieved from a Vector Store and added to the prompt's
4038
* user text.
@@ -132,15 +130,6 @@ public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> co
132130
return chatResponseBuilder.build();
133131
}
134132

135-
@Override
136-
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
137-
return fluxResponse.map(cr -> {
138-
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
139-
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
140-
return chatResponseBuilder.build();
141-
});
142-
}
143-
144133
protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
145134

146135
if (!context.containsKey(FILTER_EXPRESSION)
@@ -151,4 +140,9 @@ protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
151140

152141
}
153142

143+
@Override
144+
public StreamResponseMode getStreamResponseMode() {
145+
return StreamResponseMode.ON_FINISH_REASON;
146+
}
147+
154148
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,8 @@
2323
import org.springframework.ai.chat.client.AdvisedRequest;
2424
import org.springframework.ai.chat.client.RequestResponseAdvisor;
2525
import org.springframework.ai.chat.model.ChatResponse;
26-
import org.springframework.ai.chat.model.MessageAggregator;
2726
import org.springframework.ai.model.ModelOptionsUtils;
2827

29-
import reactor.core.publisher.Flux;
30-
3128
/**
3229
* A simple logger advisor that logs the request and response messages.
3330
*
@@ -65,12 +62,6 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
6562
return request;
6663
}
6764

68-
@Override
69-
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {
70-
return new MessageAggregator().aggregate(fluxChatResponse,
71-
chatResponse -> logger.debug("stream response: {}", this.responseToString.apply(chatResponse)));
72-
}
73-
7465
@Override
7566
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
7667
logger.debug("response: {}", this.responseToString.apply(response));
@@ -82,4 +73,9 @@ public String toString() {
8273
return SimpleLoggerAdvisor.class.getSimpleName();
8374
}
8475

76+
@Override
77+
public StreamResponseMode getStreamResponseMode() {
78+
return StreamResponseMode.AGGREGATE;
79+
}
80+
8581
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,12 @@
2121
import java.util.Map;
2222
import java.util.stream.Collectors;
2323

24-
import org.springframework.ai.chat.messages.AssistantMessage;
25-
import reactor.core.publisher.Flux;
26-
2724
import org.springframework.ai.chat.client.AdvisedRequest;
25+
import org.springframework.ai.chat.messages.AssistantMessage;
2826
import org.springframework.ai.chat.messages.Message;
2927
import org.springframework.ai.chat.messages.MessageType;
3028
import org.springframework.ai.chat.messages.UserMessage;
3129
import org.springframework.ai.chat.model.ChatResponse;
32-
import org.springframework.ai.chat.model.MessageAggregator;
3330
import org.springframework.ai.document.Document;
3431
import org.springframework.ai.model.Content;
3532
import org.springframework.ai.vectorstore.SearchRequest;
@@ -120,19 +117,6 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map<String, Object
120117
return chatResponse;
121118
}
122119

123-
@Override
124-
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {
125-
126-
return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> {
127-
List<Message> assistantMessages = chatResponse.getResults()
128-
.stream()
129-
.map(g -> (Message) g.getOutput())
130-
.toList();
131-
132-
this.getChatMemoryStore().write(toDocuments(assistantMessages, this.doGetConversationId(context)));
133-
});
134-
}
135-
136120
private List<Document> toDocuments(List<Message> messages, String conversationId) {
137121

138122
List<Document> docs = messages.stream()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.chat.client.advisor.observation;
17+
18+
import java.util.Map;
19+
20+
import org.springframework.ai.chat.client.AdvisedRequest;
21+
import org.springframework.ai.chat.client.RequestResponseAdvisor;
22+
import org.springframework.ai.chat.client.RequestResponseAdvisor.StreamResponseMode;
23+
import org.springframework.ai.chat.model.ChatResponse;
24+
import org.springframework.ai.chat.model.MessageAggregator;
25+
import org.springframework.util.StringUtils;
26+
27+
import io.micrometer.observation.Observation;
28+
import reactor.core.publisher.Flux;
29+
30+
/**
31+
* @author Christian Tzolov
32+
* @since 1.0.0
33+
*/
34+
public abstract class AdvisorObservableHelper {
35+
36+
private static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();
37+
38+
public static AdvisedRequest adviseRequest(Observation parentObservation, RequestResponseAdvisor advisor,
39+
AdvisedRequest advisedRequest, Map<String, Object> advisorContext) {
40+
41+
var observationContext = AdvisorObservationContext.builder()
42+
.withAdvisorName(advisor.getName())
43+
.withAdvisorType(AdvisorObservationContext.Type.BEFORE)
44+
.withAdvisedRequest(advisedRequest)
45+
.withAdvisorRequestContext(advisorContext)
46+
.build();
47+
48+
return AdvisorObservationDocumentation.AI_ADVISOR
49+
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
50+
parentObservation.getObservationRegistry())
51+
.parentObservation(parentObservation)
52+
.observe(() -> advisor.adviseRequest(advisedRequest, advisorContext));
53+
}
54+
55+
public static ChatResponse adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor,
56+
ChatResponse response, Map<String, Object> advisorContext) {
57+
58+
var observationContext = AdvisorObservationContext.builder()
59+
.withAdvisorName(advisor.getName())
60+
.withAdvisorType(AdvisorObservationContext.Type.AFTER)
61+
.withAdvisorRequestContext(advisorContext)
62+
.build();
63+
64+
return AdvisorObservationDocumentation.AI_ADVISOR
65+
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
66+
parentObservation.getObservationRegistry())
67+
.parentObservation(parentObservation)
68+
.observe(() -> advisor.adviseResponse(response, advisorContext));
69+
}
70+
71+
public static Flux<ChatResponse> adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor,
72+
Flux<ChatResponse> fluxResponse, Map<String, Object> advisorContext) {
73+
74+
if (advisor.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) {
75+
return fluxResponse
76+
.map(chatResponse -> adviseResponse(parentObservation, advisor, chatResponse, advisorContext));
77+
}
78+
else if (advisor.getStreamResponseMode() == StreamResponseMode.AGGREGATE) {
79+
return new MessageAggregator().aggregate(fluxResponse, chatResponse -> {
80+
adviseResponse(parentObservation, advisor, chatResponse, advisorContext);
81+
});
82+
}
83+
else if (advisor.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) {
84+
return fluxResponse.map(chatResponse -> {
85+
boolean withFinishReason = chatResponse.getResults()
86+
.stream()
87+
.filter(result -> result != null && result.getMetadata() != null
88+
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
89+
.findFirst()
90+
.isPresent();
91+
92+
if (withFinishReason) {
93+
return adviseResponse(parentObservation, advisor, chatResponse, advisorContext);
94+
}
95+
return chatResponse;
96+
});
97+
}
98+
99+
return advisor.adviseResponse(fluxResponse, advisorContext);
100+
}
101+
102+
}

0 commit comments

Comments
 (0)