Skip to content

Provide the ability to configure OpenAI client read timeout #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;
import java.util.List;

import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
Expand All @@ -33,8 +34,12 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand All @@ -60,8 +65,7 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper
List<FunctionCallback> toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {

var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
var openAiApi = openAiApi(commonProperties, chatProperties, restClientBuilder, responseErrorHandler);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
Expand All @@ -78,23 +82,22 @@ public OpenAiEmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties co
OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {

var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
var openAiApi = openAiApi(commonProperties, embeddingProperties, restClientBuilder, responseErrorHandler);

return new OpenAiEmbeddingClient(openAiApi, embeddingProperties.getMetadataMode(),
embeddingProperties.getOptions(), retryTemplate);
}

private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
private <T extends OpenAiParentProperties> OpenAiApi openAiApi(OpenAiConnectionProperties commonProperties,
T specificProperties, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {

String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
Assert.hasText(resolvedBaseUrl, "OpenAI base URL must be set");
OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties,
specificProperties);
RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder,
overridenCommonProperties);

String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
Assert.hasText(resolvedApiKey, "OpenAI API key must be set");

return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler);
return new OpenAiApi(overridenCommonProperties.getBaseUrl(), overridenCommonProperties.getApiKey(),
overrideRestClientBuilder, responseErrorHandler);
}

@Bean
Expand All @@ -105,41 +108,32 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp
OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {

String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey()
: commonProperties.getApiKey();

String baseUrl = StringUtils.hasText(imageProperties.getBaseUrl()) ? imageProperties.getBaseUrl()
: commonProperties.getBaseUrl();

Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties,
imageProperties);
RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder,
overridenCommonProperties);

var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler);
var openAiImageApi = new OpenAiImageApi(overridenCommonProperties.getBaseUrl(),
overridenCommonProperties.getApiKey(), overrideRestClientBuilder, responseErrorHandler);

return new OpenAiImageClient(openAiImageApi, imageProperties.getOptions(), retryTemplate);
}

@Bean
@ConditionalOnMissingBean
public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConnectionProperties commonProperties,
OpenAiAudioTranscriptionProperties transcriptionProperties, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {

String apiKey = StringUtils.hasText(transcriptionProperties.getApiKey()) ? transcriptionProperties.getApiKey()
: commonProperties.getApiKey();

String baseUrl = StringUtils.hasText(transcriptionProperties.getBaseUrl())
? transcriptionProperties.getBaseUrl() : commonProperties.getBaseUrl();

Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
OpenAiAudioTranscriptionProperties transcriptionProperties, RestClient.Builder restClientBuilder,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {

var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler);
OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties,
transcriptionProperties);
RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder,
overridenCommonProperties);

OpenAiAudioTranscriptionClient openAiChatClient = new OpenAiAudioTranscriptionClient(openAiAudioApi,
transcriptionProperties.getOptions(), retryTemplate);
var openAiAudioApi = new OpenAiAudioApi(overridenCommonProperties.getBaseUrl(),
overridenCommonProperties.getApiKey(), overrideRestClientBuilder, responseErrorHandler);

return openAiChatClient;
return new OpenAiAudioTranscriptionClient(openAiAudioApi, transcriptionProperties.getOptions(), retryTemplate);
}

@Bean
Expand All @@ -150,4 +144,37 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex
return manager;
}

private static <T extends OpenAiParentProperties> OpenAiConnectionProperties checkAndOverrideProperties(
OpenAiConnectionProperties commonProperties, T specificProperties) {

String apiKey = StringUtils.hasText(specificProperties.getApiKey()) ? specificProperties.getApiKey()
: commonProperties.getApiKey();

String baseUrl = StringUtils.hasText(specificProperties.getBaseUrl()) ? specificProperties.getBaseUrl()
: commonProperties.getBaseUrl();

Duration readTimeout = specificProperties.getReadTimeout() != null ? specificProperties.getReadTimeout()
: commonProperties.getReadTimeout();

Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
Assert.notNull(readTimeout, "OpenAI base read timeout must be set");

OpenAiConnectionProperties overridenCommonProperties = new OpenAiConnectionProperties();
overridenCommonProperties.setApiKey(apiKey);
overridenCommonProperties.setBaseUrl(baseUrl);
overridenCommonProperties.setReadTimeout(readTimeout);

return overridenCommonProperties;

}

private static RestClient.Builder overrideRestClientBuilder(RestClient.Builder restClientBuilder,
OpenAiConnectionProperties overridenCommonProperties) {
ClientHttpRequestFactorySettings requestFactorySettings = new ClientHttpRequestFactorySettings(
Duration.ofHours(1l), overridenCommonProperties.getReadTimeout(), SslBundle.of(null));
ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(requestFactorySettings);
return restClientBuilder.clone().requestFactory(requestFactory);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;

import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties(OpenAiConnectionProperties.CONFIG_PREFIX)
Expand All @@ -24,8 +26,11 @@ public class OpenAiConnectionProperties extends OpenAiParentProperties {

public static final String DEFAULT_BASE_URL = "https://api.openai.com";

public static final Duration DEFAULT_READ_TIMEOUT = Duration.ofMinutes(1);

public OpenAiConnectionProperties() {
super.setBaseUrl(DEFAULT_BASE_URL);
super.setReadTimeout(DEFAULT_READ_TIMEOUT);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;

/**
* Internal parent properties for the OpenAI properties.
*
Expand All @@ -27,6 +29,8 @@ class OpenAiParentProperties {

private String baseUrl;

private Duration readTimeout;

public String getApiKey() {
return apiKey;
}
Expand All @@ -43,4 +47,12 @@ public void setBaseUrl(String baseUrl) {
this.baseUrl = baseUrl;
}

public Duration getReadTimeout() {
return readTimeout;
}

public void setReadTimeout(Duration readTimeout) {
this.readTimeout = readTimeout;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;
import org.junit.jupiter.api.Test;
import org.skyscreamer.jsonassert.JSONAssert;
import org.skyscreamer.jsonassert.JSONCompareMode;
Expand Down Expand Up @@ -50,6 +51,7 @@ public void chatProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
Expand All @@ -61,9 +63,11 @@ public void chatProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(chatProperties.getApiKey()).isNull();
assertThat(chatProperties.getBaseUrl()).isNull();
assertThat(chatProperties.getReadTimeout()).isNull();

assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
Expand Down Expand Up @@ -104,8 +108,10 @@ public void chatOverrideConnectionProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.chat.base-url=TEST_BASE_URL2",
"spring.ai.openai.chat.api-key=456",
"spring.ai.openai.chat.read-timeout=5m",
"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
Expand All @@ -117,9 +123,11 @@ public void chatOverrideConnectionProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(chatProperties.getApiKey()).isEqualTo("456");
assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
assertThat(chatProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5));

assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
Expand Down Expand Up @@ -162,6 +170,7 @@ public void embeddingProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.embedding.options.model=MODEL_XYZ")
// @formatter:on
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
Expand All @@ -172,9 +181,11 @@ public void embeddingProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getApiKey()).isNull();
assertThat(embeddingProperties.getBaseUrl()).isNull();
assertThat(embeddingProperties.getReadTimeout()).isNull();

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
});
Expand All @@ -187,8 +198,10 @@ public void embeddingOverrideConnectionProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.embedding.base-url=TEST_BASE_URL2",
"spring.ai.openai.embedding.api-key=456",
"spring.ai.openai.embedding.read-timeout=5m",
"spring.ai.openai.embedding.options.model=MODEL_XYZ")
// @formatter:on
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
Expand All @@ -199,9 +212,11 @@ public void embeddingOverrideConnectionProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getApiKey()).isEqualTo("456");
assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
assertThat(embeddingProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5));

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
});
Expand All @@ -213,6 +228,7 @@ public void imageProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.image.options.model=MODEL_XYZ",
"spring.ai.openai.image.options.n=3")
// @formatter:on
Expand All @@ -224,9 +240,11 @@ public void imageProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(imageProperties.getApiKey()).isNull();
assertThat(imageProperties.getBaseUrl()).isNull();
assertThat(imageProperties.getReadTimeout()).isNull();

assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
Expand All @@ -239,8 +257,10 @@ public void imageOverrideConnectionProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.image.base-url=TEST_BASE_URL2",
"spring.ai.openai.image.api-key=456",
"spring.ai.openai.image.read-timeout=5m",
"spring.ai.openai.image.options.model=MODEL_XYZ",
"spring.ai.openai.image.options.n=3")
// @formatter:on
Expand All @@ -252,9 +272,11 @@ public void imageOverrideConnectionProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(imageProperties.getApiKey()).isEqualTo("456");
assertThat(imageProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
assertThat(imageProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5));

assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
Expand All @@ -268,6 +290,7 @@ public void chatOptionsTest() {
// @formatter:off
"spring.ai.openai.api-key=API_KEY",
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.read-timeout=2m",

"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.frequencyPenalty=-1.5",
Expand Down Expand Up @@ -322,6 +345,7 @@ public void chatOptionsTest() {

assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002");

Expand Down Expand Up @@ -395,6 +419,7 @@ public void embeddingOptionsTest() {
// @formatter:off
"spring.ai.openai.api-key=API_KEY",
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.read-timeout=2m",

"spring.ai.openai.embedding.options.model=MODEL_XYZ",
"spring.ai.openai.embedding.options.encodingFormat=MyEncodingFormat",
Expand All @@ -409,6 +434,7 @@ public void embeddingOptionsTest() {

assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(embeddingProperties.getOptions().getEncodingFormat()).isEqualTo("MyEncodingFormat");
Expand All @@ -422,6 +448,7 @@ public void imageOptionsTest() {
// @formatter:off
"spring.ai.openai.api-key=API_KEY",
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.read-timeout=2m",

"spring.ai.openai.image.options.n=3",
"spring.ai.openai.image.options.model=MODEL_XYZ",
Expand All @@ -442,6 +469,7 @@ public void imageOptionsTest() {

assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
Expand Down