Skip to content

GH-1260: Batching strategy improvements in TokenCountBatchingStrategy #1280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
package org.springframework.ai.embedding;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.springframework.ai.document.ContentFormatter;
import org.springframework.ai.document.Document;
Expand All @@ -41,6 +43,12 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
*/
private static final int MAX_INPUT_TOKEN_COUNT = 8191;

/**
* The actual max input token count used will be the original max input minus the
* threshold value multiplied by the original input.
*/
private static final double DEFAULT_TOKEN_COUNT_THRESHOLD_FACTOR = 0.1;

private final TokenCountEstimator tokenCountEstimator;

private final int maxInputTokenCount;
Expand All @@ -50,27 +58,31 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
private final MetadataMode metadataMode;

public TokenCountBatchingStrategy() {
this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT);
this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT, DEFAULT_TOKEN_COUNT_THRESHOLD_FACTOR);
}

/**
* @param encodingType {@link EncodingType}
* @param thresholdFactor the threshold factor to use on top of the max input token
* count
* @param maxInputTokenCount upper limit for input tokens
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount) {
this(encodingType, maxInputTokenCount, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double thresholdFactor) {
this(encodingType, maxInputTokenCount, thresholdFactor, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
}

/**
* @param encodingType {@link EncodingType}
* @param maxInputTokenCount upper limit for input tokens
* @param thresholdFactor the threshold factor to use on top of the max input token
* count
* @param contentFormatter {@link ContentFormatter}
* @param metadataMode {@link MetadataMode}
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount,
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double thresholdFactor,
ContentFormatter contentFormatter, MetadataMode metadataMode) {
this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount - (maxInputTokenCount * .1));
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount - (maxInputTokenCount * thresholdFactor));
this.contentFormater = contentFormatter;
this.metadataMode = metadataMode;
}
Expand All @@ -80,6 +92,7 @@ public List<List<Document>> batch(List<Document> documents) {
List<List<Document>> batches = new ArrayList<>();
int currentSize = 0;
List<Document> currentBatch = new ArrayList<>();
Map<Document, Integer> documentTokens = new HashMap<>();

for (Document document : documents) {
int tokenCount = this.tokenCountEstimator
Expand All @@ -88,6 +101,11 @@ public List<List<Document>> batch(List<Document> documents) {
throw new IllegalArgumentException(
"Tokens in a single document exceeds the maximum number of allowed input tokens");
}
documentTokens.put(document, tokenCount);
}

for (Document document : documentTokens.keySet()) {
Integer tokenCount = documentTokens.get(document);
if (currentSize + tokenCount > maxInputTokenCount) {
batches.add(currentBatch);
currentBatch.clear();
Expand Down