diff --git a/engines/python/setup/djl_python/neuron_utils/model_loader.py b/engines/python/setup/djl_python/neuron_utils/model_loader.py index fec6f611a..26ce7ab5c 100644 --- a/engines/python/setup/djl_python/neuron_utils/model_loader.py +++ b/engines/python/setup/djl_python/neuron_utils/model_loader.py @@ -161,11 +161,10 @@ def can_use_continuous_batching(self) -> bool: :return: bool indicating if continuous batching can be used """ - use_continuous_batching = (self.config.rolling_batch != "disable" - and self.config.rolling_batch_strategy - == TnXGenerationStrategy.continuous_batching - and self.config.max_rolling_batch_size - > 1) or self.config.rolling_batch == "vllm" + use_continuous_batching = ( + self.config.rolling_batch != "disable" + and self.config.rolling_batch_strategy + == TnXGenerationStrategy.continuous_batching) return use_continuous_batching def set_neuron_config(self) -> None: diff --git a/serving/docker/pytorch-inf2.Dockerfile b/serving/docker/pytorch-inf2.Dockerfile index 1f4642413..9212b53aa 100644 --- a/serving/docker/pytorch-inf2.Dockerfile +++ b/serving/docker/pytorch-inf2.Dockerfile @@ -20,13 +20,13 @@ ARG transformers_neuronx_version=0.11.351 ARG neuronx_distributed_version=0.8.0 ARG neuronx_cc_version=2.14.227.0 ARG protobuf_version=3.19.6 -ARG transformers_version=4.43.1 +ARG transformers_version=4.43.2 ARG accelerate_version=0.29.2 ARG diffusers_version=0.28.2 ARG pydantic_version=2.6.1 ARG optimum_neuron_version=0.0.24 # %2B is the url escape for the '+' character -ARG vllm_wheel="https://publish.djl.ai/neuron_vllm/vllm-0.5.0%2Bnightly-py3-none-any.whl" +ARG vllm_wheel="https://publish.djl.ai/neuron_vllm/vllm-0.6.0%2Bnightly-py3-none-any.whl" EXPOSE 8080 # Sets up Path for Neuron tools @@ -73,12 +73,13 @@ RUN mkdir -p /opt/djl/bin && cp scripts/telemetry.sh /opt/djl/bin && \ scripts/install_djl_serving.sh $djl_version && \ scripts/install_djl_serving.sh $djl_version ${torch_version} && \ scripts/install_inferentia2.sh && \ - pip install accelerate==${accelerate_version} safetensors ${vllm_wheel} torchvision==${torchvision_version} \ + pip install accelerate==${accelerate_version} safetensors torchvision==${torchvision_version} \ neuronx-cc==${neuronx_cc_version} torch-neuronx==${torch_neuronx_version} transformers-neuronx==${transformers_neuronx_version} \ neuronx_distributed==${neuronx_distributed_version} protobuf==${protobuf_version} sentencepiece jinja2 \ diffusers==${diffusers_version} opencv-contrib-python-headless Pillow --extra-index-url=https://pip.repos.neuron.amazonaws.com \ pydantic==${pydantic_version} optimum optimum-neuron==${optimum_neuron_version} tiktoken blobfile && \ - pip install transformers==${transformers_version} && \ + pip install transformers==${transformers_version} ${vllm_wheel} && \ + echo y | pip uninstall triton && \ scripts/install_s5cmd.sh x64 && \ scripts/patch_oss_dlc.sh python && \ useradd -m -d /home/djl djl && \ diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index 42972405c..b8a542558 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -109,7 +109,12 @@ private static void setRollingBatch( // Non text-generation use-cases are not compatible with rolling batch rollingBatch = "disable"; } else if (isTnxEnabled(features)) { - rollingBatch = "tnx"; + if (Integer.parseInt(lmiProperties.getProperty("option.max_rolling_batch_size", "1")) + >= 12) { + rollingBatch = "vllm"; + } else { + rollingBatch = "tnx"; + } } else if (isLmiDistEnabled(features) && "lmi-dist".equals(MODEL_TO_ROLLING_BATCH.get(modelType))) { rollingBatch = "lmi-dist"; diff --git a/wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java index d6df8a62f..8f1b791c7 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java @@ -15,6 +15,9 @@ import ai.djl.serving.wlm.LmiUtils.HuggingFaceModelConfig; import ai.djl.util.NeuronUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.util.ArrayList; import java.util.List; import java.util.Properties; @@ -23,7 +26,8 @@ public class NeuronSmartDefaultUtils { private static final float BILLION = 1_000_000_000.0F; - private static final int MAX_ROLLING_BATCH = 128; // Current cap for NeuronSDK 2.19.1 + private static final int MAX_ROLLING_BATCH = + 32; // Current best throughput and latency balance batch size private static final float MEMORY_PER_CORE = 16.0F; // Currently there is only one config w/ 16 gb per core @@ -31,6 +35,8 @@ public class NeuronSmartDefaultUtils { private float modelSizeInGb; private float sequenceSizeInGb; + private static final Logger logger = LoggerFactory.getLogger(NeuronSmartDefaultUtils.class); + /** * Applies smart defaults for Neuron models. * @@ -53,6 +59,7 @@ public void applySmartDefaults(Properties prop, HuggingFaceModelConfig modelConf } prop.setProperty( "option.n_positions", String.valueOf(modelConfig.getDefaultNPositions())); + logger.info("[Smart Default] N_POSITIONS: {}.", prop.getProperty("option.n_positions")); } setInternalSettings(prop, modelConfig); setHeuristicNeuronTPDegree(prop); @@ -77,6 +84,7 @@ private void setInternalSettings(Properties prop, HuggingFaceModelConfig modelCo modelSizeInGb = (paramBytes * modelConfig.getModelParameters()) / BILLION; sequenceSizeInGb = modelConfig.getApproxMemoryForSingleSequence(nPositions, paramBytes) + * 0.95F / (1024.0F * 1024.0F * 1024.0F); } @@ -119,6 +127,9 @@ private void setHeuristicNeuronTPDegree(Properties prop) { if (prop.containsKey("option.tensor_parallel_degree") && "max".equals(prop.getProperty("option.tensor_parallel_degree"))) { prop.setProperty("option.tensor_parallel_degree", String.valueOf(availableCores)); + logger.info( + "[Smart Default] TENSOR_PARALLEL_DEGREE:" + " {}.", + prop.getProperty("option.tensor_parallel_degree")); return; } @@ -130,13 +141,17 @@ private void setHeuristicNeuronTPDegree(Properties prop) { int totalInstanceConcurrency = getMaxConcurrency(totalMemory, tpDegree); for (int coreConfig : coreConfigs) { float maxMemory = coreConfig * MEMORY_PER_CORE; - int maxConcurrency = getMaxConcurrency(maxMemory, coreConfig); + int maxConcurrency = + getMaxConcurrency(maxMemory, coreConfig) * (availableCores / coreConfig); if (maxConcurrency >= totalInstanceConcurrency && coreConfig <= tpDegree) { tpDegree = coreConfig; totalInstanceConcurrency = maxConcurrency; } } prop.setProperty("option.tensor_parallel_degree", String.valueOf(tpDegree)); + logger.info( + "[Smart Default] TENSOR_PARALLEL_DEGREE:" + " {}.", + prop.getProperty("option.tensor_parallel_degree")); } else if (!prop.containsKey("option.tensor_parallel_degree")) { // Set tensor parallel degree by minimizing TP degree that supports fixed batch size int batchSize = Integer.parseInt(prop.getProperty("option.max_rolling_batch_size")); @@ -144,13 +159,18 @@ private void setHeuristicNeuronTPDegree(Properties prop) { getMaxConcurrencyWithBatch(totalMemory, tpDegree, batchSize); for (int coreConfig : coreConfigs) { float maxMemory = coreConfig * MEMORY_PER_CORE; - int maxConcurrency = getMaxConcurrencyWithBatch(maxMemory, coreConfig, batchSize); + int maxConcurrency = + getMaxConcurrencyWithBatch(maxMemory, coreConfig, batchSize) + * (availableCores / coreConfig); if (maxConcurrency >= totalInstanceConcurrency && coreConfig <= tpDegree) { tpDegree = coreConfig; totalInstanceConcurrency = maxConcurrency; } } prop.setProperty("option.tensor_parallel_degree", String.valueOf(tpDegree)); + logger.info( + "[Smart Default] TENSOR_PARALLEL_DEGREE: {}.", + prop.getProperty("option.tensor_parallel_degree")); } } @@ -222,9 +242,9 @@ private int getMaxConcurrencyWithBatch(float totalMemory, int tpDegree, int batc private List availableCoreConfigs() { List coreConfigs = new ArrayList<>(); List availableCoreConfigs = buildCoreConfigs(availableCores); - int coresPerModel = (int) Math.ceil(modelSizeInGb / MEMORY_PER_CORE); + int coresPerModel = (int) Math.ceil(1.1F * modelSizeInGb / MEMORY_PER_CORE); for (int coreConfig : availableCoreConfigs) { - if (coresPerModel >= coreConfig) { + if (coresPerModel <= coreConfig) { coreConfigs.add(coreConfig); } } @@ -250,8 +270,10 @@ private List buildCoreConfigs(int nCores) { coreConfigs.add(i); } } - // Add the given number of cores to the list - coreConfigs.add(nCores); + // Add the given number of cores to the list if not already added + if (nCores > 8) { + coreConfigs.add(nCores); + } return coreConfigs; } @@ -274,6 +296,9 @@ private void setHeuristicNeuronMaxRollingBatch(Properties prop) { if (maxRollingBatchSize > 0) { prop.setProperty( "option.max_rolling_batch_size", String.valueOf(maxRollingBatchSize)); + logger.info( + "[Smart Default] MAX_ROLLING_BATCH_SIZE: {}.", + prop.getProperty("option.max_rolling_batch_size")); } } } diff --git a/wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java b/wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java index eac35a282..fa0f7fd4b 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java @@ -104,7 +104,7 @@ public void testApplySmartDefaultsQuantize8BModel() throws IOException { } Assert.assertEquals(prop.getProperty("option.n_positions"), "4096"); Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); - Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "8"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "16"); } @Test @@ -118,7 +118,7 @@ public void testApplySmartDefaults2BModel() throws IOException { } Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); - Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "64"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "32"); } @Test @@ -133,7 +133,7 @@ public void testApplySmartDefaultsQuantize2BModel() throws IOException { } Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); - Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "128"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "32"); } @Test @@ -147,7 +147,7 @@ public void testApplySmartDefaultsWithNPositions() throws IOException { smartDefaultUtils.applySmartDefaults(prop, modelConfig); } Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); - Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "128"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "32"); } @Test @@ -161,7 +161,7 @@ public void testApplySmartDefaultsWithTPDegree() throws IOException { smartDefaultUtils.applySmartDefaults(prop, modelConfig); } Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); - Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "64"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "32"); } @Test @@ -190,11 +190,26 @@ public void testApplySmartDefaultsWithTPMax() throws IOException { } Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); - Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "64"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "32"); + } + + @Test + public void testApplySmartDefaultsWithNeuron8bModel() throws IOException { + Properties prop = new Properties(); + LmiUtils.HuggingFaceModelConfig modelConfig = get8BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(NeuronUtils::hasNeuron).thenReturn(true); + mockedStatic.when(NeuronUtils::getNeuronCores).thenReturn(32); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "4096"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "2"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "16"); } @Test - public void testApplySmartDefaultsWithNeuron() throws IOException { + public void testApplySmartDefaultsWithNeuron70bModel() throws IOException { Properties prop = new Properties(); LmiUtils.HuggingFaceModelConfig modelConfig = get70BLlamaHuggingFaceModelConfig(); try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) {