From 764a85f88683a246ce97d98b0d8dbb6edd5e8a46 Mon Sep 17 00:00:00 2001 From: starkt Date: Sat, 13 Jul 2024 00:31:22 +0200 Subject: [PATCH 1/6] fix sequential calling for azure --- .../ai/azure/openai/AzureOpenAiChatModel.java | 3 +++ .../AzureOpenAiChatModelFunctionCallIT.java | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index b883458481e..c96edcd2a24 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -486,6 +486,9 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom if (fromOptions.getResponseFormat() != null) { mergedOptions.setResponseFormat(fromOptions.getResponseFormat()); } + if (fromOptions.getTools() != null) { + mergedOptions.setTools(fromOptions.getTools()); + } return mergedOptions; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index f4349e17211..392289a649b 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -85,6 +85,32 @@ void functionCallTest() { assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); } + + @Test + void functionCallSequentialTest() { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); + } + @Test void streamFunctionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); From 33c197f2115038d2286841884edbe6e57a13eaf4 Mon Sep 17 00:00:00 2001 From: starkt Date: Sat, 13 Jul 2024 00:44:58 +0200 Subject: [PATCH 2/6] fix formatting --- .../AzureOpenAiChatModelFunctionCallIT.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 392289a649b..90fb5983336 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -85,22 +85,22 @@ void functionCallTest() { assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); } - @Test void functionCallSequentialTest() { - UserMessage userMessage = new UserMessage("What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); From e4918424ce10011761fe93b513387c8f8f8dffcf Mon Sep 17 00:00:00 2001 From: starkt Date: Sat, 13 Jul 2024 01:40:20 +0200 Subject: [PATCH 3/6] fix for streaming --- .../ai/azure/openai/AzureOpenAiChatModel.java | 156 +++++++++++------- .../AzureOpenAiChatModelFunctionCallIT.java | 101 ++++++++---- 2 files changed, 162 insertions(+), 95 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index c96edcd2a24..12d6d4e06cb 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -44,7 +44,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; @@ -59,13 +61,9 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -104,9 +102,9 @@ public class AzureOpenAiChatModel extends public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) { this(microsoftOpenAiClient, AzureOpenAiChatOptions.builder() - .withDeploymentName(DEFAULT_DEPLOYMENT_NAME) - .withTemperature(DEFAULT_TEMPERATURE) - .build()); + .withDeploymentName(DEFAULT_DEPLOYMENT_NAME) + .withTemperature(DEFAULT_TEMPERATURE) + .build()); } public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) { @@ -114,7 +112,7 @@ public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatO } public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options, - FunctionCallbackContext functionCallbackContext) { + FunctionCallbackContext functionCallbackContext) { super(functionCallbackContext); Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(options, "AzureOpenAiChatOptions must not be null"); @@ -148,9 +146,9 @@ public ChatResponse call(Prompt prompt) { logger.trace("Azure ChatCompletions: {}", chatCompletions); List generations = nullSafeList(chatCompletions.getChoices()).stream() - .map(choice -> new Generation(choice.getMessage().getContent()) - .withGenerationMetadata(generateChoiceMetadata(choice))) - .toList(); + .map(choice -> new Generation(choice.getMessage().getContent()) + .withGenerationMetadata(generateChoiceMetadata(choice))) + .toList(); PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); @@ -160,51 +158,81 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { - ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); + + // we have to map with a custom function to handle the tool call requests + // due to the existing bugs in the azure api (see comments in streamWithAzureApi) + // we have to recursively call this specific method for tool calls instead of using the one from the AbstractFunctionCallSupport + return streamWithAzureOpenAi(options).flatMapIterable(ChatCompletions::getChoices) + .map(choice -> { + var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); + var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); + return new ChatResponse(List.of(generation)); + }); + } + + private Flux streamWithAzureOpenAi(ChatCompletionsOptions options) { options.setStream(true); IterableStream chatCompletionsStream = this.openAIClient - .getChatCompletionsStream(options.getModel(), options); + .getChatCompletionsStream(options.getModel(), options); Flux chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream); final var isFunctionCall = new AtomicBoolean(false); final var accessibleChatCompletionsFlux = chatCompletionsFlux - // Note: the first chat completions can be ignored when using Azure OpenAI - // service which is a known service bug. - .skip(1) - .map(chatCompletions -> { - final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); - isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); - return chatCompletions; - }) - .windowUntil(chatCompletions -> { - if (isFunctionCall.get() && chatCompletions.getChoices() - .get(0) - .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { - isFunctionCall.set(false); - return true; - } - return !isFunctionCall.get(); - }) - .concatMapIterable(window -> { - final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions); - return List.of(reduce); - }) - .flatMap(mono -> mono); + // Note: the first chat completions can be ignored when using Azure OpenAI + // service which is a known service bug. + .skip(1) + .map(chatCompletions -> { + final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); + isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); + return chatCompletions; + }) + .windowUntil(chatCompletions -> { + if (isFunctionCall.get() && chatCompletions.getChoices() + .get(0) + .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { + isFunctionCall.set(false); + return true; + } + return !isFunctionCall.get(); + }) + .concatMapIterable(window -> { + final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions); + return List.of(reduce); + }) + .flatMap(mono -> mono); return accessibleChatCompletionsFlux - .switchMap(accessibleChatCompletions -> handleFunctionCallOrReturnStream(options, - Flux.just(accessibleChatCompletions))) - .flatMapIterable(ChatCompletions::getChoices) - .map(choice -> { - var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); - var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); - return new ChatResponse(List.of(generation)); - }); + .switchMap(accessibleChatCompletions -> handleToolCallRequests(options, + Flux.just(accessibleChatCompletions))); } + private Flux handleToolCallRequests(ChatCompletionsOptions request, Flux response) { + return response.switchMap(resp -> { + if (!this.isToolFunctionCall(resp)) { + return Mono.just(resp); + } + + // The chat completion tool call requires the complete conversation + // history. Including the initial user message. + List conversationHistory = new ArrayList<>(); + + conversationHistory.addAll(this.doGetUserMessages(request)); + + ChatRequestMessage responseMessage = this.doGetToolResponseMessage(resp); + + // Add the assistant response to the message conversation history. + conversationHistory.add(responseMessage); + + ChatCompletionsOptions newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory); + + // recursively go backwards and call our stream again (including all bug fixes / workarounds for the azure api) + return this.streamWithAzureOpenAi(newRequest); + }); + } + /** * Test access. */ @@ -213,9 +241,9 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { Set functionsForThisRequest = new HashSet<>(); List azureMessages = prompt.getInstructions() - .stream() - .map(this::fromSpringAiMessage) - .toList(); + .stream() + .map(this::fromSpringAiMessage) + .toList(); ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); @@ -250,8 +278,8 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { if (!CollectionUtils.isEmpty(functionsForThisRequest)) { List tools = this.getFunctionTools(functionsForThisRequest); List tools2 = tools.stream() - .map(t -> ((ChatCompletionsToolDefinition) t)) - .toList(); + .map(t -> ((ChatCompletionsToolDefinition) t)) + .toList(); options.setTools(tools2); } @@ -264,7 +292,7 @@ private List getFunctionTools(Set FunctionDefinition functionDefinition = new FunctionDefinition(functionCallback.getName()); functionDefinition.setDescription(functionCallback.getDescription()); BinaryData parameters = BinaryData - .fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())); + .fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())); functionDefinition.setParameters(parameters); return new ChatCompletionsFunctionToolDefinition(functionDefinition); }).toList(); @@ -279,10 +307,10 @@ private ChatRequestMessage fromSpringAiMessage(Message message) { items.add(new ChatMessageTextContentItem(message.getContent())); if (!CollectionUtils.isEmpty(message.getMedia())) { items.addAll(message.getMedia() - .stream() - .map(media -> new ChatMessageImageContentItem( - new ChatMessageImageUrl(media.getData().toString()))) - .toList()); + .stream() + .map(media -> new ChatMessageImageContentItem( + new ChatMessageImageUrl(media.getData().toString()))) + .toList()); } return new ChatRequestUserMessage(items); case SYSTEM: @@ -305,9 +333,9 @@ private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) { chatCompletions.getPromptFilterResults()); return PromptMetadata.of(promptFilterResults.stream() - .map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(), - promptFilterResult.getContentFilterResults())) - .toList()); + .map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(), + promptFilterResult.getContentFilterResults())) + .toList()); } private List nullSafeList(List list) { @@ -320,7 +348,7 @@ private List nullSafeList(List list) { * {@link ChatCompletionsOptions} instance. */ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, - AzureOpenAiChatOptions toSpringAiOptions) { + AzureOpenAiChatOptions toSpringAiOptions) { if (toSpringAiOptions == null) { return fromAzureOptions; @@ -336,7 +364,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, : toSpringAiOptions.getLogitBias()); mergedAzureOptions - .setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop()); + .setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop()); mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature()); if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) { @@ -366,7 +394,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, mergedAzureOptions.setN(fromAzureOptions.getN() != null ? fromAzureOptions.getN() : toSpringAiOptions.getN()); mergedAzureOptions - .setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser()); + .setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser()); mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel() : toSpringAiOptions.getDeploymentName()); @@ -383,7 +411,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, * @return a new {@link ChatCompletionsOptions} instance. */ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, - ChatCompletionsOptions toAzureOptions) { + ChatCompletionsOptions toAzureOptions) { if (fromSpringAiOptions == null) { return toAzureOptions; @@ -542,7 +570,7 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { @Override protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, - ChatRequestMessage responseMessage, List conversationHistory) { + ChatRequestMessage responseMessage, List conversationHistory) { // Every tool-call item requires a separate function call and a response (TOOL) // message. @@ -579,7 +607,7 @@ protected List doGetUserMessages(ChatCompletionsOptions requ protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) { final var accessibleChatChoice = response.getChoices().get(0); var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()) - .orElse(accessibleChatChoice.getDelta()); + .orElse(accessibleChatChoice.getDelta()); ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage(""); final var toolCalls = responseMessage.getToolCalls(); assistantMessage.setToolCalls(toolCalls.stream().map(tc -> { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 90fb5983336..f3d8cff0b3d 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -17,6 +17,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -68,13 +69,13 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); @@ -94,13 +95,13 @@ void functionCallSequentialTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); @@ -111,6 +112,44 @@ void functionCallSequentialTest() { assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); } + @Test + void functionCallSequentialAndStreamTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + var response = chatModel.stream(new Prompt(messages, promptOptions)); + + final var counter = new AtomicInteger(); + String content = response.doOnEach(listSignal -> counter.getAndIncrement()) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + + logger.info("Response: {}", response); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + } + @Test void streamFunctionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); @@ -118,26 +157,26 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); Flux response = chatModel.stream(new Prompt(messages, promptOptions)); final var counter = new AtomicInteger(); String content = response.doOnEach(listSignal -> counter.getAndIncrement()) - .collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); logger.info("Response: {}", content); assertThat(counter.get()).isGreaterThan(30).as("The response should be chunked in more than 30 messages"); @@ -153,8 +192,8 @@ public static class TestConfiguration { @Bean public OpenAIClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) - .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .buildClient(); } @Bean From e8415e00000cf2d7c20affd10864cd37dfcb78b4 Mon Sep 17 00:00:00 2001 From: starkt Date: Sat, 13 Jul 2024 01:41:34 +0200 Subject: [PATCH 4/6] fix formatting --- .../ai/azure/openai/AzureOpenAiChatModel.java | 126 +++++++++--------- .../AzureOpenAiChatModelFunctionCallIT.java | 94 ++++++------- 2 files changed, 111 insertions(+), 109 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 12d6d4e06cb..a6d1f66471c 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -102,9 +102,9 @@ public class AzureOpenAiChatModel extends public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) { this(microsoftOpenAiClient, AzureOpenAiChatOptions.builder() - .withDeploymentName(DEFAULT_DEPLOYMENT_NAME) - .withTemperature(DEFAULT_TEMPERATURE) - .build()); + .withDeploymentName(DEFAULT_DEPLOYMENT_NAME) + .withTemperature(DEFAULT_TEMPERATURE) + .build()); } public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) { @@ -112,7 +112,7 @@ public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatO } public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options, - FunctionCallbackContext functionCallbackContext) { + FunctionCallbackContext functionCallbackContext) { super(functionCallbackContext); Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(options, "AzureOpenAiChatOptions must not be null"); @@ -146,9 +146,9 @@ public ChatResponse call(Prompt prompt) { logger.trace("Azure ChatCompletions: {}", chatCompletions); List generations = nullSafeList(chatCompletions.getChoices()).stream() - .map(choice -> new Generation(choice.getMessage().getContent()) - .withGenerationMetadata(generateChoiceMetadata(choice))) - .toList(); + .map(choice -> new Generation(choice.getMessage().getContent()) + .withGenerationMetadata(generateChoiceMetadata(choice))) + .toList(); PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); @@ -162,54 +162,54 @@ public Flux stream(Prompt prompt) { // we have to map with a custom function to handle the tool call requests // due to the existing bugs in the azure api (see comments in streamWithAzureApi) - // we have to recursively call this specific method for tool calls instead of using the one from the AbstractFunctionCallSupport - return streamWithAzureOpenAi(options).flatMapIterable(ChatCompletions::getChoices) - .map(choice -> { - var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); - var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); - return new ChatResponse(List.of(generation)); - }); + // we have to recursively call this specific method for tool calls instead of + // using the one from the AbstractFunctionCallSupport + return streamWithAzureOpenAi(options).flatMapIterable(ChatCompletions::getChoices).map(choice -> { + var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); + var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); + return new ChatResponse(List.of(generation)); + }); } private Flux streamWithAzureOpenAi(ChatCompletionsOptions options) { options.setStream(true); IterableStream chatCompletionsStream = this.openAIClient - .getChatCompletionsStream(options.getModel(), options); + .getChatCompletionsStream(options.getModel(), options); Flux chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream); final var isFunctionCall = new AtomicBoolean(false); final var accessibleChatCompletionsFlux = chatCompletionsFlux - // Note: the first chat completions can be ignored when using Azure OpenAI - // service which is a known service bug. - .skip(1) - .map(chatCompletions -> { - final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); - isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); - return chatCompletions; - }) - .windowUntil(chatCompletions -> { - if (isFunctionCall.get() && chatCompletions.getChoices() - .get(0) - .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { - isFunctionCall.set(false); - return true; - } - return !isFunctionCall.get(); - }) - .concatMapIterable(window -> { - final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions); - return List.of(reduce); - }) - .flatMap(mono -> mono); - return accessibleChatCompletionsFlux - .switchMap(accessibleChatCompletions -> handleToolCallRequests(options, - Flux.just(accessibleChatCompletions))); + // Note: the first chat completions can be ignored when using Azure OpenAI + // service which is a known service bug. + .skip(1) + .map(chatCompletions -> { + final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); + isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); + return chatCompletions; + }) + .windowUntil(chatCompletions -> { + if (isFunctionCall.get() && chatCompletions.getChoices() + .get(0) + .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { + isFunctionCall.set(false); + return true; + } + return !isFunctionCall.get(); + }) + .concatMapIterable(window -> { + final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions); + return List.of(reduce); + }) + .flatMap(mono -> mono); + return accessibleChatCompletionsFlux.switchMap( + accessibleChatCompletions -> handleToolCallRequests(options, Flux.just(accessibleChatCompletions))); } - private Flux handleToolCallRequests(ChatCompletionsOptions request, Flux response) { + private Flux handleToolCallRequests(ChatCompletionsOptions request, + Flux response) { return response.switchMap(resp -> { if (!this.isToolFunctionCall(resp)) { return Mono.just(resp); @@ -226,9 +226,11 @@ private Flux handleToolCallRequests(ChatCompletionsOptions requ // Add the assistant response to the message conversation history. conversationHistory.add(responseMessage); - ChatCompletionsOptions newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory); + ChatCompletionsOptions newRequest = this.doCreateToolResponseRequest(request, responseMessage, + conversationHistory); - // recursively go backwards and call our stream again (including all bug fixes / workarounds for the azure api) + // recursively go backwards and call our stream again (including all bug fixes + // / workarounds for the azure api) return this.streamWithAzureOpenAi(newRequest); }); } @@ -241,9 +243,9 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { Set functionsForThisRequest = new HashSet<>(); List azureMessages = prompt.getInstructions() - .stream() - .map(this::fromSpringAiMessage) - .toList(); + .stream() + .map(this::fromSpringAiMessage) + .toList(); ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); @@ -278,8 +280,8 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { if (!CollectionUtils.isEmpty(functionsForThisRequest)) { List tools = this.getFunctionTools(functionsForThisRequest); List tools2 = tools.stream() - .map(t -> ((ChatCompletionsToolDefinition) t)) - .toList(); + .map(t -> ((ChatCompletionsToolDefinition) t)) + .toList(); options.setTools(tools2); } @@ -292,7 +294,7 @@ private List getFunctionTools(Set FunctionDefinition functionDefinition = new FunctionDefinition(functionCallback.getName()); functionDefinition.setDescription(functionCallback.getDescription()); BinaryData parameters = BinaryData - .fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())); + .fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())); functionDefinition.setParameters(parameters); return new ChatCompletionsFunctionToolDefinition(functionDefinition); }).toList(); @@ -307,10 +309,10 @@ private ChatRequestMessage fromSpringAiMessage(Message message) { items.add(new ChatMessageTextContentItem(message.getContent())); if (!CollectionUtils.isEmpty(message.getMedia())) { items.addAll(message.getMedia() - .stream() - .map(media -> new ChatMessageImageContentItem( - new ChatMessageImageUrl(media.getData().toString()))) - .toList()); + .stream() + .map(media -> new ChatMessageImageContentItem( + new ChatMessageImageUrl(media.getData().toString()))) + .toList()); } return new ChatRequestUserMessage(items); case SYSTEM: @@ -333,9 +335,9 @@ private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) { chatCompletions.getPromptFilterResults()); return PromptMetadata.of(promptFilterResults.stream() - .map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(), - promptFilterResult.getContentFilterResults())) - .toList()); + .map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(), + promptFilterResult.getContentFilterResults())) + .toList()); } private List nullSafeList(List list) { @@ -348,7 +350,7 @@ private List nullSafeList(List list) { * {@link ChatCompletionsOptions} instance. */ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, - AzureOpenAiChatOptions toSpringAiOptions) { + AzureOpenAiChatOptions toSpringAiOptions) { if (toSpringAiOptions == null) { return fromAzureOptions; @@ -364,7 +366,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, : toSpringAiOptions.getLogitBias()); mergedAzureOptions - .setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop()); + .setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop()); mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature()); if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) { @@ -394,7 +396,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, mergedAzureOptions.setN(fromAzureOptions.getN() != null ? fromAzureOptions.getN() : toSpringAiOptions.getN()); mergedAzureOptions - .setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser()); + .setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser()); mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel() : toSpringAiOptions.getDeploymentName()); @@ -411,7 +413,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, * @return a new {@link ChatCompletionsOptions} instance. */ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, - ChatCompletionsOptions toAzureOptions) { + ChatCompletionsOptions toAzureOptions) { if (fromSpringAiOptions == null) { return toAzureOptions; @@ -570,7 +572,7 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { @Override protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, - ChatRequestMessage responseMessage, List conversationHistory) { + ChatRequestMessage responseMessage, List conversationHistory) { // Every tool-call item requires a separate function call and a response (TOOL) // message. @@ -607,7 +609,7 @@ protected List doGetUserMessages(ChatCompletionsOptions requ protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) { final var accessibleChatChoice = response.getChoices().get(0); var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()) - .orElse(accessibleChatChoice.getDelta()); + .orElse(accessibleChatChoice.getDelta()); ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage(""); final var toolCalls = responseMessage.getToolCalls(); assistantMessage.setToolCalls(toolCalls.stream().map(tc -> { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index f3d8cff0b3d..e8fb7585ba4 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -69,13 +69,13 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); @@ -95,13 +95,13 @@ void functionCallSequentialTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); @@ -121,27 +121,27 @@ void functionCallSequentialAndStreamTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); var response = chatModel.stream(new Prompt(messages, promptOptions)); final var counter = new AtomicInteger(); String content = response.doOnEach(listSignal -> counter.getAndIncrement()) - .collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .filter(Objects::nonNull) - .collect(Collectors.joining()); + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .filter(Objects::nonNull) + .collect(Collectors.joining()); logger.info("Response: {}", response); @@ -157,26 +157,26 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) - .build(); + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); Flux response = chatModel.stream(new Prompt(messages, promptOptions)); final var counter = new AtomicInteger(); String content = response.doOnEach(listSignal -> counter.getAndIncrement()) - .collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); logger.info("Response: {}", content); assertThat(counter.get()).isGreaterThan(30).as("The response should be chunked in more than 30 messages"); @@ -192,8 +192,8 @@ public static class TestConfiguration { @Bean public OpenAIClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) - .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .buildClient(); } @Bean From 748fa2d827e4ca81e84f65ef3287a48edb76a7a2 Mon Sep 17 00:00:00 2001 From: starkt Date: Tue, 16 Jul 2024 11:02:44 +0200 Subject: [PATCH 5/6] refactor AzureAi to OpenAi --- .../ai/azure/openai/AzureOpenAiChatModel.java | 344 +++++++----------- .../AzureOpenAiChatModelFunctionCallIT.java | 14 +- 2 files changed, 135 insertions(+), 223 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index a6d1f66471c..ebe8ba8514f 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -16,31 +16,8 @@ package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.models.ChatChoice; -import com.azure.ai.openai.models.ChatCompletions; -import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; -import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition; -import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsOptions; -import com.azure.ai.openai.models.ChatCompletionsResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsToolCall; -import com.azure.ai.openai.models.ChatCompletionsToolDefinition; -import com.azure.ai.openai.models.ChatMessageContentItem; -import com.azure.ai.openai.models.ChatMessageImageContentItem; -import com.azure.ai.openai.models.ChatMessageImageUrl; -import com.azure.ai.openai.models.ChatMessageTextContentItem; -import com.azure.ai.openai.models.ChatRequestAssistantMessage; -import com.azure.ai.openai.models.ChatRequestMessage; -import com.azure.ai.openai.models.ChatRequestSystemMessage; -import com.azure.ai.openai.models.ChatRequestToolMessage; -import com.azure.ai.openai.models.ChatRequestUserMessage; -import com.azure.ai.openai.models.CompletionsFinishReason; -import com.azure.ai.openai.models.ContentFilterResultsForPrompt; -import com.azure.ai.openai.models.FunctionCall; -import com.azure.ai.openai.models.FunctionDefinition; +import com.azure.ai.openai.models.*; import com.azure.core.util.BinaryData; -import com.azure.core.util.IterableStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; @@ -56,7 +33,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.AbstractToolCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -64,7 +41,7 @@ import reactor.core.publisher.Mono; import java.util.*; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.ConcurrentHashMap; /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by @@ -80,8 +57,7 @@ * @see ChatModel * @see com.azure.ai.openai.OpenAIClient */ -public class AzureOpenAiChatModel extends - AbstractFunctionCallSupport implements ChatModel { +public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel { private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo"; @@ -142,97 +118,140 @@ public ChatResponse call(Prompt prompt) { options.setStream(false); logger.trace("Azure ChatCompletionsOptions: {}", options); - ChatCompletions chatCompletions = this.callWithFunctionSupport(options); - logger.trace("Azure ChatCompletions: {}", chatCompletions); + ChatCompletions chatCompletion = this.openAIClient.getChatCompletions(options.getModel(), options); + logger.trace("Azure ChatCompletions: {}", chatCompletion); - List generations = nullSafeList(chatCompletions.getChoices()).stream() + if (isToolFunctionCall(chatCompletion)) { + List toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), + chatCompletion); + + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions())); + } + + List generations = nullSafeList(chatCompletion.getChoices()).stream() .map(choice -> new Generation(choice.getMessage().getContent()) .withGenerationMetadata(generateChoiceMetadata(choice))) .toList(); - PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); + PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletion); return new ChatResponse(generations, - AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata)); + AzureOpenAiChatResponseMetadata.from(chatCompletion, promptFilterMetadata)); } @Override public Flux stream(Prompt prompt) { + ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); + options.setStream(true); + + Flux completionChunks = Flux + .fromIterable(this.openAIClient.getChatCompletionsStream(options.getModel(), options)); - // we have to map with a custom function to handle the tool call requests - // due to the existing bugs in the azure api (see comments in streamWithAzureApi) - // we have to recursively call this specific method for tool calls instead of - // using the one from the AbstractFunctionCallSupport - return streamWithAzureOpenAi(options).flatMapIterable(ChatCompletions::getChoices).map(choice -> { - var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); - var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); - return new ChatResponse(List.of(generation)); + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + final ChatCompletions[] currentChatCompletion = { MergeUtils.emptyChatCompletions() }; + return completionChunks.filter(this::filterNotRelevantDeltaChunks).switchMap(chatCompletion -> { + currentChatCompletion[0] = MergeUtils.mergeChatCompletions(currentChatCompletion[0], chatCompletion); + + if (this.isToolFunctionCall(currentChatCompletion[0])) { + var toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), + currentChatCompletion[0]); + + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); + } + + // Non function calling. + return Mono.just(chatCompletion).map(chatCompletion2 -> { + try { + @SuppressWarnings("null") + String id = chatCompletion2.getId(); + + List generations = chatCompletion2.getChoices().stream().map(choice -> { + if (choice.getDelta().getRole() != null) { + roleMap.putIfAbsent(id, choice.getDelta().getRole().toString()); + } + String finish = (choice.getFinishReason() != null ? choice.getFinishReason().toString() : ""); + var generation = new Generation(Optional.ofNullable(choice.getDelta().getContent()).orElse(""), + Map.of("id", id, "role", roleMap.getOrDefault(id, ""), "finishReason", finish)); + if (choice.getFinishReason() != null) { + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(choice.getFinishReason().toString(), null)); + } + return generation; + }).toList(); + + if (chatCompletion2.getUsage() != null) { + return new ChatResponse(generations, AzureOpenAiChatResponseMetadata.from(chatCompletion2, + generatePromptMetadata(chatCompletion2))); + } + else { + return new ChatResponse(generations); + } + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + }); }); } - private Flux streamWithAzureOpenAi(ChatCompletionsOptions options) { - options.setStream(true); + private boolean filterNotRelevantDeltaChunks(ChatCompletions chatCompletion) { + if (chatCompletion.getChoices() == null || chatCompletion.getChoices().isEmpty()) { + return false; + } + return chatCompletion.getChoices().stream().anyMatch(choice -> { + if (choice.getFinishReason() != null) { + return true; + } - IterableStream chatCompletionsStream = this.openAIClient - .getChatCompletionsStream(options.getModel(), options); - - Flux chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream); - - final var isFunctionCall = new AtomicBoolean(false); - final var accessibleChatCompletionsFlux = chatCompletionsFlux - // Note: the first chat completions can be ignored when using Azure OpenAI - // service which is a known service bug. - .skip(1) - .map(chatCompletions -> { - final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); - isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); - return chatCompletions; - }) - .windowUntil(chatCompletions -> { - if (isFunctionCall.get() && chatCompletions.getChoices() - .get(0) - .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { - isFunctionCall.set(false); - return true; - } - return !isFunctionCall.get(); - }) - .concatMapIterable(window -> { - final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions); - return List.of(reduce); - }) - .flatMap(mono -> mono); - return accessibleChatCompletionsFlux.switchMap( - accessibleChatCompletions -> handleToolCallRequests(options, Flux.just(accessibleChatCompletions))); + if (choice.getDelta() == null) { + return false; + } + return choice.getDelta().getContent() != null + || (choice.getDelta().getToolCalls() != null && !choice.getDelta().getToolCalls().isEmpty()); + }); } - private Flux handleToolCallRequests(ChatCompletionsOptions request, - Flux response) { - return response.switchMap(resp -> { - if (!this.isToolFunctionCall(resp)) { - return Mono.just(resp); - } + private ChatResponseMessage extractAssistantMessage(ChatCompletions chatCompletion) { + return Optional.ofNullable(chatCompletion.getChoices().iterator().next().getMessage()) + .orElse(chatCompletion.getChoices().iterator().next().getDelta()); + } - // The chat completion tool call requires the complete conversation - // history. Including the initial user message. - List conversationHistory = new ArrayList<>(); + private List handleToolCallRequests(List previousMessages, ChatCompletions chatCompletion) { + ChatResponseMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion); - conversationHistory.addAll(this.doGetUserMessages(request)); + List assistantToolCalls = nativeAssistantMessage.getToolCalls() + .stream() + .filter(x -> x instanceof ChatCompletionsFunctionToolCall) + .map(x -> (ChatCompletionsFunctionToolCall) x) + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.getId(), "function", + toolCall.getFunction().getName(), toolCall.getFunction().getArguments())) + .toList(); - ChatRequestMessage responseMessage = this.doGetToolResponseMessage(resp); + AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.getContent(), Map.of(), + assistantToolCalls); - // Add the assistant response to the message conversation history. - conversationHistory.add(responseMessage); + List toolResponseMessages = this.executeFuncitons(assistantMessage); - ChatCompletionsOptions newRequest = this.doCreateToolResponseRequest(request, responseMessage, - conversationHistory); + // History + List messages = new ArrayList<>(previousMessages); + messages.add(assistantMessage); + messages.addAll(toolResponseMessages); - // recursively go backwards and call our stream again (including all bug fixes - // / workarounds for the azure api) - return this.streamWithAzureOpenAi(newRequest); - }); + return messages; + } + + List toAzureChatMessage(List messages) { + return messages.stream().map(this::fromSpringAiMessage).toList(); } /** @@ -242,10 +261,7 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { Set functionsForThisRequest = new HashSet<>(); - List azureMessages = prompt.getInstructions() - .stream() - .map(this::fromSpringAiMessage) - .toList(); + List azureMessages = toAzureChatMessage(prompt.getInstructions()); ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); @@ -318,7 +334,23 @@ private ChatRequestMessage fromSpringAiMessage(Message message) { case SYSTEM: return new ChatRequestSystemMessage(message.getContent()); case ASSISTANT: - return new ChatRequestAssistantMessage(message.getContent()); + var assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = assistantMessage.getToolCalls() + .stream() + .map(toolCall -> (ChatCompletionsToolCall) new ChatCompletionsFunctionToolCall(toolCall.id(), + new FunctionCall(toolCall.name(), toolCall.arguments()))) + .toList(); + } + var azureAssistantMessage = new ChatRequestAssistantMessage(message.getContent()); + if (toolCalls != null) { + azureAssistantMessage.setToolCalls(toolCalls); + } + return azureAssistantMessage; + case TOOL: + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + return new ChatRequestToolMessage(message.getContent(), toolMessage.getId()); default: throw new IllegalArgumentException("Unknown message type " + message.getMessageType()); } @@ -468,61 +500,6 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, return mergedAzureOptions; } - /** - * Merges the fromOptions into the toOptions and returns a new ChatCompletionsOptions - * instance. - * @param fromOptions the ChatCompletionsOptions to merge from. - * @param toOptions the ChatCompletionsOptions to merge to. - * @return a new ChatCompletionsOptions instance. - */ - private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCompletionsOptions toOptions) { - - if (fromOptions == null) { - return toOptions; - } - - ChatCompletionsOptions mergedOptions = this.copy(toOptions); - - if (fromOptions.getMaxTokens() != null) { - mergedOptions.setMaxTokens(fromOptions.getMaxTokens()); - } - if (fromOptions.getLogitBias() != null) { - mergedOptions.setLogitBias(fromOptions.getLogitBias()); - } - if (fromOptions.getStop() != null) { - mergedOptions.setStop(fromOptions.getStop()); - } - if (fromOptions.getTemperature() != null) { - mergedOptions.setTemperature(fromOptions.getTemperature()); - } - if (fromOptions.getTopP() != null) { - mergedOptions.setTopP(fromOptions.getTopP()); - } - if (fromOptions.getFrequencyPenalty() != null) { - mergedOptions.setFrequencyPenalty(fromOptions.getFrequencyPenalty()); - } - if (fromOptions.getPresencePenalty() != null) { - mergedOptions.setPresencePenalty(fromOptions.getPresencePenalty()); - } - if (fromOptions.getN() != null) { - mergedOptions.setN(fromOptions.getN()); - } - if (fromOptions.getUser() != null) { - mergedOptions.setUser(fromOptions.getUser()); - } - if (fromOptions.getModel() != null) { - mergedOptions.setModel(fromOptions.getModel()); - } - if (fromOptions.getResponseFormat() != null) { - mergedOptions.setResponseFormat(fromOptions.getResponseFormat()); - } - if (fromOptions.getTools() != null) { - mergedOptions.setTools(fromOptions.getTools()); - } - - return mergedOptions; - } - /** * Copy the fromOptions into a new ChatCompletionsOptions instance. * @param fromOptions the ChatCompletionsOptions to copy from. @@ -570,67 +547,6 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { return copyOptions; } - @Override - protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, - ChatRequestMessage responseMessage, List conversationHistory) { - - // Every tool-call item requires a separate function call and a response (TOOL) - // message. - for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) { - - var functionName = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName(); - String functionArguments = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getArguments(); - - if (!this.functionCallbackRegister.containsKey(functionName)) { - throw new IllegalStateException("No function callback found for function name: " + functionName); - } - - String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); - - // Add the function response to the conversation. - conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId())); - } - - // Recursively call chatCompletionWithTools until the model doesn't call a - // functions anymore. - ChatCompletionsOptions newRequest = new ChatCompletionsOptions(conversationHistory); - - newRequest = merge(previousRequest, newRequest); - - return newRequest; - } - - @Override - protected List doGetUserMessages(ChatCompletionsOptions request) { - return request.getMessages(); - } - - @Override - protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) { - final var accessibleChatChoice = response.getChoices().get(0); - var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()) - .orElse(accessibleChatChoice.getDelta()); - ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage(""); - final var toolCalls = responseMessage.getToolCalls(); - assistantMessage.setToolCalls(toolCalls.stream().map(tc -> { - final var tc1 = (ChatCompletionsFunctionToolCall) tc; - var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(), - new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments())); - return ((ChatCompletionsToolCall) toDowncast); - }).toList()); - return assistantMessage; - } - - @Override - protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) { - return this.openAIClient.getChatCompletions(request.getModel(), request); - } - - @Override - protected Flux doChatCompletionStream(ChatCompletionsOptions request) { - return Flux.fromIterable(this.openAIClient.getChatCompletionsStream(request.getModel(), request)); - } - @Override protected boolean isToolFunctionCall(ChatCompletions chatCompletions) { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index e8fb7585ba4..f33bd6abf6a 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -131,23 +131,19 @@ void functionCallSequentialAndStreamTest() { var response = chatModel.stream(new Prompt(messages, promptOptions)); - final var counter = new AtomicInteger(); - String content = response.doOnEach(listSignal -> counter.getAndIncrement()) - .collectList() - .block() - .stream() + List responses = response.collectList().block(); + String stitchedResponseContent = responses.stream() .map(ChatResponse::getResults) .flatMap(List::stream) .map(Generation::getOutput) .map(AssistantMessage::getContent) - .filter(Objects::nonNull) .collect(Collectors.joining()); logger.info("Response: {}", response); - assertThat(content).containsAnyOf("30.0", "30"); - assertThat(content).containsAnyOf("10.0", "10"); - assertThat(content).containsAnyOf("15.0", "15"); + assertThat(stitchedResponseContent).containsAnyOf("30.0", "30"); + assertThat(stitchedResponseContent).containsAnyOf("10.0", "10"); + assertThat(stitchedResponseContent).containsAnyOf("15.0", "15"); } @Test From 2a414b80f8f33839eb7c00fd74d344ad0d136479 Mon Sep 17 00:00:00 2001 From: starkt Date: Tue, 16 Jul 2024 13:00:04 +0200 Subject: [PATCH 6/6] use Async Client for actual streaming support --- .../ai/azure/openai/AzureOpenAiChatModel.java | 121 ++++++++---------- .../openai/AzureOpenAiEmbeddingModel.java | 11 +- .../azure/openai/AzureOpenAiImageModel.java | 9 +- .../AzureChatCompletionsOptionsTests.java | 3 +- .../openai/AzureEmbeddingsOptionsTests.java | 3 +- .../azure/openai/AzureOpenAiChatModelIT.java | 7 +- .../openai/AzureOpenAiEmbeddingModelIT.java | 7 +- .../MockAzureOpenAiTestConfiguration.java | 7 +- .../AzureOpenAiChatModelFunctionCallIT.java | 7 +- .../openai/image/AzureOpenAiImageModelIT.java | 7 +- pom.xml | 2 +- .../pages/api/chat/azure-openai-chat.adoc | 2 +- .../openai/AzureOpenAiAutoConfiguration.java | 19 +-- .../qdrant/QdrantVectorStoreIT.java | 7 +- 14 files changed, 104 insertions(+), 108 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index ebe8ba8514f..36b390b0359 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -15,9 +15,11 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.*; import com.azure.core.util.BinaryData; +import com.azure.core.util.CoreUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; @@ -68,14 +70,14 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport stream(Prompt prompt) { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); options.setStream(true); - Flux completionChunks = Flux - .fromIterable(this.openAIClient.getChatCompletionsStream(options.getModel(), options)); - // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); final ChatCompletions[] currentChatCompletion = { MergeUtils.emptyChatCompletions() }; - return completionChunks.filter(this::filterNotRelevantDeltaChunks).switchMap(chatCompletion -> { - currentChatCompletion[0] = MergeUtils.mergeChatCompletions(currentChatCompletion[0], chatCompletion); - - if (this.isToolFunctionCall(currentChatCompletion[0])) { - var toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), - currentChatCompletion[0]); - - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); - } - - // Non function calling. - return Mono.just(chatCompletion).map(chatCompletion2 -> { - try { - @SuppressWarnings("null") - String id = chatCompletion2.getId(); - - List generations = chatCompletion2.getChoices().stream().map(choice -> { - if (choice.getDelta().getRole() != null) { - roleMap.putIfAbsent(id, choice.getDelta().getRole().toString()); + return this.openAIClient.getChatCompletionsStream(options.getModel(), options) + .filter(chatCompletion -> !CoreUtils.isNullOrEmpty(chatCompletion.getChoices())) // see: + // https://learn.microsoft.com/en-us/java/api/overview/azure/ai-openai-readme?view=azure-java-preview#streaming-chat-completions + .switchMap(chatCompletion -> { + currentChatCompletion[0] = MergeUtils.mergeChatCompletions(currentChatCompletion[0], chatCompletion); + + if (this.isToolFunctionCall(currentChatCompletion[0])) { + var toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), + currentChatCompletion[0]); + + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); + } + + // Non function calling. + return Mono.just(chatCompletion).map(chatCompletion2 -> { + try { + @SuppressWarnings("null") + String id = chatCompletion2.getId(); + + List generations = chatCompletion2.getChoices().stream().map(choice -> { + if (choice.getDelta().getRole() != null) { + roleMap.putIfAbsent(id, choice.getDelta().getRole().toString()); + } + String finish = (choice.getFinishReason() != null ? choice.getFinishReason().toString() + : ""); + var generation = new Generation(choice.getDelta().getContent(), + Map.of("id", id, "role", roleMap.getOrDefault(id, ""), "finishReason", finish)); + if (choice.getFinishReason() != null) { + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(choice.getFinishReason().toString(), null)); + } + return generation; + }).toList(); + + if (chatCompletion2.getUsage() != null) { + return new ChatResponse(generations, AzureOpenAiChatResponseMetadata.from(chatCompletion2, + generatePromptMetadata(chatCompletion2))); } - String finish = (choice.getFinishReason() != null ? choice.getFinishReason().toString() : ""); - var generation = new Generation(Optional.ofNullable(choice.getDelta().getContent()).orElse(""), - Map.of("id", id, "role", roleMap.getOrDefault(id, ""), "finishReason", finish)); - if (choice.getFinishReason() != null) { - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(choice.getFinishReason().toString(), null)); + else { + return new ChatResponse(generations); } - return generation; - }).toList(); - - if (chatCompletion2.getUsage() != null) { - return new ChatResponse(generations, AzureOpenAiChatResponseMetadata.from(chatCompletion2, - generatePromptMetadata(chatCompletion2))); } - else { - return new ChatResponse(generations); + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); } - } - catch (Exception e) { - logger.error("Error processing chat completion", e); - return new ChatResponse(List.of()); - } + }); }); - }); - } - - private boolean filterNotRelevantDeltaChunks(ChatCompletions chatCompletion) { - if (chatCompletion.getChoices() == null || chatCompletion.getChoices().isEmpty()) { - return false; - } - return chatCompletion.getChoices().stream().anyMatch(choice -> { - if (choice.getFinishReason() != null) { - return true; - } - - if (choice.getDelta() == null) { - return false; - } - - return choice.getDelta().getContent() != null - || (choice.getDelta().getToolCalls() != null && !choice.getDelta().getToolCalls().isEmpty()); - }); } private ChatResponseMessage extractAssistantMessage(ChatCompletions chatCompletion) { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java index 625c8e02be9..e53335509ff 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.EmbeddingItem; import com.azure.ai.openai.models.Embeddings; @@ -38,22 +39,22 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel { private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class); - private final OpenAIClient azureOpenAiClient; + private final OpenAIAsyncClient azureOpenAiClient; private final AzureOpenAiEmbeddingOptions defaultOptions; private final MetadataMode metadataMode; - public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient) { + public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient) { this(azureOpenAiClient, MetadataMode.EMBED); } - public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) { + public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient, MetadataMode metadataMode) { this(azureOpenAiClient, metadataMode, AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build()); } - public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode, + public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient, MetadataMode metadataMode, AzureOpenAiEmbeddingOptions options) { Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(metadataMode, "Metadata mode must not be null"); @@ -78,7 +79,7 @@ public EmbeddingResponse call(EmbeddingRequest embeddingRequest) { logger.debug("Retrieving embeddings"); EmbeddingsOptions azureOptions = toEmbeddingOptions(embeddingRequest); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions); + Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions).block(); logger.debug("Embeddings retrieved"); return generateEmbeddingResponse(embeddings); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index e6da1ebbf38..2c35122b8dc 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -1,5 +1,6 @@ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ImageGenerationOptions; import com.azure.ai.openai.models.ImageGenerationQuality; @@ -45,15 +46,15 @@ public class AzureOpenAiImageModel implements ImageModel { private final Logger logger = LoggerFactory.getLogger(getClass()); @Autowired - private final OpenAIClient openAIClient; + private final OpenAIAsyncClient openAIClient; private final AzureOpenAiImageOptions defaultOptions; - public AzureOpenAiImageModel(OpenAIClient openAIClient) { + public AzureOpenAiImageModel(OpenAIAsyncClient openAIClient) { this(openAIClient, AzureOpenAiImageOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build()); } - public AzureOpenAiImageModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiImageOptions options) { + public AzureOpenAiImageModel(OpenAIAsyncClient microsoftOpenAiClient, AzureOpenAiImageOptions options) { Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(options, "AzureOpenAiChatOptions must not be null"); this.openAIClient = microsoftOpenAiClient; @@ -73,7 +74,7 @@ public ImageResponse call(ImagePrompt imagePrompt) { toPrettyJson(imageGenerationOptions)); } - var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions); + var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions).block(); if (logger.isTraceEnabled()) { logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index e9c8196ff3c..e8b7fdf6e02 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; @@ -40,7 +41,7 @@ public class AzureChatCompletionsOptionsTests { @Test public void createRequestWithChatOptions() { - OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + OpenAIAsyncClient mockClient = Mockito.mock(OpenAIAsyncClient.class); var defaultOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("DEFAULT_MODEL") diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java index 18fe0e56af1..851dd97d8eb 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java @@ -17,6 +17,7 @@ import java.util.List; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -35,7 +36,7 @@ public class AzureEmbeddingsOptionsTests { @Test public void createRequestWithChatOptions() { - OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + OpenAIAsyncClient mockClient = Mockito.mock(OpenAIAsyncClient.class); var client = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED, AzureOpenAiEmbeddingOptions.builder() .withDeploymentName("DEFAULT_MODEL") diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index 079b840e2cb..d832d5374e5 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -25,6 +25,7 @@ import java.util.Objects; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIAsyncClient; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -214,14 +215,14 @@ record ActorsFilmsRecord(String actor, List movies) { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) { + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiChatModel(openAIClient, AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(200).build()); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java index 7c1710e06c6..02a399b34d4 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java @@ -17,6 +17,7 @@ import java.util.List; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -67,14 +68,14 @@ void batchEmbedding() { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient) { + public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiEmbeddingModel(openAIClient); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java index 6ae824badcf..6ac865466a1 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; @@ -51,15 +52,15 @@ public class MockAzureOpenAiTestConfiguration { @Bean - OpenAIClient microsoftAzureOpenAiClient(MockWebServer webServer) { + OpenAIAsyncClient microsoftAzureOpenAiClient(MockWebServer webServer) { HttpUrl baseUrl = webServer.url(MockAiTestConfiguration.SPRING_AI_API_PATH); - return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildClient(); + return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildAsyncClient(); } @Bean - AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) { + AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient microsoftAzureOpenAiClient) { return new AzureOpenAiChatModel(microsoftAzureOpenAiClient); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index f33bd6abf6a..17b76e99e5c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -186,14 +187,14 @@ void streamFunctionCallTest() { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) { + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient, String selectedModel) { return new AzureOpenAiChatModel(openAIClient, AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build()); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java index f57dfb6a706..09b8d4b0bb3 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java @@ -1,5 +1,6 @@ package org.springframework.ai.azure.openai.image; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -65,14 +66,14 @@ void imageAsUrlTest() { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiImageModel azureOpenAiImageModel(OpenAIClient openAIClient) { + public AzureOpenAiImageModel azureOpenAiImageModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiImageModel(openAIClient, AzureOpenAiImageOptions.builder().withDeploymentName("Dalle3").build()); diff --git a/pom.xml b/pom.xml index ce1bc39a47f..9a42733bced 100644 --- a/pom.xml +++ b/pom.xml @@ -147,7 +147,7 @@ 3.3.0 6.1.4 4.3.4 - 1.0.0-beta.10 + 1.0.0-beta.8 1.0.0 4.31.1 diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index 7c38b2fd9b1..cd32d837f09 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -293,7 +293,7 @@ Next, create an `AzureOpenAiChatModel` instance and use it to generate text resp var openAIClient = new OpenAIClientBuilder() .credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); var openAIChatOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("gpt-35-turbo") diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 476cca5f682..e55d5221d67 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -17,6 +17,7 @@ import java.util.List; +import com.azure.ai.openai.OpenAIAsyncClient; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; import org.springframework.ai.azure.openai.AzureOpenAiImageModel; @@ -51,8 +52,8 @@ public class AzureOpenAiAutoConfiguration { private final static String APPLICATION_ID = "spring-ai"; @Bean - @ConditionalOnMissingBean({ OpenAIClient.class, TokenCredential.class }) - public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) { + @ConditionalOnMissingBean({ OpenAIAsyncClient.class, TokenCredential.class }) + public OpenAIAsyncClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) { if (StringUtils.hasText(connectionProperties.getApiKey())) { @@ -61,7 +62,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(new AzureKeyCredential(connectionProperties.getApiKey())) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .buildAsyncClient(); } // Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is @@ -70,7 +71,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope return new OpenAIClientBuilder().endpoint("https://api.openai.com/v1") .credential(new KeyCredential(connectionProperties.getOpenAiApiKey())) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .buildAsyncClient(); } throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty"); @@ -79,7 +80,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope @Bean @ConditionalOnMissingBean @ConditionalOnBean(TokenCredential.class) - public OpenAIClient openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, + public OpenAIAsyncClient openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, TokenCredential tokenCredential) { Assert.notNull(tokenCredential, "TokenCredential must not be null"); @@ -88,13 +89,13 @@ public OpenAIClient openAIClientWithTokenCredential(AzureOpenAiConnectionPropert return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(tokenCredential) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .buildAsyncClient(); } @Bean @ConditionalOnProperty(prefix = AzureOpenAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient, AzureOpenAiChatProperties chatProperties, List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { @@ -111,7 +112,7 @@ public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, @Bean @ConditionalOnProperty(prefix = AzureOpenAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIClient openAIClient, + public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIAsyncClient openAIClient, AzureOpenAiEmbeddingProperties embeddingProperties) { return new AzureOpenAiEmbeddingModel(openAIClient, embeddingProperties.getMetadataMode(), embeddingProperties.getOptions()); @@ -128,7 +129,7 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex @Bean @ConditionalOnProperty(prefix = AzureOpenAiImageOptionsProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiImageModel azureOpenAiImageClient(OpenAIClient openAIClient, + public AzureOpenAiImageModel azureOpenAiImageClient(OpenAIAsyncClient openAIClient, AzureOpenAiImageOptionsProperties imageProperties) { return new AzureOpenAiImageModel(openAIClient, imageProperties.getOptions()); diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java index 55306ea01fb..f26dd6e1d3b 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java @@ -21,6 +21,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -255,14 +256,14 @@ public VectorStore qdrantVectorStore(EmbeddingModel embeddingModel, QdrantClient } @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient) { + public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiEmbeddingModel(openAIClient); }