27
27
import org .slf4j .Logger ;
28
28
import org .slf4j .LoggerFactory ;
29
29
import org .springframework .ai .document .Document ;
30
+ import org .springframework .ai .embedding .BatchingStrategy ;
30
31
import org .springframework .ai .embedding .EmbeddingModel ;
32
+ import org .springframework .ai .embedding .EmbeddingOptionsBuilder ;
33
+ import org .springframework .ai .embedding .TokenCountBatchingStrategy ;
31
34
import org .springframework .ai .observation .conventions .VectorStoreProvider ;
32
35
import org .springframework .ai .observation .conventions .VectorStoreSimilarityMetric ;
33
36
import org .springframework .ai .vectorstore .filter .FilterExpressionConverter ;
57
60
* @author Josh Long
58
61
* @author Muthukumaran Navaneethakrishnan
59
62
* @author Thomas Vitale
63
+ * @author Soby Chacko
64
+ * @since 1.0.0
60
65
*/
61
66
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {
62
67
@@ -90,17 +95,19 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
90
95
91
96
private final boolean initializeSchema ;
92
97
93
- private int dimensions ;
98
+ private final int dimensions ;
94
99
95
- private PgDistanceType distanceType ;
100
+ private final PgDistanceType distanceType ;
96
101
97
- private ObjectMapper objectMapper = new ObjectMapper ();
102
+ private final ObjectMapper objectMapper = new ObjectMapper ();
98
103
99
- private boolean removeExistingVectorStoreTable ;
104
+ private final boolean removeExistingVectorStoreTable ;
100
105
101
- private PgIndexType createIndexMethod ;
106
+ private final PgIndexType createIndexMethod ;
102
107
103
- private PgVectorSchemaValidator schemaValidator ;
108
+ private final PgVectorSchemaValidator schemaValidator ;
109
+
110
+ private final BatchingStrategy batchingStrategy ;
104
111
105
112
public PgVectorStore (JdbcTemplate jdbcTemplate , EmbeddingModel embeddingModel ) {
106
113
this (jdbcTemplate , embeddingModel , INVALID_EMBEDDING_DIMENSION , PgDistanceType .COSINE_DISTANCE , false ,
@@ -134,13 +141,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
134
141
135
142
this (schemaName , vectorTableName , vectorTableValidationsEnabled , jdbcTemplate , embeddingModel , dimensions ,
136
143
distanceType , removeExistingVectorStoreTable , createIndexMethod , initializeSchema ,
137
- ObservationRegistry .NOOP , null );
144
+ ObservationRegistry .NOOP , null , new TokenCountBatchingStrategy () );
138
145
}
139
146
140
147
private PgVectorStore (String schemaName , String vectorTableName , boolean vectorTableValidationsEnabled ,
141
148
JdbcTemplate jdbcTemplate , EmbeddingModel embeddingModel , int dimensions , PgDistanceType distanceType ,
142
149
boolean removeExistingVectorStoreTable , PgIndexType createIndexMethod , boolean initializeSchema ,
143
- ObservationRegistry observationRegistry , VectorStoreObservationConvention customObservationConvention ) {
150
+ ObservationRegistry observationRegistry , VectorStoreObservationConvention customObservationConvention ,
151
+ BatchingStrategy batchingStrategy ) {
144
152
145
153
super (observationRegistry , customObservationConvention );
146
154
@@ -163,6 +171,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
163
171
this .createIndexMethod = createIndexMethod ;
164
172
this .initializeSchema = initializeSchema ;
165
173
this .schemaValidator = new PgVectorSchemaValidator (jdbcTemplate );
174
+ this .batchingStrategy = batchingStrategy ;
166
175
}
167
176
168
177
public PgDistanceType getDistanceType () {
@@ -174,6 +183,8 @@ public void doAdd(List<Document> documents) {
174
183
175
184
int size = documents .size ();
176
185
186
+ this .embeddingModel .embed (documents , EmbeddingOptionsBuilder .builder ().build (), this .batchingStrategy );
187
+
177
188
this .jdbcTemplate .batchUpdate (
178
189
"INSERT INTO " + getFullyQualifiedTableName ()
179
190
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
@@ -185,8 +196,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException {
185
196
var document = documents .get (i );
186
197
var content = document .getContent ();
187
198
var json = toJson (document .getMetadata ());
188
- var embedding = embeddingModel .embed (document );
189
- document .setEmbedding (embedding );
199
+ var embedding = document .getEmbedding ();
190
200
var pGvector = new PGvector (embedding );
191
201
192
202
StatementCreatorUtils .setParameterValue (ps , 1 , SqlTypeValue .TYPE_UNKNOWN ,
@@ -497,6 +507,8 @@ public static class Builder {
497
507
498
508
private ObservationRegistry observationRegistry = ObservationRegistry .NOOP ;
499
509
510
+ private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy ();
511
+
500
512
@ Nullable
501
513
private VectorStoreObservationConvention searchObservationConvention ;
502
514
@@ -559,10 +571,16 @@ public Builder withSearchObservationConvention(VectorStoreObservationConvention
559
571
return this ;
560
572
}
561
573
574
+ public Builder withBatchingStrategy (BatchingStrategy batchingStrategy ) {
575
+ this .batchingStrategy = batchingStrategy ;
576
+ return this ;
577
+ }
578
+
562
579
public PgVectorStore build () {
563
- return new PgVectorStore (schemaName , vectorTableName , vectorTableValidationsEnabled , jdbcTemplate ,
564
- embeddingModel , dimensions , distanceType , removeExistingVectorStoreTable , indexType ,
565
- initializeSchema , observationRegistry , searchObservationConvention );
580
+ return new PgVectorStore (this .schemaName , this .vectorTableName , this .vectorTableValidationsEnabled ,
581
+ this .jdbcTemplate , this .embeddingModel , this .dimensions , this .distanceType ,
582
+ this .removeExistingVectorStoreTable , this .indexType , this .initializeSchema ,
583
+ this .observationRegistry , this .searchObservationConvention , this .batchingStrategy );
566
584
}
567
585
568
586
}
0 commit comments