From 6d8ac15c3f0da2f73767346d906c89d775b0543f Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 28 Jun 2023 10:05:08 -0700 Subject: [PATCH 1/7] [serving] Fixes unittest in multi-GPU case (#874) 1. Fixes cu118 test case failure for MXNet 2. Fixes multi-GPU throttle test case failure --- .../src/test/java/ai/djl/serving/ModelServerTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 4badce358..f06a44df6 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -18,6 +18,7 @@ import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; +import ai.djl.engine.Engine; import ai.djl.modality.Classifications.Classification; import ai.djl.repository.MRL; import ai.djl.repository.Repository; @@ -650,7 +651,8 @@ private void testDescribeModel(Channel channel) throws InterruptedException { assertFalse(model.isLoadedAtStartup()); DescribeWorkflowResponse.Group group = model.getWorkGroups().get(0); - assertEquals(group.getDevice().isGpu(), CudaUtils.hasCuda()); + boolean hasGpu = Engine.getEngine("MXNet").getGpuCount() > 0; + assertEquals(group.getDevice().isGpu(), hasGpu); assertEquals(group.getMinWorkers(), 2); assertEquals(group.getMaxWorkers(), 4); List workers = group.getWorkers(); @@ -726,7 +728,10 @@ private void testThrottle(Channel channel) throws InterruptedException { req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, url); channel2.writeAndFlush(req).sync(); Assert.assertTrue(latch2.await(2, TimeUnit.MINUTES)); - assertEquals(httpStatus.code(), 503); + if (CudaUtils.getGpuCount() <= 1) { + // one request is not able to saturate workers in multi-GPU case + assertEquals(httpStatus.code(), 503); + } // wait for 1st response f.sync(); From 96e27921e74279a7cd09807492b809cee851e97e Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 28 Jun 2023 11:05:34 -0700 Subject: [PATCH 2/7] Refactor createVirtualEnv() (#875) * Refactor createVirtualEnv() * Updates unittest --- .../main/java/ai/djl/python/engine/PyEnv.java | 68 ++++++++--------- .../java/ai/djl/python/engine/PyModel.java | 8 +- .../ai/djl/python/engine/PyEngineTest.java | 24 ------ .../java/ai/djl/python/engine/PyEnvTest.java | 76 +++++++++++++++++++ .../src/test/resources/venv/requirements.txt | 1 - .../test/resources/venv/serving.properties | 5 -- 6 files changed, 112 insertions(+), 70 deletions(-) create mode 100644 engines/python/src/test/java/ai/djl/python/engine/PyEnvTest.java delete mode 100644 engines/python/src/test/resources/venv/requirements.txt delete mode 100644 engines/python/src/test/resources/venv/serving.properties diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java index 6011a05e8..56b18df08 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java @@ -172,37 +172,40 @@ public Map getInitParameters() { * * @param name the virtual environment name */ - public synchronized void createVirtualEnv(String name) { - if (venvCreated) { - return; - } - Path path = getVenvDir().resolve(name).toAbsolutePath(); - if (Files.exists(path)) { - logger.info("Virtual environment already exists at {}.", path); - setPythonExecutable(path.resolve("bin").resolve("python").toString()); - venvCreated = true; - return; - } - String[] cmd = {pythonExecutable, "-m", "venv", path.toString(), "--system-site-packages"}; - - try { - Process process = new ProcessBuilder(cmd).redirectErrorStream(true).start(); - String logOutput; - try (InputStream is = process.getInputStream()) { - logOutput = Utils.toString(is); + public void createVirtualEnv(String name) { + synchronized (PyEnv.class) { + if (venvCreated) { + return; } - int ret = process.waitFor(); - logger.debug("{}", logOutput); - if (ret != 0) { - throw new EngineException( - "Failed to create virtual environment with error code: " + ret); + Path path = getVenvDir().resolve(name).toAbsolutePath(); + if (Files.exists(path)) { + logger.info("Virtual environment already exists at {}.", path); + setPythonExecutable(path.resolve("bin").resolve("python").toString()); + venvCreated = true; + return; } + String[] cmd = { + pythonExecutable, "-m", "venv", path.toString(), "--system-site-packages" + }; - logger.info("Python virtual environment created successfully at {}!", path); - setPythonExecutable(path.resolve("bin").resolve("python").toString()); - venvCreated = true; - } catch (IOException | InterruptedException e) { - throw new EngineException("Python virtual failed", e); + try { + Process process = new ProcessBuilder(cmd).redirectErrorStream(true).start(); + String logOutput; + try (InputStream is = process.getInputStream()) { + logOutput = Utils.toString(is); + } + int ret = process.waitFor(); + logger.debug("{}", logOutput); + if (ret != 0) { + throw new EngineException("Failed to create venv with error code: " + ret); + } + + logger.info("Python virtual environment created successfully at {}!", path); + setPythonExecutable(path.resolve("bin").resolve("python").toString()); + venvCreated = true; + } catch (IOException | InterruptedException e) { + throw new EngineException("Create venv failed", e); + } } } @@ -212,9 +215,6 @@ public synchronized void createVirtualEnv(String name) { * @param name the virtual environment name */ public synchronized void deleteVirtualEnv(String name) { - if (!venvCreated) { - return; - } Path path = getVenvDir().resolve(name); Utils.deleteQuietly(path); } @@ -489,11 +489,7 @@ private static int getDefaultTimeout(String key, int def) { private Path getVenvDir() { String venvDir = Utils.getEnvOrSystemProperty("DJL_VENV_DIR"); if (venvDir == null || venvDir.isEmpty()) { - Path dir = Paths.get(System.getProperty("user.home")); - if (!Files.isWritable(dir)) { - dir = Paths.get(System.getProperty("java.io.tmpdir")); - } - return dir.resolve("venv"); + return Utils.getCacheDir().resolve("venv"); } return Paths.get(venvDir); } diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index 203a3dc46..6cd61d6b7 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -185,7 +185,7 @@ public void load(Path modelPath, String prefix, Map options) throws I } pyEnv.setEntryPoint(entryPoint); if (pyEnv.isEnableVenv()) { - pyEnv.createVirtualEnv(getName()); + pyEnv.createVirtualEnv(Utils.hash(modelDir.toString())); } if (pyEnv.isMpiMode()) { @@ -325,12 +325,12 @@ private void createAllPyProcesses(int mpiWorkers) { } private void shutdown() { - if (pyEnv.isEnableVenv()) { - pyEnv.deleteVirtualEnv(getName()); - } for (PyProcess process : workerQueue) { process.stopPythonProcess(); } workerQueue.clear(); + if (pyEnv.isEnableVenv()) { + pyEnv.deleteVirtualEnv(Utils.hash(modelDir.toString())); + } } } diff --git a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java index c280e3c59..d0ab428b7 100644 --- a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java +++ b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java @@ -400,30 +400,6 @@ public void testRollingBatch() throws TranslateException, IOException, ModelExce } } - @Test - public void testPythonVenv() throws IOException, ModelException { - Criteria criteria = - Criteria.builder() - .setTypes(Input.class, Output.class) - .optEngine("Python") - .optModelPath(Paths.get("src/test/resources/venv")) - .build(); - - String venvDir = "build/venv"; - System.setProperty("DJL_VENV_DIR", venvDir); - Path path; - try (ZooModel model = criteria.loadModel()) { - path = Paths.get(venvDir).resolve(model.getName()); - Assert.assertTrue(Files.isDirectory(path)); - } - Assert.assertFalse(Files.isDirectory(path)); - - // Test exception - venvDir = "/COM1"; // Invalid directory - System.setProperty("DJL_VENV_DIR", venvDir); - Assert.assertThrows(EngineException.class, criteria::loadModel); - } - @Test public void testModelException() throws TranslateException, IOException, ModelException { Criteria criteria = diff --git a/engines/python/src/test/java/ai/djl/python/engine/PyEnvTest.java b/engines/python/src/test/java/ai/djl/python/engine/PyEnvTest.java new file mode 100644 index 000000000..fff4e436a --- /dev/null +++ b/engines/python/src/test/java/ai/djl/python/engine/PyEnvTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.python.engine; + +import ai.djl.MalformedModelException; +import ai.djl.engine.EngineException; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.util.Utils; + +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +public class PyEnvTest { + + @BeforeClass + public void setUp() { + System.setProperty("DJL_VENV_DIR", "build/venv"); + } + + @AfterClass + public void tierDown() { + System.clearProperty("DJL_VENV_DIR"); + Utils.deleteQuietly(Paths.get("build/venv")); + } + + @Test + public void testPythonVenv() + throws ModelNotFoundException, MalformedModelException, IOException { + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optEngine("Python") + .optModelPath(Paths.get("src/test/resources/echo")) + .optOption("enable_venv", "true") + .build(); + + Path venvDir; + try (ZooModel model = criteria.loadModel(); + ZooModel model2 = criteria.loadModel()) { + String venvName = Utils.hash(model.getModelPath().toString()); + venvDir = Paths.get("build/venv").resolve(venvName); + Assert.assertTrue(Files.exists(venvDir)); + Assert.assertNotNull(model2.getModelPath()); + } + Assert.assertFalse(Files.exists(venvDir)); + + // Test exception cases + Criteria criteria2 = + criteria.toBuilder().optOption("pythonExecutable", "non-exists").build(); + Assert.assertThrows(EngineException.class, criteria2::loadModel); + + System.setProperty("DJL_VENV_DIR", "/COM1"); // invalid path + Assert.assertThrows(EngineException.class, criteria::loadModel); + } +} diff --git a/engines/python/src/test/resources/venv/requirements.txt b/engines/python/src/test/resources/venv/requirements.txt deleted file mode 100644 index 747b7aa97..000000000 --- a/engines/python/src/test/resources/venv/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -transformers \ No newline at end of file diff --git a/engines/python/src/test/resources/venv/serving.properties b/engines/python/src/test/resources/venv/serving.properties deleted file mode 100644 index ab2d03ebb..000000000 --- a/engines/python/src/test/resources/venv/serving.properties +++ /dev/null @@ -1,5 +0,0 @@ -engine=Python -option.entryPoint=djl_python.huggingface -option.model_id=facebook/opt-125m -option.task=text-generation -option.enable_venv=true From 1a8cbcf06d62ffcd976281b50216fde543204e60 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Wed, 28 Jun 2023 11:22:35 -0700 Subject: [PATCH 3/7] Fix a few pipeline issues (#876) * minor changes * bump up available memory --- .github/workflows/lmi_dist_integration.yml | 1 + tests/integration/llm/client.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lmi_dist_integration.yml b/.github/workflows/lmi_dist_integration.yml index 22a2690f0..68ab0b785 100644 --- a/.github/workflows/lmi_dist_integration.yml +++ b/.github/workflows/lmi_dist_integration.yml @@ -31,6 +31,7 @@ jobs: image: deepjavalibrary/djl-serving:deepspeed-nightly options: --gpus all --runtime=nvidia --shm-size=2gb steps: + - uses: actions/checkout@v3 - name: Install environment working-directory: tests/integration run: | diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 31d1f096b..aac3fa7e6 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -205,7 +205,7 @@ def get_model_name(): }, "nomic-ai/gpt4all-j": { "batch_size": [1, 2], - "max_memory_per_gpu": 6.0 + "max_memory_per_gpu": 9.0 } } From 9d98f2ebaae143d8a35db04d8122c6da6f835a34 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Wed, 28 Jun 2023 16:28:01 -0700 Subject: [PATCH 4/7] [python] Rolling batch support for flash models (#865) * [python] Add rolling batch for flash attention models * Flash gptneox support * fix py tests * Set sharded to false for tp 1 * Review changes --------- Co-authored-by: sindhuso --- .../python/setup/djl_python/huggingface.py | 50 ++++-- .../djl_python/rolling_batch/__init__.py | 2 +- .../rolling_batch/lmi_dist_rolling_batch.py | 162 ++++++++++++++++++ .../djl_python/rolling_batch/rolling_batch.py | 11 +- .../ai/djl/python/engine/PyPredictor.java | 2 +- .../rolling_batch/serving.properties | 2 +- 6 files changed, 211 insertions(+), 18 deletions(-) create mode 100644 engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 19a6eb860..7b08f451b 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -23,6 +23,7 @@ from djl_python.outputs import Output from djl_python.streaming_utils import StreamingUtils from djl_python.rolling_batch import SchedulerRollingBatch +from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch ARCHITECTURES_2_TASK = { "TapasForQuestionAnswering": "table-question-answering", @@ -41,6 +42,13 @@ "BloomModel": "text-generation", } +ARCHITECTURES_2_RB_CLS = { + "RWForCausalLM" : LmiDistRollingBatch, + "GPTNeoXForCausalLM" : LmiDistRollingBatch, + "T5ForConditionalGeneration" : LmiDistRollingBatch, + "LlamaForCausalLM": LmiDistRollingBatch +} + def get_torch_dtype_from_str(dtype: str): if dtype == "auto": @@ -58,6 +66,21 @@ def get_torch_dtype_from_str(dtype: str): raise ValueError(f"Invalid data type: {dtype}") +def get_rolling_batch_class_from_str(rolling_batch_type: str, model_config): + if rolling_batch_type == "auto": + architecture = model_config.architectures[0] + if architecture in ARCHITECTURES_2_RB_CLS: + return ARCHITECTURES_2_RB_CLS[architecture] + else: + return SchedulerRollingBatch + elif rolling_batch_type == "scheduler": + return SchedulerRollingBatch + elif rolling_batch_type == "lmi-dist": + return LmiDistRollingBatch + raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}") + + + class HuggingFaceService(object): def __init__(self): @@ -69,7 +92,7 @@ def __init__(self): self.tokenizer = None self.trust_remote_code = os.environ.get("HF_TRUST_REMOTE_CODE", "FALSE").lower() == 'true' - self.enable_rolling_batch = None + self.rolling_batch_type = None self.rolling_batch = None self.model_config = None @@ -121,20 +144,21 @@ def initialize(self, properties: dict): if "dtype" in properties: kwargs["torch_dtype"] = get_torch_dtype_from_str( properties.get("dtype")) - self.enable_rolling_batch = properties.get("rolling_batch", None) - if self.enable_rolling_batch and self.enable_rolling_batch.lower( - ) == "false": - self.enable_rolling_batch = None + self.rolling_batch_type = properties.get("rolling_batch", None) if self.enable_streaming: self._init_model_and_tokenizer(model_id_or_path, **kwargs) self.initialized = True return - elif self.enable_rolling_batch: - # TODO: Add logic to call appropriate scheduler backend for rolling batch - self.rolling_batch = SchedulerRollingBatch(model_id_or_path, - self.device, properties, - **kwargs) + elif self.rolling_batch_type: + if properties.get("engine") != "Python": + self.device = int(os.getenv("LOCAL_RANK", 0)) + model_config = AutoConfig.from_pretrained(model_id_or_path, **kwargs) + _rolling_batch_cls = get_rolling_batch_class_from_str(self.rolling_batch_type, model_config) + self.rolling_batch = _rolling_batch_cls(model_id_or_path, + self.device, properties, + **kwargs) + self.initialized = True return @@ -170,7 +194,7 @@ def inference(self, inputs): input_data.extend(_inputs) else: input_data.append(_inputs) - if first or self.enable_rolling_batch: + if first or self.rolling_batch_type: parameters.append(input_map.pop("parameters", {})) first = False else: @@ -196,7 +220,7 @@ def inference(self, inputs): input_data, self.device, **parameters[0])) return outputs - elif self.enable_rolling_batch: + elif self.rolling_batch_type: result = self.rolling_batch.inference(input_data, parameters) for i in range(len(batch)): res = result[i] @@ -270,7 +294,7 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs): self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, padding_side="left") model_config = AutoConfig.from_pretrained(model_id_or_path, - kwargs=kwargs) + **kwargs) self.model_config = model_config architectures = model_config.architectures if architectures and architectures[0].endswith( diff --git a/engines/python/setup/djl_python/rolling_batch/__init__.py b/engines/python/setup/djl_python/rolling_batch/__init__.py index f4822e326..4dd28b0d4 100644 --- a/engines/python/setup/djl_python/rolling_batch/__init__.py +++ b/engines/python/setup/djl_python/rolling_batch/__init__.py @@ -11,4 +11,4 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -from .scheduler_rolling_batch import SchedulerRollingBatch +from .scheduler_rolling_batch import SchedulerRollingBatch \ No newline at end of file diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py new file mode 100644 index 000000000..c7693e927 --- /dev/null +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from djl_python.rolling_batch.rolling_batch import RollingBatch +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig +from lmi_dist.models import get_model +from lmi_dist.models.flash_causal_lm import FlashCausalLMBatch +from lmi_dist.models.seq2seq_lm import Seq2SeqLMBatch +from lmi_dist.utils.parameters import ( + NextTokenChooserParameters, + StoppingCriteriaParameters, +) +import lmi_dist +from lmi_dist.utils.types import ( + Batch, + Request, + Generation +) + +import torch + +ARCHITECTURE_2_BATCH_CLS = { + "RWForCausalLM": FlashCausalLMBatch, + "GPTNeoXForCausalLM": FlashCausalLMBatch, + "T5ForConditionalGeneration": Seq2SeqLMBatch, + "LlamaForCausalLM": FlashCausalLMBatch +} + + +def get_batch_cls_from_architecture(architecture): + if architecture in ARCHITECTURE_2_BATCH_CLS: + return ARCHITECTURE_2_BATCH_CLS[architecture] + raise ValueError("Invalid architecture, not supported by lmi-dist") + + +class LmiDistRollingBatch(RollingBatch): + + def __init__(self, model_id_or_path, device, properties, **kwargs): + """ + Initializes the LmiDistRollingBatch. + + :param model_id_or_path: model id or path + :param device: model loaded device + :param properties: other properties of the model, such as decoder strategy + :param kwargs passed while loading the model + """ + + super().__init__(device) + self.properties = properties + self.batch_cls = None + self._init_model(kwargs, model_id_or_path) + self.batch_id_counter = 0 + self.cache: Batch = None + + def _init_model(self, kwargs, model_id_or_path): + self.config = AutoConfig.from_pretrained(model_id_or_path, + **kwargs) + self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0]) + sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1 + self.model = get_model(model_id_or_path, + revision=None, + sharded=sharded, + quantize=None, + trust_remote_code=kwargs.get("trust_remote_code")) + + def inference(self, input_data, parameters): + """ + Performs prefill and decode operations for the batch. + + :param input_data: List of input texts for each request in a batch + :param parameters: List of kwargs for each request in a batch + :return: generated batch decoded tokens + """ + batch_size = len(input_data) + new_requests = self.get_new_requests(input_data, parameters, + batch_size) + new_batch = self.preprocess_requests(new_requests) + self._prefill_and_decode(new_batch) + return self.postprocess_results(batch_size) + + def _prefill_and_decode(self, new_batch): + # prefill step + if new_batch: + generations, prefill_next_batch = self.model.generate_token(new_batch) + + if self.cache: + decode_generations, decode_next_batch = self.model.generate_token(self.cache) + self.cache = decode_next_batch + generations.extend(decode_generations) + + # concatenate with the existing batch of the model + self.cache = self.model.batch_type.concatenate([prefill_next_batch, self.cache]) + + + else: + self.cache = prefill_next_batch + else: + generations, next_batch = self.model.generate_token(self.cache) + self.cache = next_batch + + generation_dict = {} + for generation in generations: + generation_dict[generation.request_id] = generation + + req_ids = [] + for r in self.pending_requests: + generation = generation_dict[r.id] + is_last_token = generation.generated_text is not None + if not is_last_token: + req_ids.append((r.id)) + r.set_next_token(generation.token_text, last_token=is_last_token) + + # filter the requests that are stopped. + if self.cache: + self.cache = self.cache.filter(req_ids) + + def preprocess_requests(self, requests, **kwargs): + preprocessed_requests = [] + for r in requests: + param = r.parameters + parameters = NextTokenChooserParameters( + temperature=param.get("temperature", 0.5), # TODO: Find a better place to put default values + repetition_penalty=param.get("repetition_penalty", 1.0), + top_k=param.get("top_k", 4), + top_p=param.get("top_p", 1.0), + typical_p=param.get("typical_p", 1.0), + do_sample=param.get("do_sample", False), + ) + stop_parameters = StoppingCriteriaParameters(stop_sequences=param.get("stop_sequences", []), + max_new_tokens=param.get("max_new_tokens", 30)) + + preprocessed_requests.append(lmi_dist.utils.types.Request( + id=r.id, + inputs=r.input_text, + parameters=parameters, + stopping_parameters=stop_parameters + )) + + if preprocessed_requests: + batch = Batch(id=self.batch_id_counter, + requests=preprocessed_requests, + size=len(preprocessed_requests)) + self.batch_id_counter += 1 + + return self.batch_cls.get_batch( + batch, + self.model.tokenizer, + kwargs.get("torch_dtype", torch.float16), + self.device + ) + else: + return None diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index d943ef597..34deadf60 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -24,12 +24,14 @@ class Request(object): """ - def __init__(self, input_text: str, parameters: dict): + def __init__(self, id: int, input_text: str, parameters: dict): """ Initialize a request + :param id: request id :param input_text: request's input text """ + self.id = id self.input_text = input_text self.parameters = parameters self.next_token = None @@ -79,6 +81,7 @@ def __init__(self, device): self.device = device self.pending_requests = [] + self.req_id_counter = 0 @abstractmethod def inference(self, input_data, parameters): @@ -98,9 +101,10 @@ def get_new_requests(self, input_data, parameters, batch_size): for i in range(pending_req_len, batch_size): data = input_data[i] params = parameters[i] if i < len(parameters) else {} - request = Request(data, params) + request = Request(self.req_id_counter, data, params) self.pending_requests.append(request) new_requests.append(request) + self.req_id_counter += 1 return new_requests @@ -119,4 +123,7 @@ def postprocess_results(self, batch_size): if self.pending_requests[batch_size - i].is_last_token(): self.pending_requests.pop(batch_size - i) + if len(self.pending_requests) == 0: + self.req_id_counter = 0 + return results diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java index 152aafbb9..cd7a83c0f 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java @@ -50,7 +50,7 @@ public PyPredictor( super(model, translator, device, false); this.process = process; this.timeout = timeout; - isRollingBatch = Boolean.parseBoolean(model.getProperty("rolling_batch", "false")); + isRollingBatch = model.getProperty("rolling_batch") != null; if (isRollingBatch) { int maxRollingBatchSize = Integer.parseInt(model.getProperty("max_rolling_batch_size", "3")); diff --git a/engines/python/src/test/resources/rolling_batch/serving.properties b/engines/python/src/test/resources/rolling_batch/serving.properties index 9c4c639af..43dcd9ae5 100644 --- a/engines/python/src/test/resources/rolling_batch/serving.properties +++ b/engines/python/src/test/resources/rolling_batch/serving.properties @@ -1,2 +1,2 @@ -option.rolling_batch=true +option.rolling_batch=auto option.max_rolling_batch_size=3 From 85bcd8c85e5aa518a864cac111896634ddd78390 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Thu, 29 Jun 2023 09:38:25 -0700 Subject: [PATCH 5/7] [CI] give longer time for building DeepSpeed container (#880) * give longer time for building * enable for MPI mode * fix installation --- .github/workflows/docker-nightly-publish.yml | 2 +- .../python/setup/djl_python/huggingface.py | 38 ++++++++++--------- serving/docker/scripts/install_flash_attn.sh | 6 +-- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/.github/workflows/docker-nightly-publish.yml b/.github/workflows/docker-nightly-publish.yml index b491721cd..3ad7105cc 100644 --- a/.github/workflows/docker-nightly-publish.yml +++ b/.github/workflows/docker-nightly-publish.yml @@ -108,7 +108,7 @@ jobs: nightly-deepspeed: runs-on: [ self-hosted, cpu ] - timeout-minutes: 30 + timeout-minutes: 60 needs: create-runner steps: - uses: actions/checkout@v3 diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 7b08f451b..d929bacf9 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -23,7 +23,6 @@ from djl_python.outputs import Output from djl_python.streaming_utils import StreamingUtils from djl_python.rolling_batch import SchedulerRollingBatch -from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch ARCHITECTURES_2_TASK = { "TapasForQuestionAnswering": "table-question-answering", @@ -42,11 +41,11 @@ "BloomModel": "text-generation", } -ARCHITECTURES_2_RB_CLS = { - "RWForCausalLM" : LmiDistRollingBatch, - "GPTNeoXForCausalLM" : LmiDistRollingBatch, - "T5ForConditionalGeneration" : LmiDistRollingBatch, - "LlamaForCausalLM": LmiDistRollingBatch +LMI_DIST_ADV_MODEL = { + "RWForCausalLM", + "GPTNeoXForCausalLM", + "T5ForConditionalGeneration", + "LlamaForCausalLM" } @@ -66,21 +65,22 @@ def get_torch_dtype_from_str(dtype: str): raise ValueError(f"Invalid data type: {dtype}") -def get_rolling_batch_class_from_str(rolling_batch_type: str, model_config): +def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool, model_config): if rolling_batch_type == "auto": architecture = model_config.architectures[0] - if architecture in ARCHITECTURES_2_RB_CLS: - return ARCHITECTURES_2_RB_CLS[architecture] + if architecture in LMI_DIST_ADV_MODEL and is_mpi: + from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch + return LmiDistRollingBatch else: return SchedulerRollingBatch elif rolling_batch_type == "scheduler": return SchedulerRollingBatch elif rolling_batch_type == "lmi-dist": + from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch return LmiDistRollingBatch raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}") - class HuggingFaceService(object): def __init__(self): @@ -151,10 +151,12 @@ def initialize(self, properties: dict): self.initialized = True return elif self.rolling_batch_type: - if properties.get("engine") != "Python": + self.rolling_batch_type = self.rolling_batch_type.lower() + is_mpi = properties.get("engine") != "Python" + if is_mpi: self.device = int(os.getenv("LOCAL_RANK", 0)) model_config = AutoConfig.from_pretrained(model_id_or_path, **kwargs) - _rolling_batch_cls = get_rolling_batch_class_from_str(self.rolling_batch_type, model_config) + _rolling_batch_cls = get_rolling_batch_class_from_str(self.rolling_batch_type, is_mpi, model_config) self.rolling_batch = _rolling_batch_cls(model_id_or_path, self.device, properties, **kwargs) @@ -247,12 +249,12 @@ def inference(self, inputs): def get_pipeline(self, task: str, model_id_or_path: str, kwargs): # define tokenizer or feature extractor as kwargs to load it the pipeline correctly if task in { - "automatic-speech-recognition", - "image-segmentation", - "image-classification", - "audio-classification", - "object-detection", - "zero-shot-image-classification", + "automatic-speech-recognition", + "image-segmentation", + "image-classification", + "audio-classification", + "object-detection", + "zero-shot-image-classification", }: kwargs["feature_extractor"] = model_id_or_path else: diff --git a/serving/docker/scripts/install_flash_attn.sh b/serving/docker/scripts/install_flash_attn.sh index bb554fd86..ca638ca34 100755 --- a/serving/docker/scripts/install_flash_attn.sh +++ b/serving/docker/scripts/install_flash_attn.sh @@ -3,8 +3,8 @@ rm -rf flash-attention git clone https://github.com/HazyResearch/flash-attention.git -b v1.0.0 pushd flash-attention || exit 1 -python setup.py install -cd csrc/layer_norm && python setup.py install -cd ../rotary && python setup.py install +pip install -v . +cd csrc/layer_norm && pip install -v . +cd ../rotary && pip install -v . popd || exit 1 rm -rf flash-attention From 36b750b287e1ba56fb6c83281b2c8907495ea426 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Thu, 29 Jun 2023 13:12:28 -0700 Subject: [PATCH 6/7] add MPI Engine as generic name for distributed environment (#882) --- ...neProvider.java => MpiEngineProvider.java} | 22 ++++++++++++++----- .../services/ai.djl.engine.EngineProvider | 5 +++-- 2 files changed, 19 insertions(+), 8 deletions(-) rename engines/python/src/main/java/ai/djl/python/engine/{DsEngineProvider.java => MpiEngineProvider.java} (64%) diff --git a/engines/python/src/main/java/ai/djl/python/engine/DsEngineProvider.java b/engines/python/src/main/java/ai/djl/python/engine/MpiEngineProvider.java similarity index 64% rename from engines/python/src/main/java/ai/djl/python/engine/DsEngineProvider.java rename to engines/python/src/main/java/ai/djl/python/engine/MpiEngineProvider.java index b098a83c3..27550fb99 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/DsEngineProvider.java +++ b/engines/python/src/main/java/ai/djl/python/engine/MpiEngineProvider.java @@ -15,17 +15,17 @@ import ai.djl.engine.EngineProvider; /** {@code DsEngineProvider} is the DeepSpeed implementation of {@link EngineProvider}. */ -public class DsEngineProvider extends PyEngineProvider { +public class MpiEngineProvider extends PyEngineProvider { - /** Constructs a new {@code DsEngineProvider} instance. */ - public DsEngineProvider() { + /** Constructs a new {@code MpiEngineProvider} instance. */ + public MpiEngineProvider() { mpiMode = true; } /** {@inheritDoc} */ @Override public String getEngineName() { - return "DeepSpeed"; + return "MPI"; } /** {@inheritDoc} */ @@ -34,8 +34,8 @@ public int getEngineRank() { return PyEngine.RANK + 1; } - /** {@code FtEngineProvider} is the alias of {@link DsEngineProvider}. */ - public static final class FtEngineProvider extends DsEngineProvider { + /** {@code FtEngineProvider} is the alias of {@link MpiEngineProvider}. */ + public static final class FtEngineProvider extends MpiEngineProvider { /** {@inheritDoc} */ @Override @@ -43,4 +43,14 @@ public String getEngineName() { return "FasterTransformer"; } } + + /** {@code DsEngineProvider} is the alias of {@link MpiEngineProvider}. */ + public static final class DsEngineProvider extends MpiEngineProvider { + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return "DeepSpeed"; + } + } } diff --git a/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider b/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider index ce1c3c0d3..d4c78eefa 100644 --- a/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider +++ b/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider @@ -1,3 +1,4 @@ ai.djl.python.engine.PyEngineProvider -ai.djl.python.engine.DsEngineProvider -ai.djl.python.engine.DsEngineProvider$FtEngineProvider +ai.djl.python.engine.MpiEngineProvider +ai.djl.python.engine.MpiEngineProvider$FtEngineProvider +ai.djl.python.engine.MpiEngineProvider$DsEngineProvider From cfb23d85f6423fbd4d7dd93d9b9358fdad9e494f Mon Sep 17 00:00:00 2001 From: Sindhu Somasundaram <56774226+sindhuvahinis@users.noreply.github.com> Date: Thu, 29 Jun 2023 15:19:32 -0700 Subject: [PATCH 7/7] Non streaming for rolling batch (#881) * Non streaming for rolling batch --- .../python/setup/djl_python/huggingface.py | 26 +++++++-------- .../ai/djl/python/engine/PyPredictor.java | 4 ++- .../ai/djl/python/engine/RollingBatch.java | 33 ++++++++++++++----- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index d929bacf9..72029f604 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -146,11 +146,7 @@ def initialize(self, properties: dict): properties.get("dtype")) self.rolling_batch_type = properties.get("rolling_batch", None) - if self.enable_streaming: - self._init_model_and_tokenizer(model_id_or_path, **kwargs) - self.initialized = True - return - elif self.rolling_batch_type: + if self.rolling_batch_type: self.rolling_batch_type = self.rolling_batch_type.lower() is_mpi = properties.get("engine") != "Python" if is_mpi: @@ -163,6 +159,10 @@ def initialize(self, properties: dict): self.initialized = True return + elif self.enable_streaming: + self._init_model_and_tokenizer(model_id_or_path, **kwargs) + self.initialized = True + return if not task: task = self.infer_task_from_model_architecture(model_id_or_path) @@ -207,7 +207,14 @@ def inference(self, inputs): outputs = Output() - if self.enable_streaming: + if self.rolling_batch_type: + result = self.rolling_batch.inference(input_data, parameters) + for i in range(len(batch)): + res = result[i] + outputs.add_as_json(res, batch_index=i) + + return outputs + elif self.enable_streaming: outputs.add_property("content-type", "application/jsonlines") if self.enable_streaming == "huggingface": outputs.add_stream_content( @@ -222,13 +229,6 @@ def inference(self, inputs): input_data, self.device, **parameters[0])) return outputs - elif self.rolling_batch_type: - result = self.rolling_batch.inference(input_data, parameters) - for i in range(len(batch)): - res = result[i] - outputs.add_as_json(res, batch_index=i) - - return outputs prediction = self.hf_pipeline(input_data, **parameters[0]) diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java index cd7a83c0f..b7e229f40 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java @@ -51,10 +51,12 @@ public PyPredictor( this.process = process; this.timeout = timeout; isRollingBatch = model.getProperty("rolling_batch") != null; + boolean enableStreaming = + Boolean.parseBoolean(model.getProperty("enable_streaming", "false")); if (isRollingBatch) { int maxRollingBatchSize = Integer.parseInt(model.getProperty("max_rolling_batch_size", "3")); - rollingBatch = new RollingBatch(process, maxRollingBatchSize, timeout); + rollingBatch = new RollingBatch(process, maxRollingBatchSize, timeout, enableStreaming); } } diff --git a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java index a5cfb2974..00fad2340 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java +++ b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java @@ -48,11 +48,13 @@ class RollingBatch implements Runnable { private ReentrantLock lock; private Condition canAdd; private Condition canRead; + private boolean enableStreaming; - RollingBatch(PyProcess process, int maxRollingBatchSize, int timeout) { + RollingBatch(PyProcess process, int maxRollingBatchSize, int timeout, boolean enableStreaming) { this.process = process; this.maxRollingBatchSize = maxRollingBatchSize; this.timeout = timeout; + this.enableStreaming = enableStreaming; list = new ArrayList<>(3); lock = new ReentrantLock(true); canAdd = lock.newCondition(); @@ -97,7 +99,7 @@ public void run() { for (int i = 0; i < size; ++i) { Request status = list.get(i); String json = content.get(i).getValue().getAsString(); - status.addResponse(json); + status.addResponse(json, enableStreaming); } list.removeIf(status -> status.last); if (list.size() < maxRollingBatchSize) { @@ -122,7 +124,7 @@ public Output addInput(Input input, int timeout) throws TranslateException { throw new TranslateException("Time out in: " + timeout); } } - Request req = new Request(input); + Request req = new Request(input, enableStreaming); list.add(req); canRead.signal(); return req.output; @@ -144,28 +146,41 @@ private static final class Request { Input input; ChunkedBytesSupplier data; Output output; - String nextToken; + StringBuilder nextToken; // NOPMD boolean last; - Request(Input input) { + Request(Input input, boolean enableStreaming) { this.input = input; data = new ChunkedBytesSupplier(); output = new Output(); output.add(data); + if (enableStreaming) { + nextToken = new StringBuilder(); + } else { + nextToken = new StringBuilder(1024); + } } BytesSupplier getRequest() { - if (nextToken != null) { + if (nextToken.length() != 0) { return BytesSupplier.wrap("{\"inputs\": [\"\"]}"); } return input.getData(); } - void addResponse(String json) { + void addResponse(String json, boolean enableStreaming) { JsonObject element = JsonUtils.GSON.fromJson(json, JsonObject.class); last = element.get("last").getAsBoolean(); - nextToken = element.get("data").getAsString(); - data.appendContent(BytesSupplier.wrap(nextToken), last); + if (enableStreaming) { + nextToken.setLength(0); + nextToken.append(element.get("data").getAsString()); + data.appendContent(BytesSupplier.wrap(nextToken.toString()), last); + } else { + nextToken.append(element.get("data").getAsString()); + if (last) { + data.appendContent(BytesSupplier.wrap(nextToken.toString()), true); + } + } } } }