Skip to content

Commit 73d0b30

Browse files
sobychackomarkpollack
authored andcommitted
Enhance TokenCountBatchingStrategy with reserve percentage
- Precompute document token counts before batching into List<List<Document>> - Introduce configurable reserve percentage for max input token count Resolves #1260
1 parent 3cab5bd commit 73d0b30

File tree

1 file changed

+41
-9
lines changed

1 file changed

+41
-9
lines changed

spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
package org.springframework.ai.embedding;
1717

1818
import java.util.ArrayList;
19+
import java.util.HashMap;
1920
import java.util.List;
21+
import java.util.Map;
2022

2123
import org.springframework.ai.document.ContentFormatter;
2224
import org.springframework.ai.document.Document;
@@ -31,7 +33,19 @@
3133
* max input token as the default:
3234
* https://platform.openai.com/docs/guides/embeddings/embedding-models.
3335
*
36+
* This strategy incorporates a reserve percentage to provide a buffer for potential
37+
* overhead or unexpected increases in token count during processing. The actual max input
38+
* token count used is calculated as: actualMaxInputTokenCount =
39+
* originalMaxInputTokenCount * (1 - RESERVE_PERCENTAGE)
40+
*
41+
* For example, with the default reserve percentage of 10% (0.1) and the default max input
42+
* token count of 8191, the actual max input token count used will be 7371.
43+
*
44+
* The strategy batches documents based on their token counts, ensuring that each batch
45+
* does not exceed the calculated max input token count.
46+
*
3447
* @author Soby Chacko
48+
* @author Mark Pollack
3549
* @since 1.0.0
3650
*/
3751
public class TokenCountBatchingStrategy implements BatchingStrategy {
@@ -41,6 +55,12 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
4155
*/
4256
private static final int MAX_INPUT_TOKEN_COUNT = 8191;
4357

58+
/**
59+
* The default percentage of tokens to reserve when calculating the actual max input
60+
* token count.
61+
*/
62+
private static final double DEFAULT_TOKEN_COUNT_RESERVE_PERCENTAGE = 0.1;
63+
4464
private final TokenCountEstimator tokenCountEstimator;
4565

4666
private final int maxInputTokenCount;
@@ -50,27 +70,33 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
5070
private final MetadataMode metadataMode;
5171

5272
public TokenCountBatchingStrategy() {
53-
this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT);
73+
this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT, DEFAULT_TOKEN_COUNT_RESERVE_PERCENTAGE);
5474
}
5575

5676
/**
5777
* @param encodingType {@link EncodingType}
78+
* @param thresholdFactor the threshold factor to use on top of the max input token
79+
* count
5880
* @param maxInputTokenCount upper limit for input tokens
5981
*/
60-
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount) {
61-
this(encodingType, maxInputTokenCount, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
82+
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double thresholdFactor) {
83+
this(encodingType, maxInputTokenCount, thresholdFactor, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
6284
}
6385

6486
/**
65-
* @param encodingType {@link EncodingType}
66-
* @param maxInputTokenCount upper limit for input tokens
67-
* @param contentFormatter {@link ContentFormatter}
68-
* @param metadataMode {@link MetadataMode}
87+
* @param encodingType The {@link EncodingType} to be used for token counting.
88+
* @param maxInputTokenCount The initial upper limit for input tokens.
89+
* @param reservePercentage The percentage of tokens to reserve from the max input
90+
* token count. This creates a buffer for potential token count increases during
91+
* processing.
92+
* @param contentFormatter the {@link ContentFormatter} to be used for formatting
93+
* content.
94+
* @param metadataMode The {@link MetadataMode} to be used for handling metadata.
6995
*/
70-
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount,
96+
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage,
7197
ContentFormatter contentFormatter, MetadataMode metadataMode) {
7298
this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
73-
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount - (maxInputTokenCount * .1));
99+
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage));
74100
this.contentFormater = contentFormatter;
75101
this.metadataMode = metadataMode;
76102
}
@@ -80,6 +106,7 @@ public List<List<Document>> batch(List<Document> documents) {
80106
List<List<Document>> batches = new ArrayList<>();
81107
int currentSize = 0;
82108
List<Document> currentBatch = new ArrayList<>();
109+
Map<Document, Integer> documentTokens = new HashMap<>();
83110

84111
for (Document document : documents) {
85112
int tokenCount = this.tokenCountEstimator
@@ -88,6 +115,11 @@ public List<List<Document>> batch(List<Document> documents) {
88115
throw new IllegalArgumentException(
89116
"Tokens in a single document exceeds the maximum number of allowed input tokens");
90117
}
118+
documentTokens.put(document, tokenCount);
119+
}
120+
121+
for (Document document : documentTokens.keySet()) {
122+
Integer tokenCount = documentTokens.get(document);
91123
if (currentSize + tokenCount > maxInputTokenCount) {
92124
batches.add(currentBatch);
93125
currentBatch.clear();

0 commit comments

Comments
 (0)