Skip to content

Commit 6fc76b7

Browse files
tzolovchemicL
authored andcommitted
Refactor advisor architecture in Spring AI
This commit introduces a major overhaul of the advisor system in Spring AI, improving modularity, type safety, and consistency Core Changes: - Replace RequestAdvisor and ResponseAdvisor with CallAroundAdvisor and StreamAroundAdvisor - Introduce AdvisedRequest and AdvisedResponse classes for better encapsulation - Deprecate RequestResponseAdvisor in favor of new advisor types - Remove AdvisorObservableHelper class Advisor Implementation Updates: - Update AbstractChatMemoryAdvisor, MessageChatMemoryAdvisor, PromptChatMemoryAdvisor, QuestionAnswerAdvisor, SafeGuardAroundAdvisor, SimpleLoggerAdvisor, and VectorStoreChatMemoryAdvisor to implement new advisor interfaces - Remove CacheAroundAdvisor (functionality likely moved elsewhere) - Make CallAroundAdvisor and StreamAroundAdvisor extend Ordered interface Client and Chain Management: - Modify DefaultChatClient to use new advisor chain approach - Refactor DefaultAroundAdvisorChain for better ordering and observation - Implement builder pattern for advisor chain construction in DefaultChatClient - Separate call and stream advisors in DefaultAroundAdvisorChain Observation and Context Handling: - Update observation conventions and context handling in advisors - Add order field to AdvisorObservationContext - Modify DefaultAdvisorObservationConvention to include order in high cardinality key values Testing and Integration: - Refactor ChatClientAdvisorTests and add new AdvisorsTests - Update integration tests to reflect new advisor structure - Enhance AdvisorsTests to verify correct advisor execution order New Features: - Generalize the Protect From Blocking functionality across all advisors - Add (experimental) Re2 advisor to enhance reasoning capabilities of LLMs - Add disabled Re2 test in OpenAiChatClientIT Documentation: - Add Advisors documentation - Enhance advisors documentation with order explanation and Re2 example Advisor Ordering: - Introduce Advisor constants for precedence ordering - Update AbstractChatMemoryAdvisor to use new precedence constant - Improve advisor ordering and management in DefaultAroundAdvisorChain.Builder - Remove redundant reordering logic from DefaultAroundAdvisorChain These changes aim to provide a more flexible and powerful advisor system, allowing for easier implementation of complex AI-driven interactions Co-authored-by: Dariusz Jędrzejczyk <[email protected]>
1 parent c81972e commit 6fc76b7

File tree

44 files changed

+1799
-1055
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1799
-1055
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
import org.junit.jupiter.params.provider.ValueSource;
2929
import org.slf4j.Logger;
3030
import org.slf4j.LoggerFactory;
31-
import org.springframework.ai.chat.client.AdvisedRequest;
3231
import org.springframework.ai.chat.client.ChatClient;
33-
import org.springframework.ai.chat.client.advisor.api.RequestAdvisor;
34-
import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor;
35-
import org.springframework.ai.chat.model.ChatResponse;
32+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
33+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
34+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
35+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
3636
import org.springframework.ai.converter.BeanOutputConverter;
3737
import org.springframework.ai.model.function.FunctionCallbackContext;
3838
import org.springframework.ai.openai.OpenAiChatModel;
@@ -65,7 +65,7 @@ public class OpenAiPaymentTransactionIT {
6565
record TransactionStatusResponse(String id, String status) {
6666
}
6767

68-
private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor {
68+
private static class LoggingAdvisor implements CallAroundAdvisor {
6969

7070
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
7171

@@ -74,7 +74,23 @@ public String getName() {
7474
}
7575

7676
@Override
77-
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
77+
public int getOrder() {
78+
return 0;
79+
}
80+
81+
@Override
82+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
83+
84+
advisedRequest = this.before(advisedRequest);
85+
86+
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
87+
88+
this.observeAfter(advisedResponse);
89+
90+
return advisedResponse;
91+
}
92+
93+
private AdvisedRequest before(AdvisedRequest request) {
7894
logger.info("System text: \n" + request.systemText());
7995
logger.info("System params: " + request.systemParams());
8096
logger.info("User text: \n" + request.userText());
@@ -86,10 +102,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
86102
return request;
87103
}
88104

89-
@Override
90-
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
91-
logger.info("Response: " + response);
92-
return response;
105+
private void observeAfter(AdvisedResponse advisedResponse) {
106+
logger.info("Response: " + advisedResponse.response());
93107
}
94108

95109
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Map;
2323
import java.util.stream.Collectors;
2424

25+
import org.junit.jupiter.api.Disabled;
2526
import org.junit.jupiter.api.Test;
2627
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2728
import org.junit.jupiter.params.ParameterizedTest;
@@ -64,6 +65,38 @@ class OpenAiChatClientIT extends AbstractIT {
6465
record ActorsFilms(String actor, List<String> movies) {
6566
}
6667

68+
@Test
69+
@Disabled("Although the Re2 advisor improves the response correctness it is not always guarantied to work.")
70+
void re2() {
71+
// .user(" Could Scooby Doo fit in a Kangaroo Pouch? Choices: (A) Yes (B) No")
72+
// .user("Roger has 5 tennis balls. He buys 2 more cans of tennis " +
73+
// "balls. Each can has 3 tennis balls. How many tennis balls " +
74+
// "does he have now?")
75+
76+
String REASON_QUESTION = """
77+
What do these words have in common?
78+
Freight Stone Often Canine.
79+
""";
80+
81+
// @formatter:off
82+
ChatClient chatClient = ChatClient.builder(chatModel)
83+
.defaultOptions(OpenAiChatOptions.builder()
84+
.withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())
85+
.defaultUser(REASON_QUESTION)
86+
.build();
87+
88+
String response = chatClient.prompt()
89+
.advisors(new ReReadingAdvisor())
90+
.call()
91+
.content();
92+
// @formatter:on
93+
94+
logger.info("" + response);
95+
assertThat(response.toLowerCase().replace("(", " ").replace(")", " ").replace("\"", " ").replace("\"", " "))
96+
.contains(" eight", " one", " ten", " nine");
97+
98+
}
99+
67100
@Test
68101
void call() {
69102

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai.chat.client;
17+
18+
import java.util.HashMap;
19+
import java.util.Map;
20+
21+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
22+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
23+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
24+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
25+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
26+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
27+
28+
import reactor.core.publisher.Flux;
29+
30+
/**
31+
* Drawing inspiration from the human strategy of re-reading, this advisor implements a
32+
* re-reading strategy for LLM reasoning, dubbed RE2, to enhance understanding in the
33+
* input phase. Based on the article:
34+
* <a href="https://arxiv.org/pdf/2309.06275">Re-Reading Improves Reasoning in Large
35+
* Language Models</a>
36+
*
37+
* @author Christian Tzolov
38+
* @since 1.0.0
39+
*/
40+
public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
41+
42+
private static final String DEFAULT_RE2_ADVISE_TEMPLATE = """
43+
{re2_input_query}
44+
Read the question again: {re2_input_query}
45+
""";
46+
47+
private final String re2AdviseTemplate;
48+
49+
private int order = 0;
50+
51+
public ReReadingAdvisor() {
52+
this(DEFAULT_RE2_ADVISE_TEMPLATE);
53+
}
54+
55+
public ReReadingAdvisor(String re2AdviseTemplate) {
56+
this.re2AdviseTemplate = re2AdviseTemplate;
57+
}
58+
59+
public String getName() {
60+
return this.getClass().getSimpleName();
61+
}
62+
63+
private AdvisedRequest before(AdvisedRequest advisedRequest) {
64+
65+
Map<String, Object> advisedUserParams = new HashMap<>(advisedRequest.userParams());
66+
advisedUserParams.put("re2_input_query", advisedRequest.userText());
67+
68+
return AdvisedRequest.from(advisedRequest)
69+
.withUserText(this.re2AdviseTemplate)
70+
.withUserParams(advisedUserParams)
71+
.build();
72+
}
73+
74+
@Override
75+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
76+
return chain.nextAroundCall(this.before(advisedRequest));
77+
}
78+
79+
@Override
80+
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
81+
return chain.nextAroundStream(this.before(advisedRequest));
82+
}
83+
84+
@Override
85+
public int getOrder() {
86+
return this.order;
87+
}
88+
89+
public ReReadingAdvisor withOrder(int order) {
90+
this.order = order;
91+
return this;
92+
}
93+
94+
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2929
import org.slf4j.Logger;
3030
import org.slf4j.LoggerFactory;
31-
import org.springframework.ai.chat.client.AdvisedRequest;
3231
import org.springframework.ai.chat.client.ChatClient;
33-
import org.springframework.ai.chat.client.advisor.api.RequestAdvisor;
34-
import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor;
35-
import org.springframework.ai.chat.model.ChatResponse;
32+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
33+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
34+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
35+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
3636
import org.springframework.ai.model.function.FunctionCallbackContext;
3737
import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType;
3838
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
@@ -65,7 +65,7 @@ public class VertexAiGeminiPaymentTransactionIT {
6565
record TransactionStatusResponse(String id, String status) {
6666
}
6767

68-
private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor {
68+
private static class LoggingAdvisor implements CallAroundAdvisor {
6969

7070
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
7171

@@ -75,7 +75,18 @@ public String getName() {
7575
}
7676

7777
@Override
78-
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
78+
public int getOrder() {
79+
return 0;
80+
}
81+
82+
@Override
83+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
84+
var response = chain.nextAroundCall(before(advisedRequest));
85+
observeAfter(response);
86+
return response;
87+
}
88+
89+
private AdvisedRequest before(AdvisedRequest request) {
7990
logger.info("System text: \n" + request.systemText());
8091
logger.info("System params: " + request.systemParams());
8192
logger.info("User text: \n" + request.userText());
@@ -87,10 +98,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
8798
return request;
8899
}
89100

90-
@Override
91-
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
92-
logger.info("Response: " + response);
93-
return response;
101+
private void observeAfter(AdvisedResponse advisedResponse) {
102+
logger.info("Response: " + advisedResponse.response());
94103
}
95104

96105
}

0 commit comments

Comments
 (0)