diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index b883458481e..36b390b0359 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -15,36 +15,17 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.models.ChatChoice; -import com.azure.ai.openai.models.ChatCompletions; -import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; -import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition; -import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsOptions; -import com.azure.ai.openai.models.ChatCompletionsResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsToolCall; -import com.azure.ai.openai.models.ChatCompletionsToolDefinition; -import com.azure.ai.openai.models.ChatMessageContentItem; -import com.azure.ai.openai.models.ChatMessageImageContentItem; -import com.azure.ai.openai.models.ChatMessageImageUrl; -import com.azure.ai.openai.models.ChatMessageTextContentItem; -import com.azure.ai.openai.models.ChatRequestAssistantMessage; -import com.azure.ai.openai.models.ChatRequestMessage; -import com.azure.ai.openai.models.ChatRequestSystemMessage; -import com.azure.ai.openai.models.ChatRequestToolMessage; -import com.azure.ai.openai.models.ChatRequestUserMessage; -import com.azure.ai.openai.models.CompletionsFinishReason; -import com.azure.ai.openai.models.ContentFilterResultsForPrompt; -import com.azure.ai.openai.models.FunctionCall; -import com.azure.ai.openai.models.FunctionDefinition; +import com.azure.ai.openai.models.*; import com.azure.core.util.BinaryData; -import com.azure.core.util.IterableStream; +import com.azure.core.util.CoreUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; @@ -54,19 +35,15 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.AbstractToolCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by @@ -82,8 +59,7 @@ * @see ChatModel * @see com.azure.ai.openai.OpenAIClient */ -public class AzureOpenAiChatModel extends - AbstractFunctionCallSupport implements ChatModel { +public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel { private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo"; @@ -94,14 +70,14 @@ public class AzureOpenAiChatModel extends /** * The {@link OpenAIClient} used to interact with the Azure OpenAI service. */ - private final OpenAIClient openAIClient; + private final OpenAIAsyncClient openAIClient; /** * The configuration information for a chat completions request. */ private AzureOpenAiChatOptions defaultOptions; - public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) { + public AzureOpenAiChatModel(OpenAIAsyncClient microsoftOpenAiClient) { this(microsoftOpenAiClient, AzureOpenAiChatOptions.builder() .withDeploymentName(DEFAULT_DEPLOYMENT_NAME) @@ -109,11 +85,11 @@ public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) { .build()); } - public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) { + public AzureOpenAiChatModel(OpenAIAsyncClient microsoftOpenAiClient, AzureOpenAiChatOptions options) { this(microsoftOpenAiClient, options, null); } - public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options, + public AzureOpenAiChatModel(OpenAIAsyncClient microsoftOpenAiClient, AzureOpenAiChatOptions options, FunctionCallbackContext functionCallbackContext) { super(functionCallbackContext); Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); @@ -124,7 +100,7 @@ public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatO /** * @deprecated since 0.8.0, use - * {@link #AzureOpenAiChatModel(OpenAIClient, AzureOpenAiChatOptions)} instead. + * {@link #AzureOpenAiChatModel(OpenAIAsyncClient, AzureOpenAiChatOptions)} instead. */ @Deprecated(forRemoval = true, since = "0.8.0") public AzureOpenAiChatModel withDefaultOptions(AzureOpenAiChatOptions defaultOptions) { @@ -144,18 +120,27 @@ public ChatResponse call(Prompt prompt) { options.setStream(false); logger.trace("Azure ChatCompletionsOptions: {}", options); - ChatCompletions chatCompletions = this.callWithFunctionSupport(options); - logger.trace("Azure ChatCompletions: {}", chatCompletions); + ChatCompletions chatCompletion = this.openAIClient.getChatCompletions(options.getModel(), options).block(); + logger.trace("Azure ChatCompletions: {}", chatCompletion); - List generations = nullSafeList(chatCompletions.getChoices()).stream() + if (isToolFunctionCall(chatCompletion)) { + List toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), + chatCompletion); + + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions())); + } + + List generations = nullSafeList(chatCompletion.getChoices()).stream() .map(choice -> new Generation(choice.getMessage().getContent()) .withGenerationMetadata(generateChoiceMetadata(choice))) .toList(); - PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); + PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletion); return new ChatResponse(generations, - AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata)); + AzureOpenAiChatResponseMetadata.from(chatCompletion, promptFilterMetadata)); } @Override @@ -164,45 +149,94 @@ public Flux stream(Prompt prompt) { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); options.setStream(true); - IterableStream chatCompletionsStream = this.openAIClient - .getChatCompletionsStream(options.getModel(), options); - - Flux chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream); - - final var isFunctionCall = new AtomicBoolean(false); - final var accessibleChatCompletionsFlux = chatCompletionsFlux - // Note: the first chat completions can be ignored when using Azure OpenAI - // service which is a known service bug. - .skip(1) - .map(chatCompletions -> { - final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); - isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); - return chatCompletions; - }) - .windowUntil(chatCompletions -> { - if (isFunctionCall.get() && chatCompletions.getChoices() - .get(0) - .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { - isFunctionCall.set(false); - return true; + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + final ChatCompletions[] currentChatCompletion = { MergeUtils.emptyChatCompletions() }; + return this.openAIClient.getChatCompletionsStream(options.getModel(), options) + .filter(chatCompletion -> !CoreUtils.isNullOrEmpty(chatCompletion.getChoices())) // see: + // https://learn.microsoft.com/en-us/java/api/overview/azure/ai-openai-readme?view=azure-java-preview#streaming-chat-completions + .switchMap(chatCompletion -> { + currentChatCompletion[0] = MergeUtils.mergeChatCompletions(currentChatCompletion[0], chatCompletion); + + if (this.isToolFunctionCall(currentChatCompletion[0])) { + var toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), + currentChatCompletion[0]); + + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); } - return !isFunctionCall.get(); - }) - .concatMapIterable(window -> { - final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions); - return List.of(reduce); - }) - .flatMap(mono -> mono); - return accessibleChatCompletionsFlux - .switchMap(accessibleChatCompletions -> handleFunctionCallOrReturnStream(options, - Flux.just(accessibleChatCompletions))) - .flatMapIterable(ChatCompletions::getChoices) - .map(choice -> { - var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); - var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); - return new ChatResponse(List.of(generation)); + + // Non function calling. + return Mono.just(chatCompletion).map(chatCompletion2 -> { + try { + @SuppressWarnings("null") + String id = chatCompletion2.getId(); + + List generations = chatCompletion2.getChoices().stream().map(choice -> { + if (choice.getDelta().getRole() != null) { + roleMap.putIfAbsent(id, choice.getDelta().getRole().toString()); + } + String finish = (choice.getFinishReason() != null ? choice.getFinishReason().toString() + : ""); + var generation = new Generation(choice.getDelta().getContent(), + Map.of("id", id, "role", roleMap.getOrDefault(id, ""), "finishReason", finish)); + if (choice.getFinishReason() != null) { + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(choice.getFinishReason().toString(), null)); + } + return generation; + }).toList(); + + if (chatCompletion2.getUsage() != null) { + return new ChatResponse(generations, AzureOpenAiChatResponseMetadata.from(chatCompletion2, + generatePromptMetadata(chatCompletion2))); + } + else { + return new ChatResponse(generations); + } + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + }); }); + } + + private ChatResponseMessage extractAssistantMessage(ChatCompletions chatCompletion) { + return Optional.ofNullable(chatCompletion.getChoices().iterator().next().getMessage()) + .orElse(chatCompletion.getChoices().iterator().next().getDelta()); + } + private List handleToolCallRequests(List previousMessages, ChatCompletions chatCompletion) { + ChatResponseMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion); + + List assistantToolCalls = nativeAssistantMessage.getToolCalls() + .stream() + .filter(x -> x instanceof ChatCompletionsFunctionToolCall) + .map(x -> (ChatCompletionsFunctionToolCall) x) + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.getId(), "function", + toolCall.getFunction().getName(), toolCall.getFunction().getArguments())) + .toList(); + + AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.getContent(), Map.of(), + assistantToolCalls); + + List toolResponseMessages = this.executeFuncitons(assistantMessage); + + // History + List messages = new ArrayList<>(previousMessages); + messages.add(assistantMessage); + messages.addAll(toolResponseMessages); + + return messages; + } + + List toAzureChatMessage(List messages) { + return messages.stream().map(this::fromSpringAiMessage).toList(); } /** @@ -212,10 +246,7 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { Set functionsForThisRequest = new HashSet<>(); - List azureMessages = prompt.getInstructions() - .stream() - .map(this::fromSpringAiMessage) - .toList(); + List azureMessages = toAzureChatMessage(prompt.getInstructions()); ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); @@ -288,7 +319,23 @@ private ChatRequestMessage fromSpringAiMessage(Message message) { case SYSTEM: return new ChatRequestSystemMessage(message.getContent()); case ASSISTANT: - return new ChatRequestAssistantMessage(message.getContent()); + var assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = assistantMessage.getToolCalls() + .stream() + .map(toolCall -> (ChatCompletionsToolCall) new ChatCompletionsFunctionToolCall(toolCall.id(), + new FunctionCall(toolCall.name(), toolCall.arguments()))) + .toList(); + } + var azureAssistantMessage = new ChatRequestAssistantMessage(message.getContent()); + if (toolCalls != null) { + azureAssistantMessage.setToolCalls(toolCalls); + } + return azureAssistantMessage; + case TOOL: + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + return new ChatRequestToolMessage(message.getContent(), toolMessage.getId()); default: throw new IllegalArgumentException("Unknown message type " + message.getMessageType()); } @@ -438,58 +485,6 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, return mergedAzureOptions; } - /** - * Merges the fromOptions into the toOptions and returns a new ChatCompletionsOptions - * instance. - * @param fromOptions the ChatCompletionsOptions to merge from. - * @param toOptions the ChatCompletionsOptions to merge to. - * @return a new ChatCompletionsOptions instance. - */ - private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCompletionsOptions toOptions) { - - if (fromOptions == null) { - return toOptions; - } - - ChatCompletionsOptions mergedOptions = this.copy(toOptions); - - if (fromOptions.getMaxTokens() != null) { - mergedOptions.setMaxTokens(fromOptions.getMaxTokens()); - } - if (fromOptions.getLogitBias() != null) { - mergedOptions.setLogitBias(fromOptions.getLogitBias()); - } - if (fromOptions.getStop() != null) { - mergedOptions.setStop(fromOptions.getStop()); - } - if (fromOptions.getTemperature() != null) { - mergedOptions.setTemperature(fromOptions.getTemperature()); - } - if (fromOptions.getTopP() != null) { - mergedOptions.setTopP(fromOptions.getTopP()); - } - if (fromOptions.getFrequencyPenalty() != null) { - mergedOptions.setFrequencyPenalty(fromOptions.getFrequencyPenalty()); - } - if (fromOptions.getPresencePenalty() != null) { - mergedOptions.setPresencePenalty(fromOptions.getPresencePenalty()); - } - if (fromOptions.getN() != null) { - mergedOptions.setN(fromOptions.getN()); - } - if (fromOptions.getUser() != null) { - mergedOptions.setUser(fromOptions.getUser()); - } - if (fromOptions.getModel() != null) { - mergedOptions.setModel(fromOptions.getModel()); - } - if (fromOptions.getResponseFormat() != null) { - mergedOptions.setResponseFormat(fromOptions.getResponseFormat()); - } - - return mergedOptions; - } - /** * Copy the fromOptions into a new ChatCompletionsOptions instance. * @param fromOptions the ChatCompletionsOptions to copy from. @@ -537,67 +532,6 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { return copyOptions; } - @Override - protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, - ChatRequestMessage responseMessage, List conversationHistory) { - - // Every tool-call item requires a separate function call and a response (TOOL) - // message. - for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) { - - var functionName = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName(); - String functionArguments = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getArguments(); - - if (!this.functionCallbackRegister.containsKey(functionName)) { - throw new IllegalStateException("No function callback found for function name: " + functionName); - } - - String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); - - // Add the function response to the conversation. - conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId())); - } - - // Recursively call chatCompletionWithTools until the model doesn't call a - // functions anymore. - ChatCompletionsOptions newRequest = new ChatCompletionsOptions(conversationHistory); - - newRequest = merge(previousRequest, newRequest); - - return newRequest; - } - - @Override - protected List doGetUserMessages(ChatCompletionsOptions request) { - return request.getMessages(); - } - - @Override - protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) { - final var accessibleChatChoice = response.getChoices().get(0); - var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()) - .orElse(accessibleChatChoice.getDelta()); - ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage(""); - final var toolCalls = responseMessage.getToolCalls(); - assistantMessage.setToolCalls(toolCalls.stream().map(tc -> { - final var tc1 = (ChatCompletionsFunctionToolCall) tc; - var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(), - new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments())); - return ((ChatCompletionsToolCall) toDowncast); - }).toList()); - return assistantMessage; - } - - @Override - protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) { - return this.openAIClient.getChatCompletions(request.getModel(), request); - } - - @Override - protected Flux doChatCompletionStream(ChatCompletionsOptions request) { - return Flux.fromIterable(this.openAIClient.getChatCompletionsStream(request.getModel(), request)); - } - @Override protected boolean isToolFunctionCall(ChatCompletions chatCompletions) { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java index 625c8e02be9..e53335509ff 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.EmbeddingItem; import com.azure.ai.openai.models.Embeddings; @@ -38,22 +39,22 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel { private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class); - private final OpenAIClient azureOpenAiClient; + private final OpenAIAsyncClient azureOpenAiClient; private final AzureOpenAiEmbeddingOptions defaultOptions; private final MetadataMode metadataMode; - public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient) { + public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient) { this(azureOpenAiClient, MetadataMode.EMBED); } - public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) { + public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient, MetadataMode metadataMode) { this(azureOpenAiClient, metadataMode, AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build()); } - public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode, + public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient, MetadataMode metadataMode, AzureOpenAiEmbeddingOptions options) { Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(metadataMode, "Metadata mode must not be null"); @@ -78,7 +79,7 @@ public EmbeddingResponse call(EmbeddingRequest embeddingRequest) { logger.debug("Retrieving embeddings"); EmbeddingsOptions azureOptions = toEmbeddingOptions(embeddingRequest); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions); + Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions).block(); logger.debug("Embeddings retrieved"); return generateEmbeddingResponse(embeddings); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index e6da1ebbf38..2c35122b8dc 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -1,5 +1,6 @@ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ImageGenerationOptions; import com.azure.ai.openai.models.ImageGenerationQuality; @@ -45,15 +46,15 @@ public class AzureOpenAiImageModel implements ImageModel { private final Logger logger = LoggerFactory.getLogger(getClass()); @Autowired - private final OpenAIClient openAIClient; + private final OpenAIAsyncClient openAIClient; private final AzureOpenAiImageOptions defaultOptions; - public AzureOpenAiImageModel(OpenAIClient openAIClient) { + public AzureOpenAiImageModel(OpenAIAsyncClient openAIClient) { this(openAIClient, AzureOpenAiImageOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build()); } - public AzureOpenAiImageModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiImageOptions options) { + public AzureOpenAiImageModel(OpenAIAsyncClient microsoftOpenAiClient, AzureOpenAiImageOptions options) { Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(options, "AzureOpenAiChatOptions must not be null"); this.openAIClient = microsoftOpenAiClient; @@ -73,7 +74,7 @@ public ImageResponse call(ImagePrompt imagePrompt) { toPrettyJson(imageGenerationOptions)); } - var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions); + var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions).block(); if (logger.isTraceEnabled()) { logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index e9c8196ff3c..e8b7fdf6e02 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; @@ -40,7 +41,7 @@ public class AzureChatCompletionsOptionsTests { @Test public void createRequestWithChatOptions() { - OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + OpenAIAsyncClient mockClient = Mockito.mock(OpenAIAsyncClient.class); var defaultOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("DEFAULT_MODEL") diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java index 18fe0e56af1..851dd97d8eb 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java @@ -17,6 +17,7 @@ import java.util.List; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -35,7 +36,7 @@ public class AzureEmbeddingsOptionsTests { @Test public void createRequestWithChatOptions() { - OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + OpenAIAsyncClient mockClient = Mockito.mock(OpenAIAsyncClient.class); var client = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED, AzureOpenAiEmbeddingOptions.builder() .withDeploymentName("DEFAULT_MODEL") diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index 079b840e2cb..d832d5374e5 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -25,6 +25,7 @@ import java.util.Objects; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIAsyncClient; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -214,14 +215,14 @@ record ActorsFilmsRecord(String actor, List movies) { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) { + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiChatModel(openAIClient, AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(200).build()); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java index 7c1710e06c6..02a399b34d4 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java @@ -17,6 +17,7 @@ import java.util.List; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -67,14 +68,14 @@ void batchEmbedding() { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient) { + public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiEmbeddingModel(openAIClient); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java index 6ae824badcf..6ac865466a1 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.azure.openai; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; @@ -51,15 +52,15 @@ public class MockAzureOpenAiTestConfiguration { @Bean - OpenAIClient microsoftAzureOpenAiClient(MockWebServer webServer) { + OpenAIAsyncClient microsoftAzureOpenAiClient(MockWebServer webServer) { HttpUrl baseUrl = webServer.url(MockAiTestConfiguration.SPRING_AI_API_PATH); - return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildClient(); + return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildAsyncClient(); } @Bean - AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) { + AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient microsoftAzureOpenAiClient) { return new AzureOpenAiChatModel(microsoftAzureOpenAiClient); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index f4349e17211..17b76e99e5c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -17,10 +17,12 @@ import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -85,6 +87,66 @@ void functionCallTest() { assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); } + @Test + void functionCallSequentialTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10"); + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); + } + + @Test + void functionCallSequentialAndStreamTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + var response = chatModel.stream(new Prompt(messages, promptOptions)); + + List responses = response.collectList().block(); + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + logger.info("Response: {}", response); + + assertThat(stitchedResponseContent).containsAnyOf("30.0", "30"); + assertThat(stitchedResponseContent).containsAnyOf("10.0", "10"); + assertThat(stitchedResponseContent).containsAnyOf("15.0", "15"); + } + @Test void streamFunctionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); @@ -125,14 +187,14 @@ void streamFunctionCallTest() { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) { + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient, String selectedModel) { return new AzureOpenAiChatModel(openAIClient, AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build()); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java index f57dfb6a706..09b8d4b0bb3 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/image/AzureOpenAiImageModelIT.java @@ -1,5 +1,6 @@ package org.springframework.ai.azure.openai.image; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -65,14 +66,14 @@ void imageAsUrlTest() { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiImageModel azureOpenAiImageModel(OpenAIClient openAIClient) { + public AzureOpenAiImageModel azureOpenAiImageModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiImageModel(openAIClient, AzureOpenAiImageOptions.builder().withDeploymentName("Dalle3").build()); diff --git a/pom.xml b/pom.xml index ce1bc39a47f..9a42733bced 100644 --- a/pom.xml +++ b/pom.xml @@ -147,7 +147,7 @@ 3.3.0 6.1.4 4.3.4 - 1.0.0-beta.10 + 1.0.0-beta.8 1.0.0 4.31.1 diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index 7c38b2fd9b1..cd32d837f09 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -293,7 +293,7 @@ Next, create an `AzureOpenAiChatModel` instance and use it to generate text resp var openAIClient = new OpenAIClientBuilder() .credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); var openAIChatOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("gpt-35-turbo") diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 476cca5f682..e55d5221d67 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -17,6 +17,7 @@ import java.util.List; +import com.azure.ai.openai.OpenAIAsyncClient; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; import org.springframework.ai.azure.openai.AzureOpenAiImageModel; @@ -51,8 +52,8 @@ public class AzureOpenAiAutoConfiguration { private final static String APPLICATION_ID = "spring-ai"; @Bean - @ConditionalOnMissingBean({ OpenAIClient.class, TokenCredential.class }) - public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) { + @ConditionalOnMissingBean({ OpenAIAsyncClient.class, TokenCredential.class }) + public OpenAIAsyncClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) { if (StringUtils.hasText(connectionProperties.getApiKey())) { @@ -61,7 +62,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(new AzureKeyCredential(connectionProperties.getApiKey())) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .buildAsyncClient(); } // Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is @@ -70,7 +71,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope return new OpenAIClientBuilder().endpoint("https://api.openai.com/v1") .credential(new KeyCredential(connectionProperties.getOpenAiApiKey())) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .buildAsyncClient(); } throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty"); @@ -79,7 +80,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope @Bean @ConditionalOnMissingBean @ConditionalOnBean(TokenCredential.class) - public OpenAIClient openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, + public OpenAIAsyncClient openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, TokenCredential tokenCredential) { Assert.notNull(tokenCredential, "TokenCredential must not be null"); @@ -88,13 +89,13 @@ public OpenAIClient openAIClientWithTokenCredential(AzureOpenAiConnectionPropert return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(tokenCredential) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .buildAsyncClient(); } @Bean @ConditionalOnProperty(prefix = AzureOpenAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient, AzureOpenAiChatProperties chatProperties, List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { @@ -111,7 +112,7 @@ public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, @Bean @ConditionalOnProperty(prefix = AzureOpenAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIClient openAIClient, + public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIAsyncClient openAIClient, AzureOpenAiEmbeddingProperties embeddingProperties) { return new AzureOpenAiEmbeddingModel(openAIClient, embeddingProperties.getMetadataMode(), embeddingProperties.getOptions()); @@ -128,7 +129,7 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex @Bean @ConditionalOnProperty(prefix = AzureOpenAiImageOptionsProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiImageModel azureOpenAiImageClient(OpenAIClient openAIClient, + public AzureOpenAiImageModel azureOpenAiImageClient(OpenAIAsyncClient openAIClient, AzureOpenAiImageOptionsProperties imageProperties) { return new AzureOpenAiImageModel(openAIClient, imageProperties.getOptions()); diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java index 55306ea01fb..f26dd6e1d3b 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java @@ -21,6 +21,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; @@ -255,14 +256,14 @@ public VectorStore qdrantVectorStore(EmbeddingModel embeddingModel, QdrantClient } @Bean - public OpenAIClient openAIClient() { + public OpenAIAsyncClient openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .buildAsyncClient(); } @Bean - public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient) { + public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIAsyncClient openAIClient) { return new AzureOpenAiEmbeddingModel(openAIClient); }