Skip to content

Commit bb88e2f

Browse files
markpollacksobychacko
authored andcommitted
Add retry support to VertexAI embedding and chat models
Resolves #832 Introduces retry functionality to VertexAI embedding and chat models, enhancing their resilience against transient failures. It also corrects a typo in the VertexAiEmbeddingConnectionDetails class name. Key changes: * Add RetryTemplate to VertexAiTextEmbeddingModel and VertexAiGeminiChatModel * Introduce spring-ai-retry dependency * Refactor code to support retry logic * Update auto-configuration classes to incorporate retry functionality * Fix typo in VertexAiEmbeddingConnectionDetails class name remove extraneous commented out code Add missing copyright headers, author etc.
1 parent 6fc76b7 commit bb88e2f

File tree

14 files changed

+624
-115
lines changed

14 files changed

+624
-115
lines changed

models/spring-ai-vertex-ai-embedding/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@
5252
<version>${project.parent.version}</version>
5353
</dependency>
5454

55+
<dependency>
56+
<groupId>org.springframework.ai</groupId>
57+
<artifactId>spring-ai-retry</artifactId>
58+
<version>${project.parent.version}</version>
59+
</dependency>
60+
5561
<dependency>
5662
<groupId>org.springframework</groupId>
5763
<artifactId>spring-web</artifactId>
Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@
2323
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
2424

2525
/**
26-
* VertexAiEmbeddigConnectionDetails represents the details of a connection to the Vertex
26+
* VertexAiEmbeddingConnectionDetails represents the details of a connection to the Vertex
2727
* AI embedding service. It provides methods to access the project ID, location,
2828
* publisher, and PredictionServiceSettings.
29+
*
30+
* @author Christian Tzolov
31+
* @author Mark Pollack
32+
* @since 1.0.0
2933
*/
30-
public class VertexAiEmbeddigConnectionDetails {
34+
public class VertexAiEmbeddingConnectionDetails {
3135

3236
private static final String DEFAULT_LOCATION = "us-central1";
3337

@@ -55,7 +59,7 @@ public class VertexAiEmbeddigConnectionDetails {
5559

5660
private final String publisher;
5761

58-
public VertexAiEmbeddigConnectionDetails(String endpoint, String projectId, String location, String publisher) {
62+
public VertexAiEmbeddingConnectionDetails(String endpoint, String projectId, String location, String publisher) {
5963
this.projectId = projectId;
6064
this.location = location;
6165
this.publisher = publisher;
@@ -119,7 +123,7 @@ public Builder withPublisher(String publisher) {
119123
return this;
120124
}
121125

122-
public VertexAiEmbeddigConnectionDetails build() {
126+
public VertexAiEmbeddingConnectionDetails build() {
123127
if (!StringUtils.hasText(this.endpoint)) {
124128
if (!StringUtils.hasText(this.location)) {
125129
this.endpoint = DEFAULT_ENDPOINT;
@@ -134,7 +138,7 @@ public VertexAiEmbeddigConnectionDetails build() {
134138
this.publisher = DEFAULT_PUBLISHER;
135139
}
136140

137-
return new VertexAiEmbeddigConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
141+
return new VertexAiEmbeddingConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
138142
}
139143

140144
}

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import org.springframework.ai.embedding.EmbeddingResultMetadata;
3636
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
3737
import org.springframework.ai.model.ModelOptionsUtils;
38-
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
38+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
3939
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
4040
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
4141
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder;
@@ -59,6 +59,7 @@
5959
* is not yet fully functional and is subject to change.
6060
*
6161
* @author Christian Tzolov
62+
* @author Mark Pollack
6263
* @since 1.0.0
6364
*/
6465
public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel {
@@ -76,9 +77,9 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
7677
private static final List<MimeType> SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
7778
MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp"));
7879

79-
private final VertexAiEmbeddigConnectionDetails connectionDetails;
80+
private final VertexAiEmbeddingConnectionDetails connectionDetails;
8081

81-
public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
82+
public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
8283
VertexAiMultimodalEmbeddingOptions defaultEmbeddingOptions) {
8384

8485
Assert.notNull(defaultEmbeddingOptions, "VertexAiMultimodalEmbeddingOptions must not be null");

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,17 @@
2929
import org.springframework.ai.embedding.EmbeddingResponse;
3030
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
3131
import org.springframework.ai.model.ModelOptionsUtils;
32-
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
32+
import org.springframework.ai.retry.RetryUtils;
33+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
34+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
3335
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
3436
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder;
3537
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder;
36-
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
38+
import org.springframework.retry.support.RetryTemplate;
3739
import org.springframework.util.Assert;
3840
import org.springframework.util.StringUtils;
3941

42+
import java.io.IOException;
4043
import java.util.ArrayList;
4144
import java.util.List;
4245
import java.util.Map;
@@ -47,22 +50,29 @@
4750
* A class representing a Vertex AI Text Embedding Model.
4851
*
4952
* @author Christian Tzolov
53+
* @author Mark Pollack
5054
* @since 1.0.0
5155
*/
5256
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
5357

5458
public final VertexAiTextEmbeddingOptions defaultOptions;
5559

56-
private final VertexAiEmbeddigConnectionDetails connectionDetails;
60+
private final VertexAiEmbeddingConnectionDetails connectionDetails;
61+
62+
private final RetryTemplate retryTemplate;
5763

58-
public VertexAiTextEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
64+
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
5965
VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
66+
this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
67+
}
6068

69+
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
70+
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
6171
Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");
62-
72+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
6373
this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();
64-
6574
this.connectionDetails = connectionDetails;
75+
this.retryTemplate = retryTemplate;
6676
}
6777

6878
@Override
@@ -73,46 +83,23 @@ public float[] embed(Document document) {
7383

7484
@Override
7585
public EmbeddingResponse call(EmbeddingRequest request) {
86+
return retryTemplate.execute(context -> {
87+
VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;
7688

77-
VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;
78-
79-
if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
80-
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
81-
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
82-
VertexAiTextEmbeddingOptions.class);
83-
}
84-
85-
try (PredictionServiceClient client = PredictionServiceClient
86-
.create(this.connectionDetails.getPredictionServiceSettings())) {
87-
88-
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
89-
90-
PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder()
91-
.setEndpoint(endpointName.toString());
92-
93-
TextParametersBuilder parametersBuilder = TextParametersBuilder.of();
94-
95-
if (finalOptions.getAutoTruncate() != null) {
96-
parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
97-
}
98-
99-
if (finalOptions.getDimensions() != null) {
100-
parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
89+
if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
90+
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
91+
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
92+
VertexAiTextEmbeddingOptions.class);
10193
}
10294

103-
predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));
95+
PredictionServiceClient client = createPredictionServiceClient();
10496

105-
for (int i = 0; i < request.getInstructions().size(); i++) {
97+
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
10698

107-
TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
108-
.withTaskType(finalOptions.getTaskType().name());
109-
if (StringUtils.hasText(finalOptions.getTitle())) {
110-
instanceBuilder.withTitle(finalOptions.getTitle());
111-
}
112-
predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
113-
}
99+
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
100+
finalOptions);
114101

115-
PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
102+
PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder);
116103

117104
int index = 0;
118105
int totalTokenCount = 0;
@@ -131,12 +118,53 @@ public EmbeddingResponse call(EmbeddingRequest request) {
131118
}
132119
return new EmbeddingResponse(embeddingList,
133120
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
121+
});
122+
}
123+
124+
protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
125+
VertexAiTextEmbeddingOptions finalOptions) {
126+
PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString());
127+
128+
TextParametersBuilder parametersBuilder = TextParametersBuilder.of();
129+
130+
if (finalOptions.getAutoTruncate() != null) {
131+
parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
134132
}
135-
catch (Exception e) {
133+
134+
if (finalOptions.getDimensions() != null) {
135+
parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
136+
}
137+
138+
predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));
139+
140+
for (int i = 0; i < request.getInstructions().size(); i++) {
141+
142+
TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
143+
.withTaskType(finalOptions.getTaskType().name());
144+
if (StringUtils.hasText(finalOptions.getTitle())) {
145+
instanceBuilder.withTitle(finalOptions.getTitle());
146+
}
147+
predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
148+
}
149+
return predictRequestBuilder;
150+
}
151+
152+
// for testing
153+
PredictionServiceClient createPredictionServiceClient() {
154+
try {
155+
return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings());
156+
}
157+
catch (IOException e) {
136158
throw new RuntimeException(e);
137159
}
138160
}
139161

162+
// for testing
163+
PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
164+
PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
165+
return embeddingResponse;
166+
}
167+
140168
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
141169
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
142170
metadata.setModel(model);

models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.springframework.ai.embedding.DocumentEmbeddingRequest;
2525
import org.springframework.ai.embedding.EmbeddingResponse;
2626
import org.springframework.ai.embedding.EmbeddingResultMetadata;
27-
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
27+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
2828
import org.springframework.beans.factory.annotation.Autowired;
2929
import org.springframework.boot.SpringBootConfiguration;
3030
import org.springframework.boot.test.context.SpringBootTest;
@@ -213,16 +213,16 @@ void textImageAndVideoEmbedding() {
213213
static class Config {
214214

215215
@Bean
216-
public VertexAiEmbeddigConnectionDetails connectionDetails() {
217-
return VertexAiEmbeddigConnectionDetails.builder()
216+
public VertexAiEmbeddingConnectionDetails connectionDetails() {
217+
return VertexAiEmbeddingConnectionDetails.builder()
218218
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
219219
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
220220
.build();
221221
}
222222

223223
@Bean
224224
public VertexAiMultimodalEmbeddingModel vertexAiEmbeddingModel(
225-
VertexAiEmbeddigConnectionDetails connectionDetails) {
225+
VertexAiEmbeddingConnectionDetails connectionDetails) {
226226

227227
VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder()
228228
.withModel(VertexAiMultimodalEmbeddingModelName.MULTIMODAL_EMBEDDING_001)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vertexai.embedding.text;
18+
19+
import com.google.cloud.aiplatform.v1.EndpointName;
20+
import com.google.cloud.aiplatform.v1.PredictRequest;
21+
import com.google.cloud.aiplatform.v1.PredictResponse;
22+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
23+
import org.springframework.ai.embedding.EmbeddingRequest;
24+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
25+
import org.springframework.retry.support.RetryTemplate;
26+
27+
import java.io.IOException;
28+
29+
public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel {
30+
31+
private PredictionServiceClient mockPredictionServiceClient;
32+
33+
private PredictRequest.Builder mockPredictRequestBuilder;
34+
35+
public TestVertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
36+
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
37+
super(connectionDetails, defaultEmbeddingOptions, retryTemplate);
38+
}
39+
40+
public void setMockPredictionServiceClient(PredictionServiceClient mockPredictionServiceClient) {
41+
this.mockPredictionServiceClient = mockPredictionServiceClient;
42+
}
43+
44+
@Override
45+
PredictionServiceClient createPredictionServiceClient() {
46+
if (mockPredictionServiceClient != null) {
47+
return mockPredictionServiceClient;
48+
}
49+
return super.createPredictionServiceClient();
50+
}
51+
52+
@Override
53+
PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
54+
if (mockPredictionServiceClient != null) {
55+
return mockPredictionServiceClient.predict(predictRequestBuilder.build());
56+
}
57+
return super.getPredictResponse(client, predictRequestBuilder);
58+
}
59+
60+
public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictRequestBuilder) {
61+
this.mockPredictRequestBuilder = mockPredictRequestBuilder;
62+
}
63+
64+
@Override
65+
protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
66+
VertexAiTextEmbeddingOptions finalOptions) {
67+
if (mockPredictRequestBuilder != null) {
68+
return mockPredictRequestBuilder;
69+
}
70+
return super.getPredictRequestBuilder(request, endpointName, finalOptions);
71+
}
72+
73+
}

models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.junit.jupiter.params.provider.ValueSource;
2525
import org.springframework.ai.embedding.EmbeddingRequest;
2626
import org.springframework.ai.embedding.EmbeddingResponse;
27-
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
27+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
2828
import org.springframework.beans.factory.annotation.Autowired;
2929
import org.springframework.boot.SpringBootConfiguration;
3030
import org.springframework.boot.test.context.SpringBootTest;
@@ -67,15 +67,15 @@ void defaultEmbedding(String modelName) {
6767
static class Config {
6868

6969
@Bean
70-
public VertexAiEmbeddigConnectionDetails connectionDetails() {
71-
return VertexAiEmbeddigConnectionDetails.builder()
70+
public VertexAiEmbeddingConnectionDetails connectionDetails() {
71+
return VertexAiEmbeddingConnectionDetails.builder()
7272
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
7373
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
7474
.build();
7575
}
7676

7777
@Bean
78-
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails) {
78+
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails) {
7979

8080
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
8181
.withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)

0 commit comments

Comments
 (0)