Skip to content

Commit 087de16

Browse files
sobychackomarkpollack
authored andcommitted
Add batching strategy for embedding documents in PgVectorStore
- Precompute all embeddings using a BatchingStrategy before inserting into the vector store This optimization improves efficiency when adding multiple documents Related to #1261
1 parent 73d0b30 commit 087de16

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.autoconfigure.vectorstore.pgvector;
1718

1819
import javax.sql.DataSource;
1920

21+
import org.springframework.ai.embedding.BatchingStrategy;
2022
import org.springframework.ai.embedding.EmbeddingModel;
23+
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
2124
import org.springframework.ai.vectorstore.PgVectorStore;
2225
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
2326
import org.springframework.beans.factory.ObjectProvider;
@@ -34,17 +37,26 @@
3437
/**
3538
* @author Christian Tzolov
3639
* @author Josh Long
40+
* @author Soby Chacko
41+
* @since 1.0.0
3742
*/
3843
@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class)
3944
@ConditionalOnClass({ PgVectorStore.class, DataSource.class, JdbcTemplate.class })
4045
@EnableConfigurationProperties(PgVectorStoreProperties.class)
4146
public class PgVectorStoreAutoConfiguration {
4247

48+
@Bean
49+
@ConditionalOnMissingBean(BatchingStrategy.class)
50+
BatchingStrategy pgVectorStoreBatchingStrategy() {
51+
return new TokenCountBatchingStrategy();
52+
}
53+
4354
@Bean
4455
@ConditionalOnMissingBean
4556
public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel,
4657
PgVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
47-
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
58+
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
59+
BatchingStrategy batchingStrategy) {
4860

4961
var initializeSchema = properties.isInitializeSchema();
5062

@@ -58,6 +70,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed
5870
.withInitializeSchema(initializeSchema)
5971
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
6072
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
73+
.withBatchingStrategy(batchingStrategy)
6174
.build();
6275
}
6376

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
import org.slf4j.Logger;
2828
import org.slf4j.LoggerFactory;
2929
import org.springframework.ai.document.Document;
30+
import org.springframework.ai.embedding.BatchingStrategy;
3031
import org.springframework.ai.embedding.EmbeddingModel;
32+
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
33+
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
3134
import org.springframework.ai.observation.conventions.VectorStoreProvider;
3235
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
3336
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
@@ -57,6 +60,8 @@
5760
* @author Josh Long
5861
* @author Muthukumaran Navaneethakrishnan
5962
* @author Thomas Vitale
63+
* @author Soby Chacko
64+
* @since 1.0.0
6065
*/
6166
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {
6267

@@ -90,17 +95,19 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
9095

9196
private final boolean initializeSchema;
9297

93-
private int dimensions;
98+
private final int dimensions;
9499

95-
private PgDistanceType distanceType;
100+
private final PgDistanceType distanceType;
96101

97-
private ObjectMapper objectMapper = new ObjectMapper();
102+
private final ObjectMapper objectMapper = new ObjectMapper();
98103

99-
private boolean removeExistingVectorStoreTable;
104+
private final boolean removeExistingVectorStoreTable;
100105

101-
private PgIndexType createIndexMethod;
106+
private final PgIndexType createIndexMethod;
102107

103-
private PgVectorSchemaValidator schemaValidator;
108+
private final PgVectorSchemaValidator schemaValidator;
109+
110+
private final BatchingStrategy batchingStrategy;
104111

105112
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
106113
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
@@ -134,13 +141,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
134141

135142
this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
136143
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
137-
ObservationRegistry.NOOP, null);
144+
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy());
138145
}
139146

140147
private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
141148
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
142149
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
143-
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) {
150+
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
151+
BatchingStrategy batchingStrategy) {
144152

145153
super(observationRegistry, customObservationConvention);
146154

@@ -163,6 +171,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
163171
this.createIndexMethod = createIndexMethod;
164172
this.initializeSchema = initializeSchema;
165173
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
174+
this.batchingStrategy = batchingStrategy;
166175
}
167176

168177
public PgDistanceType getDistanceType() {
@@ -174,6 +183,8 @@ public void doAdd(List<Document> documents) {
174183

175184
int size = documents.size();
176185

186+
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
187+
177188
this.jdbcTemplate.batchUpdate(
178189
"INSERT INTO " + getFullyQualifiedTableName()
179190
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
@@ -185,8 +196,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException {
185196
var document = documents.get(i);
186197
var content = document.getContent();
187198
var json = toJson(document.getMetadata());
188-
var embedding = embeddingModel.embed(document);
189-
document.setEmbedding(embedding);
199+
var embedding = document.getEmbedding();
190200
var pGvector = new PGvector(embedding);
191201

192202
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
@@ -497,6 +507,8 @@ public static class Builder {
497507

498508
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
499509

510+
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
511+
500512
@Nullable
501513
private VectorStoreObservationConvention searchObservationConvention;
502514

@@ -559,10 +571,16 @@ public Builder withSearchObservationConvention(VectorStoreObservationConvention
559571
return this;
560572
}
561573

574+
public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
575+
this.batchingStrategy = batchingStrategy;
576+
return this;
577+
}
578+
562579
public PgVectorStore build() {
563-
return new PgVectorStore(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate,
564-
embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, indexType,
565-
initializeSchema, observationRegistry, searchObservationConvention);
580+
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
581+
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
582+
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
583+
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy);
566584
}
567585

568586
}

0 commit comments

Comments
 (0)