Skip to content

Commit

Permalink
[tnx] update default neuron rolling batch for correct mpi mode config (
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg authored Apr 20, 2024
1 parent 812c164 commit 8bf6106
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ private static void setRollingBatch(
} else if (!isTextGenerationModel(modelConfig)) {
// Non text-generation use-cases are not compatible with rolling batch
rollingBatch = "disable";
} else if (isTnxEnabled(features)) {
rollingBatch = "tnx";
} else if (isLmiDistEnabled(features)
&& "lmi-dist".equals(MODEL_TO_ROLLING_BATCH.get(modelType))) {
rollingBatch = "lmi-dist";
Expand Down Expand Up @@ -175,6 +177,10 @@ private static boolean isTrtLlmEnabled(String features) {
return features != null && features.contains("trtllm");
}

private static boolean isTnxEnabled(String features) {
return features != null && features.contains("tnx");
}

private static boolean isT5TrtLlm(
LmiUtils.HuggingFaceModelConfig modelConfig, String features) {
return isTrtLlmEnabled(features) && "t5".equals(modelConfig.getModelType());
Expand Down

0 comments on commit 8bf6106

Please sign in to comment.