Skip to content

Support actual streaming for AzureAI #1054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
Expand All @@ -38,22 +39,22 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {

private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class);

private final OpenAIClient azureOpenAiClient;
private final OpenAIAsyncClient azureOpenAiClient;

private final AzureOpenAiEmbeddingOptions defaultOptions;

private final MetadataMode metadataMode;

public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient) {
public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient) {
this(azureOpenAiClient, MetadataMode.EMBED);
}

public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) {
public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient, MetadataMode metadataMode) {
this(azureOpenAiClient, metadataMode,
AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build());
}

public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
public AzureOpenAiEmbeddingModel(OpenAIAsyncClient azureOpenAiClient, MetadataMode metadataMode,
AzureOpenAiEmbeddingOptions options) {
Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(metadataMode, "Metadata mode must not be null");
Expand All @@ -78,7 +79,7 @@ public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
logger.debug("Retrieving embeddings");

EmbeddingsOptions azureOptions = toEmbeddingOptions(embeddingRequest);
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions);
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions).block();

logger.debug("Embeddings retrieved");
return generateEmbeddingResponse(embeddings);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ImageGenerationOptions;
import com.azure.ai.openai.models.ImageGenerationQuality;
Expand Down Expand Up @@ -45,15 +46,15 @@ public class AzureOpenAiImageModel implements ImageModel {
private final Logger logger = LoggerFactory.getLogger(getClass());

@Autowired
private final OpenAIClient openAIClient;
private final OpenAIAsyncClient openAIClient;

private final AzureOpenAiImageOptions defaultOptions;

public AzureOpenAiImageModel(OpenAIClient openAIClient) {
public AzureOpenAiImageModel(OpenAIAsyncClient openAIClient) {
this(openAIClient, AzureOpenAiImageOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build());
}

public AzureOpenAiImageModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiImageOptions options) {
public AzureOpenAiImageModel(OpenAIAsyncClient microsoftOpenAiClient, AzureOpenAiImageOptions options) {
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
this.openAIClient = microsoftOpenAiClient;
Expand All @@ -73,7 +74,7 @@ public ImageResponse call(ImagePrompt imagePrompt) {
toPrettyJson(imageGenerationOptions));
}

var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions);
var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions).block();

if (logger.isTraceEnabled()) {
logger.trace("Azure ImageGenerations: {}", toPrettyJson(images));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
Expand All @@ -40,7 +41,7 @@ public class AzureChatCompletionsOptionsTests {
@Test
public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
OpenAIAsyncClient mockClient = Mockito.mock(OpenAIAsyncClient.class);

var defaultOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("DEFAULT_MODEL")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.List;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand All @@ -35,7 +36,7 @@ public class AzureEmbeddingsOptionsTests {
@Test
public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
OpenAIAsyncClient mockClient = Mockito.mock(OpenAIAsyncClient.class);
var client = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED,
AzureOpenAiEmbeddingOptions.builder()
.withDeploymentName("DEFAULT_MODEL")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Objects;
import java.util.stream.Collectors;

import com.azure.ai.openai.OpenAIAsyncClient;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
Expand Down Expand Up @@ -214,14 +215,14 @@ record ActorsFilmsRecord(String actor, List<String> movies) {
public static class TestConfiguration {

@Bean
public OpenAIClient openAIClient() {
public OpenAIAsyncClient openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
.buildAsyncClient();
}

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) {
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient) {
return new AzureOpenAiChatModel(openAIClient,
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(200).build());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.List;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
Expand Down Expand Up @@ -67,14 +68,14 @@ void batchEmbedding() {
public static class TestConfiguration {

@Bean
public OpenAIClient openAIClient() {
public OpenAIAsyncClient openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
.buildAsyncClient();
}

@Bean
public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient) {
public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIAsyncClient openAIClient) {
return new AzureOpenAiEmbeddingModel(openAIClient);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;

Expand Down Expand Up @@ -51,15 +52,15 @@
public class MockAzureOpenAiTestConfiguration {

@Bean
OpenAIClient microsoftAzureOpenAiClient(MockWebServer webServer) {
OpenAIAsyncClient microsoftAzureOpenAiClient(MockWebServer webServer) {

HttpUrl baseUrl = webServer.url(MockAiTestConfiguration.SPRING_AI_API_PATH);

return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildClient();
return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildAsyncClient();
}

@Bean
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) {
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient microsoftAzureOpenAiClient) {
return new AzureOpenAiChatModel(microsoftAzureOpenAiClient);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
Expand Down Expand Up @@ -85,6 +87,66 @@ void functionCallTest() {
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
}

@Test
void functionCallSequentialTest() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();

ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
}

@Test
void functionCallSequentialAndStreamTest() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();

var response = chatModel.stream(new Prompt(messages, promptOptions));

List<ChatResponse> responses = response.collectList().block();
String stitchedResponseContent = responses.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());

logger.info("Response: {}", response);

assertThat(stitchedResponseContent).containsAnyOf("30.0", "30");
assertThat(stitchedResponseContent).containsAnyOf("10.0", "10");
assertThat(stitchedResponseContent).containsAnyOf("15.0", "15");
}

@Test
void streamFunctionCallTest() {
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
Expand Down Expand Up @@ -125,14 +187,14 @@ void streamFunctionCallTest() {
public static class TestConfiguration {

@Bean
public OpenAIClient openAIClient() {
public OpenAIAsyncClient openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
.buildAsyncClient();
}

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) {
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIAsyncClient openAIClient, String selectedModel) {
return new AzureOpenAiChatModel(openAIClient,
AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.springframework.ai.azure.openai.image;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
Expand Down Expand Up @@ -65,14 +66,14 @@ void imageAsUrlTest() {
public static class TestConfiguration {

@Bean
public OpenAIClient openAIClient() {
public OpenAIAsyncClient openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
.buildAsyncClient();
}

@Bean
public AzureOpenAiImageModel azureOpenAiImageModel(OpenAIClient openAIClient) {
public AzureOpenAiImageModel azureOpenAiImageModel(OpenAIAsyncClient openAIClient) {
return new AzureOpenAiImageModel(openAIClient,
AzureOpenAiImageOptions.builder().withDeploymentName("Dalle3").build());

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
<spring-boot.version>3.3.0</spring-boot.version>
<spring-framework.version>6.1.4</spring-framework.version>
<ST4.version>4.3.4</ST4.version>
<azure-open-ai-client.version>1.0.0-beta.10</azure-open-ai-client.version>
<azure-open-ai-client.version>1.0.0-beta.8</azure-open-ai-client.version>
<jtokkit.version>1.0.0</jtokkit.version>
<victools.version>4.31.1</victools.version>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ Next, create an `AzureOpenAiChatModel` instance and use it to generate text resp
var openAIClient = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
.buildAsyncClient();

var openAIChatOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("gpt-35-turbo")
Expand Down
Loading