diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java index 7a64dc1feda..6d8cb41d17c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java @@ -34,7 +34,6 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -45,6 +44,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 46326fc8b12..4f299bd5b36 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -16,6 +16,8 @@ package org.springframework.ai.openai.chat; +import static org.assertj.core.api.Assertions.assertThat; + import java.util.List; import java.util.Map; import java.util.function.Function; @@ -26,11 +28,10 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -47,7 +48,7 @@ import org.springframework.context.annotation.Description; import org.springframework.core.ParameterizedTypeReference; -import static org.assertj.core.api.Assertions.assertThat; +import reactor.core.publisher.Flux; /** * @author Christian Tzolov @@ -64,10 +65,14 @@ public class OpenAiPaymentTransactionIT { record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements RequestResponseAdvisor { + private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor { private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + public String getName() { + return this.getClass().getSimpleName(); + } + @Override public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { logger.info("System text: \n" + request.systemText()); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index 1950fd40b2a..1645152d232 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -30,7 +30,8 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; @@ -64,10 +65,15 @@ public class VertexAiGeminiPaymentTransactionIT { record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements RequestResponseAdvisor { + private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor { private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + @Override public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { logger.info("System text: \n" + request.systemText()); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java index 7be72064797..af05a1cfd33 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java @@ -20,6 +20,7 @@ import java.util.Map; import org.springframework.ai.model.Media; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; @@ -46,7 +47,7 @@ */ public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions, List media, List functionNames, List functionCallbacks, List messages, - Map userParams, Map systemParams, List advisors, + Map userParams, Map systemParams, List advisors, Map advisorParams) { public static Builder from(AdvisedRequest from) { @@ -92,7 +93,7 @@ public static class Builder { private Map systemParams = Map.of(); - private List advisors = List.of(); + private List advisors = List.of(); private Map advisorParams = Map.of(); @@ -146,7 +147,7 @@ public Builder withSystemParams(Map systemParams) { return this; } - public Builder withAdvisors(List advisors) { + public Builder withAdvisors(List advisors) { this.advisors = advisors; return this; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 418f532f6ef..6dc66514fda 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.function.Consumer; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; @@ -122,9 +123,9 @@ interface AdvisorSpec { AdvisorSpec params(Map p); - AdvisorSpec advisors(RequestResponseAdvisor... advisors); + AdvisorSpec advisors(Advisor... advisors); - AdvisorSpec advisors(List advisors); + AdvisorSpec advisors(List advisors); } @@ -192,9 +193,9 @@ interface ChatClientRequestSpec { ChatClientRequestSpec advisors(Consumer consumer); - ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors); + ChatClientRequestSpec advisors(Advisor... advisors); - ChatClientRequestSpec advisors(List advisors); + ChatClientRequestSpec advisors(List advisors); ChatClientRequestSpec messages(Message... messages); @@ -237,11 +238,11 @@ ChatClientRequestSpec function(String name, String description, Class */ interface Builder { - Builder defaultAdvisors(RequestResponseAdvisor... advisor); + Builder defaultAdvisors(Advisor... advisor); Builder defaultAdvisors(Consumer advisorSpecConsumer); - Builder defaultAdvisors(List advisors); + Builder defaultAdvisors(List advisors); Builder defaultOptions(ChatOptions chatOptions); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index c18fdeace56..767f6d26ae2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -28,6 +28,14 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; +import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor.StreamResponseMode; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; @@ -38,6 +46,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; @@ -232,7 +241,7 @@ protected Map params() { public static class DefaultAdvisorSpec implements AdvisorSpec { - private final List advisors = new ArrayList<>(); + private final List advisors = new ArrayList<>(); private final Map params = new HashMap<>(); @@ -246,17 +255,17 @@ public AdvisorSpec params(Map p) { return this; } - public AdvisorSpec advisors(RequestResponseAdvisor... advisors) { + public AdvisorSpec advisors(Advisor... advisors) { this.advisors.addAll(List.of(advisors)); return this; } - public AdvisorSpec advisors(List advisors) { + public AdvisorSpec advisors(List advisors) { this.advisors.addAll(advisors); return this; } - public List getAdvisors() { + public List getAdvisors() { return advisors; } @@ -270,10 +279,7 @@ public static class DefaultCallResponseSpec implements CallResponseSpec { private final DefaultChatClientRequestSpec request; - private final ChatModel chatModel; - - public DefaultCallResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) { - this.chatModel = chatModel; + public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) { this.request = request; } @@ -330,8 +336,8 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in formatParam, false); var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( - inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, - () -> observationContext, inputRequest.observationRegistry); + inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, + () -> observationContext, inputRequest.getObservationRegistry()); return observation.observe(() -> { ChatResponse chatResponse = doGetChatResponse(inputRequest, formatParam, observation); return chatResponse; @@ -342,35 +348,37 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam, Observation parentObservation) { - Map context = new ConcurrentHashMap<>(); - context.putAll(inputRequestSpec.getAdvisorParams()); + Map advisorContext = new ConcurrentHashMap<>(); + if (StringUtils.hasText(formatParam)) { + advisorContext.put("formatParam", formatParam); + } + advisorContext.putAll(inputRequestSpec.getAdvisorParams()); - DefaultChatClientRequestSpec advisedRequestSpec = inputRequestSpec; + // DefaultChatClientRequestSpec advisedRequestSpec = inputRequestSpec; + AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec); if (!CollectionUtils.isEmpty(inputRequestSpec.advisors)) { - AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec); - - // apply the advisors onRequest - var currentAdvisors = new ArrayList<>(inputRequestSpec.advisors); - for (RequestResponseAdvisor advisor : currentAdvisors) { + // Apply the Request advisors + var currentAdvisors = new ArrayList<>( + AdvisorObservableHelper.extractRequestAdvisors(inputRequestSpec.advisors)); + for (RequestAdvisor advisor : currentAdvisors) { advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, advisor, advisedRequest, - context); + advisorContext); } - advisedRequestSpec = toDefaultChatClientRequestSpec(advisedRequest, - inputRequestSpec.getObservationRegistry(), inputRequestSpec.getCustomObservationConvention()); } - var prompt = toPrompt(advisedRequestSpec, formatParam); + // Apply the around advisor chain that terminates with the, last, model call + // advisor. + ChatResponse advisedResponse = inputRequestSpec.aroundAdvisorChain.nextAroundCall(advisedRequest, + advisorContext); - var chatResponse = this.chatModel.call(prompt); - - ChatResponse advisedResponse = chatResponse; - // apply the advisors on response + // Apply the Response advisors. if (!CollectionUtils.isEmpty(inputRequestSpec.getAdvisors())) { - var currentAdvisors = new ArrayList<>(inputRequestSpec.getAdvisors()); - for (RequestResponseAdvisor advisor : currentAdvisors) { + var currentAdvisors = new ArrayList<>( + AdvisorObservableHelper.extractResponseAdvisors(inputRequestSpec.getAdvisors())); + for (ResponseAdvisor advisor : currentAdvisors) { advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - advisedResponse, context); + advisedResponse, advisorContext); } } @@ -387,55 +395,51 @@ public String content() { } - static Prompt toPrompt(DefaultChatClientRequestSpec advisedRequest, String formatParam) { + private static Prompt toPrompt(AdvisedRequest advisedRequest, String formatParam) { - var messages = new ArrayList(advisedRequest.getMessages()); + var messages = new ArrayList(advisedRequest.messages()); - String processedSystemText = advisedRequest.getSystemText(); + String processedSystemText = advisedRequest.systemText(); if (StringUtils.hasText(processedSystemText)) { - if (!CollectionUtils.isEmpty(advisedRequest.getSystemParams())) { - processedSystemText = new PromptTemplate(processedSystemText, advisedRequest.getSystemParams()) - .render(); + if (!CollectionUtils.isEmpty(advisedRequest.systemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, advisedRequest.systemParams()).render(); } messages.add(new SystemMessage(processedSystemText)); } var processedUserText = StringUtils.hasText(formatParam) - ? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}" - : advisedRequest.getUserText(); + ? advisedRequest.userText() + System.lineSeparator() + "{spring_ai_soc_format}" + : advisedRequest.userText(); if (StringUtils.hasText(processedUserText)) { - Map userParams = new HashMap<>(advisedRequest.getUserParams()); + Map userParams = new HashMap<>(advisedRequest.userParams()); if (StringUtils.hasText(formatParam)) { userParams.put("spring_ai_soc_format", formatParam); } if (!CollectionUtils.isEmpty(userParams)) { processedUserText = new PromptTemplate(processedUserText, userParams).render(); } - messages.add(new UserMessage(processedUserText, advisedRequest.getMedia())); + messages.add(new UserMessage(processedUserText, advisedRequest.media())); } - if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.getFunctionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); + if (advisedRequest.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!advisedRequest.functionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames())); } - if (!advisedRequest.getFunctionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); + if (!advisedRequest.functionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks()); } } - return new Prompt(messages, advisedRequest.getChatOptions()); + return new Prompt(messages, advisedRequest.chatOptions()); } public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final DefaultChatClientRequestSpec request; - private final ChatModel chatModel; - - public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) { - this.chatModel = chatModel; + public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) { this.request = request; } @@ -446,20 +450,20 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ true); Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( - inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, - () -> observationContext, inputRequest.observationRegistry); + inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, + () -> observationContext, inputRequest.getObservationRegistry()); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) .start(); // @formatter:off - return doGetFluxChatResponse(inputRequest, observation) - .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - // @formatter:on + return doGetFluxChatResponse(inputRequest, observation) + .doOnError(observation::error) + .doFinally(s -> { + observation.stop(); + }) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on }); } @@ -473,33 +477,93 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in var reqWithContext = new AdvisedRequestWithContext(toAdvisedRequest(inputRequest), advisorContext); - return Flux.fromIterable(inputRequest.advisors) + return Flux.fromIterable(AdvisorObservableHelper.extractRequestAdvisors(inputRequest.advisors)) .transformDeferredContextual((f, ctx) -> f // This allows us to call blocking code in reduce .publishOn(Schedulers.boundedElastic()) .reduce(reqWithContext, (rwc, advisor) -> { + // Apply the Request advisors AdvisedRequest advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, advisor, rwc.request, rwc.advisorContext); return new AdvisedRequestWithContext(advisedRequest, rwc.advisorContext); })) .single() .flatMapMany(rwc -> { - DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(rwc.request, - inputRequest.getObservationRegistry(), inputRequest.getCustomObservationConvention()); - - var prompt = toPrompt(advisedRequest, null); - Flux fluxChatResponse = this.chatModel.stream(prompt); + // Apply the around advisor chain that terminates with the, last, + // model call advisor. + Flux advisedResponse = inputRequest.aroundAdvisorChain.nextAroundStream(rwc.request, + rwc.advisorContext); - Flux advisedResponse = fluxChatResponse; - // apply the advisors on response + // Apply the Response advisors if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { - var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); - for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - advisedResponse, advisorContext); + + var responseAdvisors = new ArrayList<>( + AdvisorObservableHelper.extractResponseAdvisors(inputRequest.getAdvisors())); + + List perElementResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getStreamResponseMode() == StreamResponseMode.PER_ELEMENT) + .toList(); + + List onFinishElementResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getStreamResponseMode() == StreamResponseMode.ON_FINISH_ELEMENT) + .toList(); + + // PER_ELEMENT and ON_FINISH_ELEMENT + advisedResponse = advisedResponse.map(response -> { + // PER_ELEMENT + if (!CollectionUtils.isEmpty(perElementResponseAdvisors)) { + for (ResponseAdvisor advisor : perElementResponseAdvisors) { + response = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + response, rwc.advisorContext); + } + } + // ON_FINISH_ELEMENT + if (!CollectionUtils.isEmpty(onFinishElementResponseAdvisors)) { + for (ResponseAdvisor advisor : onFinishElementResponseAdvisors) { + boolean withFinishReason = response.getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + + if (withFinishReason) { + response = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + response, advisorContext); + } + } + } + return response; + }); + + // CUSTOM + // TODO: how to pass the parentObservation to the custom response + // advisor? + List customResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getStreamResponseMode() == StreamResponseMode.CUSTOM) + .toList(); + if (!CollectionUtils.isEmpty(customResponseAdvisors)) { + for (ResponseAdvisor advisor : customResponseAdvisors) { + advisedResponse = advisor.adviseResponse(advisedResponse, rwc.advisorContext); + } + } + + // AGGREGATE + List aggregateResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getStreamResponseMode() == StreamResponseMode.AGGREGATE) + .toList(); + + if (!CollectionUtils.isEmpty(aggregateResponseAdvisors)) { + advisedResponse = new MessageAggregator().aggregate(advisedResponse, chatResponse -> { + for (ResponseAdvisor advisor : aggregateResponseAdvisors) { + AdvisorObservableHelper.adviseResponse(parentObservation, advisor, chatResponse, + advisorContext); + } + }); } } + return advisedResponse; }); @@ -547,52 +611,58 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final Map systemParams = new HashMap<>(); - private final List advisors = new ArrayList<>(); + private final List advisors = new ArrayList<>(); private final Map advisorParams = new HashMap<>(); + private final DefaultAroundAdvisorChain aroundAdvisorChain; + + public AroundAdvisorChain getAroundAdvisorChain() { + return this.aroundAdvisorChain; + } + private ObservationRegistry getObservationRegistry() { - return observationRegistry; + return this.observationRegistry; } private ChatClientObservationConvention getCustomObservationConvention() { - return customObservationConvention; + return this.customObservationConvention; } public String getUserText() { - return userText; + return this.userText; } public Map getUserParams() { - return userParams; + return this.userParams; } public String getSystemText() { - return systemText; + return this.systemText; } public Map getSystemParams() { - return systemParams; + return this.systemParams; } public ChatOptions getChatOptions() { - return chatOptions; + return this.chatOptions; } - public List getAdvisors() { - return advisors; + public List getAdvisors() { + return this.advisors; } public Map getAdvisorParams() { - return advisorParams; + return this.advisorParams; } public List getMessages() { - return messages; + return this.messages; } public List getMedia() { - return media; + return this.media; } public List getFunctionNames() { @@ -600,7 +670,7 @@ public List getFunctionNames() { } public List getFunctionCallbacks() { - return functionCallbacks; + return this.functionCallbacks; } /* copy constructor */ @@ -613,8 +683,8 @@ public List getFunctionCallbacks() { public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map userParams, String systemText, Map systemParams, List functionCallbacks, List messages, List functionNames, List media, ChatOptions chatOptions, - List advisors, Map advisorParams, - ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { + List advisors, Map advisorParams, ObservationRegistry observationRegistry, + ChatClientObservationConvention customObservationConvention) { this.chatModel = chatModel; this.chatOptions = chatOptions != null ? chatOptions.copy() @@ -633,6 +703,40 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map adviceContext, + AroundAdvisorChain chain) { + String formatParam = (String) adviceContext.get("formatParam"); + return chatModel.call(toPrompt(advisedRequest, formatParam)); + } + }) + .push(new StreamAroundAdvisor() { + + @Override + public String getName() { + return StreamAroundAdvisor.class.getSimpleName(); + } + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, + AroundAdvisorChain chain) { + return chatModel.stream(toPrompt(advisedRequest, null)); + } + }) + .pushAll(this.advisors) + .build(); + // @formatter:on } /** @@ -662,18 +766,21 @@ public ChatClientRequestSpec advisors(Consumer consumer) consumer.accept(as); this.advisorParams.putAll(as.getParams()); this.advisors.addAll(as.getAdvisors()); + this.aroundAdvisorChain.pushAll(as.getAdvisors()); return this; } - public ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors) { + public ChatClientRequestSpec advisors(Advisor... advisors) { Assert.notNull(advisors, "the advisors must be non-null"); this.advisors.addAll(Arrays.asList(advisors)); + this.aroundAdvisorChain.pushAll(Arrays.asList(advisors)); return this; } - public ChatClientRequestSpec advisors(List advisors) { + public ChatClientRequestSpec advisors(List advisors) { Assert.notNull(advisors, "the advisors must be non-null"); this.advisors.addAll(advisors); + this.aroundAdvisorChain.pushAll(advisors); return this; } @@ -797,11 +904,11 @@ public ChatClientRequestSpec user(Consumer consumer) { } public CallResponseSpec call() { - return new DefaultCallResponseSpec(chatModel, this); + return new DefaultCallResponseSpec(this); } public StreamResponseSpec stream() { - return new DefaultStreamResponseSpec(chatModel, this); + return new DefaultStreamResponseSpec(this); } } @@ -906,4 +1013,4 @@ public StreamPromptResponseSpec stream() { } -} +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index cd9c7a9fa0d..4f22ad2c800 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -26,6 +26,7 @@ import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec; import org.springframework.ai.chat.client.ChatClient.PromptUserSpec; import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; @@ -69,7 +70,7 @@ public ChatClient build() { return new DefaultChatClient(this.chatModel, this.defaultRequest); } - public Builder defaultAdvisors(RequestResponseAdvisor... advisor) { + public Builder defaultAdvisors(Advisor... advisor) { this.defaultRequest.advisors(advisor); return this; } @@ -79,7 +80,7 @@ public Builder defaultAdvisors(Consumer advisorSpecConsu return this; } - public Builder defaultAdvisors(List advisors) { + public Builder defaultAdvisors(List advisors) { this.defaultRequest.advisors(advisors); return this; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index 498287f4771..ab1ca3352d5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -16,120 +16,41 @@ package org.springframework.ai.chat.client; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; + import java.util.Map; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.util.StringUtils; - -import reactor.core.publisher.Flux; /** * Advisor called before and after the {@link ChatModel#call(Prompt)} and * {@link ChatModel#stream(Prompt)} methods calls. The {@link ChatClient} maintains a - * chain of advisors with chared execution context. + * chain of advisors with shared advise context. * + * @deprecated since 1.0.0 please use {@link RequestAdvisor}, {@link ResponseAdvisor} + * instead. * @author Christian Tzolov * @since 1.0.0 */ -public interface RequestResponseAdvisor { - - public enum StreamResponseMode { - - /** - * The sync advisor will be called for each response chunk (e.g. on each Flux - * item). - */ - PER_CHUNK, - /** - * The sync advisor is called only on chunks that contain a finish reason. Usually - * the last chunk in the stream. - */ - ON_FINISH_REASON, - /** - * The sync advisor is called only once after the stream is completed and an - * aggregated response is computed. Note that at that stage the advisor can not - * modify the response, but only observe it and react on the aggregated response. - */ - AGGREGATE, - /** - * Delegates to the stream advisor implementation. - */ - CUSTOM; +@Deprecated +public interface RequestResponseAdvisor extends RequestAdvisor, ResponseAdvisor { - } - - default StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.CUSTOM; - } - - /** - * @return the advisor name. - */ + @Override default String getName() { return this.getClass().getSimpleName(); } - /** - * @param request the {@link AdvisedRequest} data to be advised. Represents the row - * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. - * @param context the shared data between the advisors in the chain. It is shared - * between all request and response advising points of all advisors in the chain. - * @return the advised {@link AdvisedRequest}. - */ - default AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + @Override + default AdvisedRequest adviseRequest(AdvisedRequest request, Map adviseContext) { return request; } - /** - * @param response the {@link ChatResponse} data to be advised. Represents the row - * {@link ChatResponse} data after the {@link ChatModel#call(Prompt)} method is - * called. - * @param context the shared data between the advisors in the chain. It is shared - * between all request and response advising points of all advisors in the chain. - * @return the advised {@link ChatResponse}. - */ - default ChatResponse adviseResponse(ChatResponse response, Map context) { + @Override + default ChatResponse adviseResponse(ChatResponse response, Map adviseContext) { return response; } - /** - * @param fluxResponse the streaming {@link ChatResponse} data to be advised. - * Represents the row {@link ChatResponse} stream data after the - * {@link ChatModel#stream(Prompt)} method is called. - * @param context the shared data between the advisors in the chain. It is shared - * between all request and response advising points of all advisors in the chain. - * @return the advised {@link ChatResponse} flux. - */ - default Flux adviseResponse(Flux fluxResponse, Map context) { - - if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { - return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, context)); - } - else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { - return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { - this.adviseResponse(chatResponse, context); - }); - } - else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { - return fluxResponse.map(chatResponse -> { - boolean withFinishReason = chatResponse.getResults() - .stream() - .filter(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())) - .findFirst() - .isPresent(); - - if (withFinishReason) { - return this.adviseResponse(chatResponse, context); - } - return chatResponse; - }); - } - - return fluxResponse; - } - -} +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 85b354ce04c..0a47f269871 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -18,7 +18,8 @@ import java.util.Map; -import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; import org.springframework.util.Assert; /** @@ -28,7 +29,7 @@ * @author Christian Tzolov * @since 1.0.0 M1 */ -public abstract class AbstractChatMemoryAdvisor implements RequestResponseAdvisor { +public abstract class AbstractChatMemoryAdvisor implements RequestAdvisor, ResponseAdvisor { public static final String CHAT_MEMORY_CONVERSATION_ID_KEY = "chat_memory_conversation_id"; @@ -60,8 +61,8 @@ public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int } @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.AGGREGATE; + public String getName() { + return this.getClass().getSimpleName(); } protected T getChatMemoryStore() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java new file mode 100644 index 00000000000..a3b56b75f11 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -0,0 +1,158 @@ +package org.springframework.ai.chat.client.advisor; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.util.Assert; + +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + +public class DefaultAroundAdvisorChain implements AroundAdvisorChain { + + private final Deque callAroundAdvisors; + + private final Deque streamAroundAdvisors; + + private final ObservationRegistry observationRegistry; + + public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry) { + this(observationRegistry, new ArrayDeque(), new ArrayDeque()); + } + + public DefaultAroundAdvisorChain(CallAroundAdvisor aroundAdvisor, ObservationRegistry observationRegistry) { + this(observationRegistry, new ArrayDeque(), new ArrayDeque()); + this.push(aroundAdvisor); + } + + public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, + Deque callAroundAdvisors, Deque streamAroundAdvisors) { + Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null"); + this.observationRegistry = observationRegistry; + this.callAroundAdvisors = callAroundAdvisors; + this.streamAroundAdvisors = streamAroundAdvisors; + } + + public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, List advisors) { + this(observationRegistry); + Assert.notNull(advisors, "the advisors must be non-null"); + advisors.forEach(this::push); + } + + public void pushAll(List advisors) { + Assert.notNull(advisors, "the advisors must be non-null"); + advisors.forEach(this::push); + } + + public void push(Advisor aroundAdvisor) { + + Assert.notNull(aroundAdvisor, "the aroundAdvisor must be non-null"); + + if (aroundAdvisor instanceof CallAroundAdvisor callAroundAdvisor) { + this.callAroundAdvisors.push(callAroundAdvisor); + } + // Note: the advisor can implement both the CallAroundAdvisor and + // StreamAroundAdvisor. + if (aroundAdvisor instanceof StreamAroundAdvisor streamAroundAdvisor) { + this.streamAroundAdvisors.push(streamAroundAdvisor); + } + } + + @Override + public ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map adviceContext) { + + if (this.callAroundAdvisors.isEmpty()) { + throw new IllegalStateException("No AroundAdvisor available to execute"); + } + + var advisor = this.callAroundAdvisors.pop(); + + var observationContext = AdvisorObservationContext.builder() + .withAdvisorName(advisor.getName()) + .withAdvisorType(AdvisorObservationContext.Type.AROUND) + .withAdvisedRequest(advisedRequest) + .withAdvisorRequestContext(adviceContext) + .build(); + + return AdvisorObservationDocumentation.AI_ADVISOR + .observation(null, AdvisorObservableHelper.DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> advisor.aroundCall(advisedRequest, adviceContext, this)); + } + + @Override + public Flux nextAroundStream(AdvisedRequest advisedRequest, Map adviceContext) { + + return Flux.deferContextual(contextView -> { + + if (this.streamAroundAdvisors.isEmpty()) { + return Flux.error(new IllegalStateException("No AroundAdvisor available to execute")); + } + + var advisor = this.streamAroundAdvisors.pop(); + + AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .withAdvisorName(advisor.getName()) + .withAdvisorType(AdvisorObservationContext.Type.AROUND) + .withAdvisedRequest(advisedRequest) + .withAdvisorRequestContext(adviceContext) + .build(); + + var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null, + AdvisorObservableHelper.DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + return advisor.aroundStream(advisedRequest, adviceContext, this) + .doOnError(observation::error) + .doFinally(s -> { + observation.stop(); + }) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + }); + } + + public static Builder builder(ObservationRegistry observationRegistry) { + return new Builder(observationRegistry); + } + + public static class Builder { + + private final DefaultAroundAdvisorChain aroundAdvisorChain; + + public Builder(ObservationRegistry observationRegistry) { + this.aroundAdvisorChain = new DefaultAroundAdvisorChain(observationRegistry); + } + + public Builder push(Advisor aroundAdvisor) { + Assert.notNull(aroundAdvisor, "the aroundAdvisor must be non-null"); + this.aroundAdvisorChain.push(aroundAdvisor); + return this; + } + + public Builder pushAll(List aroundAdvisors) { + Assert.notNull(aroundAdvisors, "the aroundAdvisors must be non-null"); + this.aroundAdvisorChain.pushAll(aroundAdvisors); + return this; + } + + public DefaultAroundAdvisorChain build() { + return this.aroundAdvisorChain; + } + + } + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index 2aeacd0fe34..a0b7c87de65 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -22,7 +22,8 @@ import java.util.stream.Collectors; import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; @@ -40,7 +41,7 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class QuestionAnswerAdvisor implements RequestResponseAdvisor { +public class QuestionAnswerAdvisor implements RequestAdvisor, ResponseAdvisor { private static final String DEFAULT_USER_TEXT_ADVISE = """ Context information is below. @@ -91,6 +92,11 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques this.userTextAdvise = userTextAdvise; } + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + @Override public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { @@ -140,9 +146,4 @@ protected Filter.Expression doGetFilterExpression(Map context) { } - @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.ON_FINISH_REASON; - } - -} +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index dc21c3fcc7a..6a1a0b8480b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -21,7 +21,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.ModelOptionsUtils; @@ -30,7 +31,7 @@ * * @author Christian Tzolov */ -public class SimpleLoggerAdvisor implements RequestResponseAdvisor { +public class SimpleLoggerAdvisor implements RequestAdvisor, ResponseAdvisor { private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); @@ -56,6 +57,11 @@ public SimpleLoggerAdvisor(Function requestToString, this.responseToString = responseToString; } + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + @Override public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { logger.debug("request: {}", this.requestToString.apply(request)); @@ -73,9 +79,4 @@ public String toString() { return SimpleLoggerAdvisor.class.getSimpleName(); } - @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.AGGREGATE; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java new file mode 100644 index 00000000000..1ae2ee6d104 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java @@ -0,0 +1,35 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; + +/** + * Parent advisor interface for all advisors. + * + * @author Christian Tzolov + * @since 1.0.0 + * @see {@link RequestAdvisor}, {@link ResponseAdvisor}, {@link CallAroundAdvisor}, + * {@link StreamAroundAdvisor}, {@link DefaultAroundAdvisorChain} + */ +public interface Advisor { + + /** + * @return the advisor name. + */ + String getName(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java new file mode 100644 index 00000000000..f33a686c24d --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java @@ -0,0 +1,16 @@ +package org.springframework.ai.chat.client.advisor.api; + +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.model.ChatResponse; + +import reactor.core.publisher.Flux; + +public interface AroundAdvisorChain { + + ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map adviceContext); + + Flux nextAroundStream(AdvisedRequest advisedRequest, Map adviceContext); + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java new file mode 100644 index 00000000000..d22017cd1de --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java @@ -0,0 +1,39 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.api; + +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.model.ChatResponse; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ + +public interface CallAroundAdvisor extends Advisor { + + /** + * Around advice that wraps the {@link ChatModel#call(Prompt)} method. + * @param advisedRequest the advised request + * @param adviceContext the advice context + * @param chain the advisor chain + * @return the response + */ + ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, AroundAdvisorChain chain); + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java new file mode 100644 index 00000000000..8ba198e323a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; + +/** + * Advisor called before the {@link ChatModel#call(Prompt)} and + * {@link ChatModel#stream(Prompt)} methods are called. The {@link ChatClient} maintains a + * chain of advisors with shared advise context. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public interface RequestAdvisor extends Advisor { + + /** + * @param request the {@link AdvisedRequest} data to be advised. Represents the row + * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. + * @param adviseContext the shared data between the advisors in the chain. It is + * shared between all request and response advising points of all advisors in the + * chain. + * @return the advised {@link AdvisedRequest}. + */ + AdvisedRequest adviseRequest(AdvisedRequest request, Map adviseContext); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java new file mode 100644 index 00000000000..6bd7728de87 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import java.util.Map; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; + +import reactor.core.publisher.Flux; + +/** + * Advisor called after the {@link ChatModel#call(Prompt)} (or + * {@link ChatModel#stream(Prompt)}) method call. The {@link ChatClient} maintains a chain + * of advisors with shared advise context. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public interface ResponseAdvisor extends Advisor { + + /** + * @param response the {@link ChatResponse} data to be advised. Represents the row + * {@link ChatResponse} data after the {@link ChatModel#call(Prompt)} method is + * called. + * @param adviseContext the shared data between the advisors in the chain. It is + * shared between all request and response advising points of all advisors in the + * chain. + * @return the advised {@link ChatResponse}. + */ + ChatResponse adviseResponse(ChatResponse response, Map adviseContext); + + /** + * Different modes of advising the streaming responses. + */ + public enum StreamResponseMode { + + /** + * Called for each response element in the Flux. The response advisor can modify + * the elements before they are returned to the client. + */ + PER_ELEMENT, + /** + * Called only on Flux elements that contain a finish reason. Usually the last + * element in the Flux. The response advisor can modify the elements before they + * are returned to the client. + */ + ON_FINISH_ELEMENT, + /** + * Called only once after all Flux elements have been consumed. All elements are + * merged into a single ChatResponse element and provided to the response advisor + * to process.
+ * Mind that at that stage the response advisor can not longer modify the response + * returned to the client. + */ + AGGREGATE, + /** + * Delegates to the stream advisor implementation. + */ + CUSTOM; + + } + + default StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.AGGREGATE; + } + + /** + * @param fluxResponse the streaming {@link ChatResponse} data to be advised. + * Represents the row {@link ChatResponse} stream data after the + * {@link ChatModel#stream(Prompt)} method is called. + * @param adviseContext the shared data between the advisors in the chain. It is + * shared between all request and response advising points of all advisors in the + * chain. + * @return the advised {@link ChatResponse} flux. + */ + default Flux adviseResponse(Flux fluxResponse, Map adviseContext) { + return fluxResponse; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java new file mode 100644 index 00000000000..d2379a7724a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java @@ -0,0 +1,41 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.api; + +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.model.ChatResponse; + +import reactor.core.publisher.Flux; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public interface StreamAroundAdvisor extends Advisor { + + /** + * Around advice that wraps the invocation of the advised request. + * @param advisedRequest + * @param adviceContext + * @param chain the chain of advisors to execute + * @return the result of the advised request + */ + Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, + AroundAdvisorChain chain); + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java new file mode 100644 index 00000000000..51dcc6b5ec9 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java @@ -0,0 +1,149 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.around; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.util.CollectionUtils; + +import reactor.core.publisher.Flux; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public class CacheAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + private final VectorStore vectorStore; + + private static final String DOCUMENT_METADATA_ADVISOR_CACHE_TAG = "advisorCacheDocument"; + + private static final String DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE = "advisorCacheResponse"; + + public CacheAroundAdvisor(VectorStore vectorStore) { + this.vectorStore = vectorStore; + } + + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, + AroundAdvisorChain chain) { + + var cachedResponseOption = getCacheEntry(advisedRequest, adviceContext); + if (cachedResponseOption.isPresent()) { + return cachedResponseOption.get(); + } + + ChatResponse chatResponse = chain.nextAroundCall(advisedRequest, adviceContext); + + saveCacheEntry(advisedRequest.userText(), chatResponse); + + return chatResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, + AroundAdvisorChain chain) { + + var cachedResponseOption = getCacheEntry(advisedRequest, adviceContext); + if (cachedResponseOption.isPresent()) { + return Flux.just(cachedResponseOption.get()); + } + + Flux fluxChatResponse = chain.nextAroundStream(advisedRequest, adviceContext); + + return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { + saveCacheEntry(advisedRequest.userText(), chatResponse); + }); + } + + private void saveCacheEntry(String userQuestion, ChatResponse chatResponse) { + List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); + if (!CollectionUtils.isEmpty(assistantMessages)) { + this.vectorStore.add(toDocuments(userQuestion, assistantMessages)); + } + } + + private Optional getCacheEntry(AdvisedRequest advisedRequest, Map adviceContext) { + + // TODO: convert into pompty first or at least materialize the user params. + String userText = advisedRequest.userText(); + + // @formatter:off + var searchRequest = SearchRequest.query(userText) + .withSimilarityThreshold(0.95) + .withTopK(1) + .withFilterExpression("'"+ DOCUMENT_METADATA_ADVISOR_CACHE_TAG + "' == 'true'"); + // @formatter:on + + List doc = vectorStore.similaritySearch(searchRequest); + + // return cached response + return CollectionUtils.isEmpty(doc) ? Optional.empty() : Optional.of(fromDocument(doc.get(0))); + } + + private ChatResponse fromDocument(Document doc) { + + if (!doc.getMetadata().containsKey(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE)) { + throw new IllegalStateException("The document is missing the cache response metadata!"); + } + String cachedResponse = "" + doc.getMetadata().get(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE); + + return ChatResponse.builder() + .withGenerations(List.of(new Generation(new AssistantMessage(cachedResponse, Map.of()), + ChatGenerationMetadata.from("STOP", null)))) + .build(); + } + + private List toDocuments(String userQuestion, List messages) { + + List docs = messages.stream() + .filter(m -> m.getMessageType() == MessageType.ASSISTANT) + .map(message -> { + var metadata = new HashMap<>(message.getMetadata() != null ? message.getMetadata() : new HashMap<>()); + metadata.put(DOCUMENT_METADATA_ADVISOR_CACHE_TAG, "true"); + metadata.put(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE, message.getContent()); + // TODO: Pehaps we need to serialize the message metadata to the document + + return new Document(userQuestion, metadata); + + }) + .toList(); + + return docs; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java new file mode 100644 index 00000000000..f682ceb1b70 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java @@ -0,0 +1,59 @@ +package org.springframework.ai.chat.client.advisor.around; + +import java.util.List; +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.util.CollectionUtils; + +import reactor.core.publisher.Flux; + +/** + * A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the + * response if the user input contains any of the sensitive words. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public class SafeGuardAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + private final List sensitiveWords; + + public SafeGuardAroundAdvisor(List sensitiveWords) { + this.sensitiveWords = sensitiveWords; + } + + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, + AroundAdvisorChain chain) { + + if (!CollectionUtils.isEmpty(this.sensitiveWords) + && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + return ChatResponse.builder().withGenerations(List.of()).build(); + } + + return chain.nextAroundCall(advisedRequest, adviceContext); + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, + AroundAdvisorChain chain) { + + if (!CollectionUtils.isEmpty(this.sensitiveWords) + && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + return Flux.empty(); + } + + return chain.nextAroundStream(advisedRequest, adviceContext); + + } + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java index f903bac15e8..e267a5ee2fe 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java @@ -15,17 +15,19 @@ */ package org.springframework.ai.chat.client.advisor.observation; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Map; import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.RequestResponseAdvisor; -import org.springframework.ai.chat.client.RequestResponseAdvisor.StreamResponseMode; +import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; -import org.springframework.util.StringUtils; +import org.springframework.util.CollectionUtils; import io.micrometer.observation.Observation; -import reactor.core.publisher.Flux; /** * @author Christian Tzolov @@ -33,9 +35,9 @@ */ public abstract class AdvisorObservableHelper { - private static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); + public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); - public static AdvisedRequest adviseRequest(Observation parentObservation, RequestResponseAdvisor advisor, + public static AdvisedRequest adviseRequest(Observation parentObservation, RequestAdvisor advisor, AdvisedRequest advisedRequest, Map advisorContext) { var observationContext = AdvisorObservationContext.builder() @@ -52,7 +54,7 @@ public static AdvisedRequest adviseRequest(Observation parentObservation, Reques .observe(() -> advisor.adviseRequest(advisedRequest, advisorContext)); } - public static ChatResponse adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor, + public static ChatResponse adviseResponse(Observation parentObservation, ResponseAdvisor advisor, ChatResponse response, Map advisorContext) { var observationContext = AdvisorObservationContext.builder() @@ -68,35 +70,34 @@ public static ChatResponse adviseResponse(Observation parentObservation, Request .observe(() -> advisor.adviseResponse(response, advisorContext)); } - public static Flux adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor, - Flux fluxResponse, Map advisorContext) { + public static List extractRequestAdvisors(List advisors) { + return advisors.stream() + .filter(advisor -> advisor instanceof RequestAdvisor) + .map(a -> (RequestAdvisor) a) + .toList(); + } - if (advisor.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { - return fluxResponse - .map(chatResponse -> adviseResponse(parentObservation, advisor, chatResponse, advisorContext)); - } - else if (advisor.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { - return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { - adviseResponse(parentObservation, advisor, chatResponse, advisorContext); - }); - } - else if (advisor.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { - return fluxResponse.map(chatResponse -> { - boolean withFinishReason = chatResponse.getResults() - .stream() - .filter(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())) - .findFirst() - .isPresent(); - - if (withFinishReason) { - return adviseResponse(parentObservation, advisor, chatResponse, advisorContext); - } - return chatResponse; - }); + /** + * Extracts the {@link ResponseAdvisor} instances from the given list of advisors and + * returns them in reverse order. + * @param advisors list of all registered advisor types. + * @return the list of {@link ResponseAdvisor} instances in reverse order. + */ + public static List extractResponseAdvisors(List advisors) { + + var list = advisors.stream() + .filter(advisor -> advisor instanceof ResponseAdvisor) + .map(a -> (ResponseAdvisor) a) + .toList(); + + // reverse the list + if (CollectionUtils.isEmpty(list)) { + return list; } - return advisor.adviseResponse(fluxResponse, advisorContext); + var reversedList = new ArrayList<>(list); + Collections.reverse(reversedList); + return Collections.unmodifiableList(reversedList); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 27e9a12d0cb..7067139d6a3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -25,12 +25,14 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.function.FunctionCallback; import io.micrometer.common.KeyValue; @@ -93,6 +95,17 @@ static RequestResponseAdvisor dummyAdvisor(String name) { public String getName() { return name; } + + @Override + public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + return request; + } + + @Override + public ChatResponse adviseResponse(ChatResponse response, Map adviseContext) { + return response; + } + }; }