From 8bf61067ece7ee8e6d2454adca6ced0ee3338c12 Mon Sep 17 00:00:00 2001 From: Tyler Osterberg Date: Fri, 19 Apr 2024 22:30:17 -0700 Subject: [PATCH] [tnx] update default neuron rolling batch for correct mpi mode config (#1789) --- .../main/java/ai/djl/serving/wlm/LmiConfigRecommender.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index ae0a44b65..3e23e8756 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -87,6 +87,8 @@ private static void setRollingBatch( } else if (!isTextGenerationModel(modelConfig)) { // Non text-generation use-cases are not compatible with rolling batch rollingBatch = "disable"; + } else if (isTnxEnabled(features)) { + rollingBatch = "tnx"; } else if (isLmiDistEnabled(features) && "lmi-dist".equals(MODEL_TO_ROLLING_BATCH.get(modelType))) { rollingBatch = "lmi-dist"; @@ -175,6 +177,10 @@ private static boolean isTrtLlmEnabled(String features) { return features != null && features.contains("trtllm"); } + private static boolean isTnxEnabled(String features) { + return features != null && features.contains("tnx"); + } + private static boolean isT5TrtLlm( LmiUtils.HuggingFaceModelConfig modelConfig, String features) { return isTrtLlmEnabled(features) && "t5".equals(modelConfig.getModelType());