16
16
package org .springframework .ai .embedding ;
17
17
18
18
import java .util .ArrayList ;
19
+ import java .util .HashMap ;
19
20
import java .util .List ;
21
+ import java .util .Map ;
20
22
21
23
import org .springframework .ai .document .ContentFormatter ;
22
24
import org .springframework .ai .document .Document ;
31
33
* max input token as the default:
32
34
* https://platform.openai.com/docs/guides/embeddings/embedding-models.
33
35
*
36
+ * This strategy incorporates a reserve percentage to provide a buffer for potential
37
+ * overhead or unexpected increases in token count during processing. The actual max input
38
+ * token count used is calculated as: actualMaxInputTokenCount =
39
+ * originalMaxInputTokenCount * (1 - RESERVE_PERCENTAGE)
40
+ *
41
+ * For example, with the default reserve percentage of 10% (0.1) and the default max input
42
+ * token count of 8191, the actual max input token count used will be 7371.
43
+ *
44
+ * The strategy batches documents based on their token counts, ensuring that each batch
45
+ * does not exceed the calculated max input token count.
46
+ *
34
47
* @author Soby Chacko
48
+ * @author Mark Pollack
35
49
* @since 1.0.0
36
50
*/
37
51
public class TokenCountBatchingStrategy implements BatchingStrategy {
@@ -41,6 +55,12 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
41
55
*/
42
56
private static final int MAX_INPUT_TOKEN_COUNT = 8191 ;
43
57
58
+ /**
59
+ * The default percentage of tokens to reserve when calculating the actual max input
60
+ * token count.
61
+ */
62
+ private static final double DEFAULT_TOKEN_COUNT_RESERVE_PERCENTAGE = 0.1 ;
63
+
44
64
private final TokenCountEstimator tokenCountEstimator ;
45
65
46
66
private final int maxInputTokenCount ;
@@ -50,27 +70,33 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
50
70
private final MetadataMode metadataMode ;
51
71
52
72
public TokenCountBatchingStrategy () {
53
- this (EncodingType .CL100K_BASE , MAX_INPUT_TOKEN_COUNT );
73
+ this (EncodingType .CL100K_BASE , MAX_INPUT_TOKEN_COUNT , DEFAULT_TOKEN_COUNT_RESERVE_PERCENTAGE );
54
74
}
55
75
56
76
/**
57
77
* @param encodingType {@link EncodingType}
78
+ * @param thresholdFactor the threshold factor to use on top of the max input token
79
+ * count
58
80
* @param maxInputTokenCount upper limit for input tokens
59
81
*/
60
- public TokenCountBatchingStrategy (EncodingType encodingType , int maxInputTokenCount ) {
61
- this (encodingType , maxInputTokenCount , Document .DEFAULT_CONTENT_FORMATTER , MetadataMode .NONE );
82
+ public TokenCountBatchingStrategy (EncodingType encodingType , int maxInputTokenCount , double thresholdFactor ) {
83
+ this (encodingType , maxInputTokenCount , thresholdFactor , Document .DEFAULT_CONTENT_FORMATTER , MetadataMode .NONE );
62
84
}
63
85
64
86
/**
65
- * @param encodingType {@link EncodingType}
66
- * @param maxInputTokenCount upper limit for input tokens
67
- * @param contentFormatter {@link ContentFormatter}
68
- * @param metadataMode {@link MetadataMode}
87
+ * @param encodingType The {@link EncodingType} to be used for token counting.
88
+ * @param maxInputTokenCount The initial upper limit for input tokens.
89
+ * @param reservePercentage The percentage of tokens to reserve from the max input
90
+ * token count. This creates a buffer for potential token count increases during
91
+ * processing.
92
+ * @param contentFormatter the {@link ContentFormatter} to be used for formatting
93
+ * content.
94
+ * @param metadataMode The {@link MetadataMode} to be used for handling metadata.
69
95
*/
70
- public TokenCountBatchingStrategy (EncodingType encodingType , int maxInputTokenCount ,
96
+ public TokenCountBatchingStrategy (EncodingType encodingType , int maxInputTokenCount , double reservePercentage ,
71
97
ContentFormatter contentFormatter , MetadataMode metadataMode ) {
72
98
this .tokenCountEstimator = new JTokkitTokenCountEstimator (encodingType );
73
- this .maxInputTokenCount = (int ) Math .round (maxInputTokenCount - ( maxInputTokenCount * .1 ));
99
+ this .maxInputTokenCount = (int ) Math .round (maxInputTokenCount * ( 1 - reservePercentage ));
74
100
this .contentFormater = contentFormatter ;
75
101
this .metadataMode = metadataMode ;
76
102
}
@@ -80,6 +106,7 @@ public List<List<Document>> batch(List<Document> documents) {
80
106
List <List <Document >> batches = new ArrayList <>();
81
107
int currentSize = 0 ;
82
108
List <Document > currentBatch = new ArrayList <>();
109
+ Map <Document , Integer > documentTokens = new HashMap <>();
83
110
84
111
for (Document document : documents ) {
85
112
int tokenCount = this .tokenCountEstimator
@@ -88,6 +115,11 @@ public List<List<Document>> batch(List<Document> documents) {
88
115
throw new IllegalArgumentException (
89
116
"Tokens in a single document exceeds the maximum number of allowed input tokens" );
90
117
}
118
+ documentTokens .put (document , tokenCount );
119
+ }
120
+
121
+ for (Document document : documentTokens .keySet ()) {
122
+ Integer tokenCount = documentTokens .get (document );
91
123
if (currentSize + tokenCount > maxInputTokenCount ) {
92
124
batches .add (currentBatch );
93
125
currentBatch .clear ();
0 commit comments