29
29
import org .springframework .ai .embedding .EmbeddingResponse ;
30
30
import org .springframework .ai .embedding .EmbeddingResponseMetadata ;
31
31
import org .springframework .ai .model .ModelOptionsUtils ;
32
- import org .springframework .ai .vertexai .embedding .VertexAiEmbeddigConnectionDetails ;
32
+ import org .springframework .ai .retry .RetryUtils ;
33
+ import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingConnectionDetails ;
34
+ import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingUsage ;
33
35
import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingUtils ;
34
36
import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingUtils .TextInstanceBuilder ;
35
37
import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingUtils .TextParametersBuilder ;
36
- import org .springframework .ai . vertexai . embedding . VertexAiEmbeddingUsage ;
38
+ import org .springframework .retry . support . RetryTemplate ;
37
39
import org .springframework .util .Assert ;
38
40
import org .springframework .util .StringUtils ;
39
41
42
+ import java .io .IOException ;
40
43
import java .util .ArrayList ;
41
44
import java .util .List ;
42
45
import java .util .Map ;
47
50
* A class representing a Vertex AI Text Embedding Model.
48
51
*
49
52
* @author Christian Tzolov
53
+ * @author Mark Pollack
50
54
* @since 1.0.0
51
55
*/
52
56
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
53
57
54
58
public final VertexAiTextEmbeddingOptions defaultOptions ;
55
59
56
- private final VertexAiEmbeddigConnectionDetails connectionDetails ;
60
+ private final VertexAiEmbeddingConnectionDetails connectionDetails ;
61
+
62
+ private final RetryTemplate retryTemplate ;
57
63
58
- public VertexAiTextEmbeddingModel (VertexAiEmbeddigConnectionDetails connectionDetails ,
64
+ public VertexAiTextEmbeddingModel (VertexAiEmbeddingConnectionDetails connectionDetails ,
59
65
VertexAiTextEmbeddingOptions defaultEmbeddingOptions ) {
66
+ this (connectionDetails , defaultEmbeddingOptions , RetryUtils .DEFAULT_RETRY_TEMPLATE );
67
+ }
60
68
69
+ public VertexAiTextEmbeddingModel (VertexAiEmbeddingConnectionDetails connectionDetails ,
70
+ VertexAiTextEmbeddingOptions defaultEmbeddingOptions , RetryTemplate retryTemplate ) {
61
71
Assert .notNull (defaultEmbeddingOptions , "VertexAiTextEmbeddingOptions must not be null" );
62
-
72
+ Assert . notNull ( retryTemplate , "retryTemplate must not be null" );
63
73
this .defaultOptions = defaultEmbeddingOptions .initializeDefaults ();
64
-
65
74
this .connectionDetails = connectionDetails ;
75
+ this .retryTemplate = retryTemplate ;
66
76
}
67
77
68
78
@ Override
@@ -73,46 +83,23 @@ public float[] embed(Document document) {
73
83
74
84
@ Override
75
85
public EmbeddingResponse call (EmbeddingRequest request ) {
86
+ return retryTemplate .execute (context -> {
87
+ VertexAiTextEmbeddingOptions finalOptions = this .defaultOptions ;
76
88
77
- VertexAiTextEmbeddingOptions finalOptions = this .defaultOptions ;
78
-
79
- if (request .getOptions () != null && request .getOptions () != EmbeddingOptions .EMPTY ) {
80
- var defaultOptionsCopy = VertexAiTextEmbeddingOptions .builder ().from (this .defaultOptions ).build ();
81
- finalOptions = ModelOptionsUtils .merge (request .getOptions (), defaultOptionsCopy ,
82
- VertexAiTextEmbeddingOptions .class );
83
- }
84
-
85
- try (PredictionServiceClient client = PredictionServiceClient
86
- .create (this .connectionDetails .getPredictionServiceSettings ())) {
87
-
88
- EndpointName endpointName = this .connectionDetails .getEndpointName (finalOptions .getModel ());
89
-
90
- PredictRequest .Builder predictRequestBuilder = PredictRequest .newBuilder ()
91
- .setEndpoint (endpointName .toString ());
92
-
93
- TextParametersBuilder parametersBuilder = TextParametersBuilder .of ();
94
-
95
- if (finalOptions .getAutoTruncate () != null ) {
96
- parametersBuilder .withAutoTruncate (finalOptions .getAutoTruncate ());
97
- }
98
-
99
- if (finalOptions .getDimensions () != null ) {
100
- parametersBuilder .withOutputDimensionality (finalOptions .getDimensions ());
89
+ if (request .getOptions () != null && request .getOptions () != EmbeddingOptions .EMPTY ) {
90
+ var defaultOptionsCopy = VertexAiTextEmbeddingOptions .builder ().from (this .defaultOptions ).build ();
91
+ finalOptions = ModelOptionsUtils .merge (request .getOptions (), defaultOptionsCopy ,
92
+ VertexAiTextEmbeddingOptions .class );
101
93
}
102
94
103
- predictRequestBuilder . setParameters ( VertexAiEmbeddingUtils . valueOf ( parametersBuilder . build ()) );
95
+ PredictionServiceClient client = createPredictionServiceClient ( );
104
96
105
- for ( int i = 0 ; i < request . getInstructions (). size (); i ++) {
97
+ EndpointName endpointName = this . connectionDetails . getEndpointName ( finalOptions . getModel ());
106
98
107
- TextInstanceBuilder instanceBuilder = TextInstanceBuilder .of (request .getInstructions ().get (i ))
108
- .withTaskType (finalOptions .getTaskType ().name ());
109
- if (StringUtils .hasText (finalOptions .getTitle ())) {
110
- instanceBuilder .withTitle (finalOptions .getTitle ());
111
- }
112
- predictRequestBuilder .addInstances (VertexAiEmbeddingUtils .valueOf (instanceBuilder .build ()));
113
- }
99
+ PredictRequest .Builder predictRequestBuilder = getPredictRequestBuilder (request , endpointName ,
100
+ finalOptions );
114
101
115
- PredictResponse embeddingResponse = client . predict ( predictRequestBuilder . build () );
102
+ PredictResponse embeddingResponse = getPredictResponse ( client , predictRequestBuilder );
116
103
117
104
int index = 0 ;
118
105
int totalTokenCount = 0 ;
@@ -131,12 +118,53 @@ public EmbeddingResponse call(EmbeddingRequest request) {
131
118
}
132
119
return new EmbeddingResponse (embeddingList ,
133
120
generateResponseMetadata (finalOptions .getModel (), totalTokenCount ));
121
+ });
122
+ }
123
+
124
+ protected PredictRequest .Builder getPredictRequestBuilder (EmbeddingRequest request , EndpointName endpointName ,
125
+ VertexAiTextEmbeddingOptions finalOptions ) {
126
+ PredictRequest .Builder predictRequestBuilder = PredictRequest .newBuilder ().setEndpoint (endpointName .toString ());
127
+
128
+ TextParametersBuilder parametersBuilder = TextParametersBuilder .of ();
129
+
130
+ if (finalOptions .getAutoTruncate () != null ) {
131
+ parametersBuilder .withAutoTruncate (finalOptions .getAutoTruncate ());
134
132
}
135
- catch (Exception e ) {
133
+
134
+ if (finalOptions .getDimensions () != null ) {
135
+ parametersBuilder .withOutputDimensionality (finalOptions .getDimensions ());
136
+ }
137
+
138
+ predictRequestBuilder .setParameters (VertexAiEmbeddingUtils .valueOf (parametersBuilder .build ()));
139
+
140
+ for (int i = 0 ; i < request .getInstructions ().size (); i ++) {
141
+
142
+ TextInstanceBuilder instanceBuilder = TextInstanceBuilder .of (request .getInstructions ().get (i ))
143
+ .withTaskType (finalOptions .getTaskType ().name ());
144
+ if (StringUtils .hasText (finalOptions .getTitle ())) {
145
+ instanceBuilder .withTitle (finalOptions .getTitle ());
146
+ }
147
+ predictRequestBuilder .addInstances (VertexAiEmbeddingUtils .valueOf (instanceBuilder .build ()));
148
+ }
149
+ return predictRequestBuilder ;
150
+ }
151
+
152
+ // for testing
153
+ PredictionServiceClient createPredictionServiceClient () {
154
+ try {
155
+ return PredictionServiceClient .create (this .connectionDetails .getPredictionServiceSettings ());
156
+ }
157
+ catch (IOException e ) {
136
158
throw new RuntimeException (e );
137
159
}
138
160
}
139
161
162
+ // for testing
163
+ PredictResponse getPredictResponse (PredictionServiceClient client , PredictRequest .Builder predictRequestBuilder ) {
164
+ PredictResponse embeddingResponse = client .predict (predictRequestBuilder .build ());
165
+ return embeddingResponse ;
166
+ }
167
+
140
168
private EmbeddingResponseMetadata generateResponseMetadata (String model , Integer totalTokens ) {
141
169
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata ();
142
170
metadata .setModel (model );
0 commit comments