Skip to content

Commit

Permalink
Merge branch 'master' into mme_test
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithkrn authored Jun 29, 2023
2 parents ff1574a + cfb23d8 commit 2ff12f5
Show file tree
Hide file tree
Showing 20 changed files with 399 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docker-nightly-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ jobs:
nightly-deepspeed:
runs-on: [ self-hosted, cpu ]
timeout-minutes: 30
timeout-minutes: 60
needs: create-runner
steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/lmi_dist_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
image: deepjavalibrary/djl-serving:deepspeed-nightly
options: --gpus all --runtime=nvidia --shm-size=2gb
steps:
- uses: actions/checkout@v3
- name: Install environment
working-directory: tests/integration
run: |
Expand Down
82 changes: 54 additions & 28 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
"BloomModel": "text-generation",
}

LMI_DIST_ADV_MODEL = {
"RWForCausalLM",
"GPTNeoXForCausalLM",
"T5ForConditionalGeneration",
"LlamaForCausalLM"
}


def get_torch_dtype_from_str(dtype: str):
if dtype == "auto":
Expand All @@ -58,6 +65,22 @@ def get_torch_dtype_from_str(dtype: str):
raise ValueError(f"Invalid data type: {dtype}")


def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool, model_config):
if rolling_batch_type == "auto":
architecture = model_config.architectures[0]
if architecture in LMI_DIST_ADV_MODEL and is_mpi:
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
else:
return SchedulerRollingBatch
elif rolling_batch_type == "scheduler":
return SchedulerRollingBatch
elif rolling_batch_type == "lmi-dist":
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}")


class HuggingFaceService(object):

def __init__(self):
Expand All @@ -69,7 +92,7 @@ def __init__(self):
self.tokenizer = None
self.trust_remote_code = os.environ.get("HF_TRUST_REMOTE_CODE",
"FALSE").lower() == 'true'
self.enable_rolling_batch = None
self.rolling_batch_type = None
self.rolling_batch = None
self.model_config = None

Expand Down Expand Up @@ -121,20 +144,23 @@ def initialize(self, properties: dict):
if "dtype" in properties:
kwargs["torch_dtype"] = get_torch_dtype_from_str(
properties.get("dtype"))
self.enable_rolling_batch = properties.get("rolling_batch", None)
if self.enable_rolling_batch and self.enable_rolling_batch.lower(
) == "false":
self.enable_rolling_batch = None
self.rolling_batch_type = properties.get("rolling_batch", None)

if self.rolling_batch_type:
self.rolling_batch_type = self.rolling_batch_type.lower()
is_mpi = properties.get("engine") != "Python"
if is_mpi:
self.device = int(os.getenv("LOCAL_RANK", 0))
model_config = AutoConfig.from_pretrained(model_id_or_path, **kwargs)
_rolling_batch_cls = get_rolling_batch_class_from_str(self.rolling_batch_type, is_mpi, model_config)
self.rolling_batch = _rolling_batch_cls(model_id_or_path,
self.device, properties,
**kwargs)

if self.enable_streaming:
self._init_model_and_tokenizer(model_id_or_path, **kwargs)
self.initialized = True
return
elif self.enable_rolling_batch:
# TODO: Add logic to call appropriate scheduler backend for rolling batch
self.rolling_batch = SchedulerRollingBatch(model_id_or_path,
self.device, properties,
**kwargs)
elif self.enable_streaming:
self._init_model_and_tokenizer(model_id_or_path, **kwargs)
self.initialized = True
return

Expand Down Expand Up @@ -169,7 +195,7 @@ def inference(self, inputs):
input_data.extend(_inputs)
else:
input_data.append(_inputs)
if first or self.enable_rolling_batch:
if first or self.rolling_batch_type:
parameters.append(input_map.pop("parameters", {}))
first = False
else:
Expand All @@ -180,7 +206,14 @@ def inference(self, inputs):

outputs = Output()

if self.enable_streaming:
if self.rolling_batch_type:
result = self.rolling_batch.inference(input_data, parameters)
for i in range(len(batch)):
res = result[i]
outputs.add_as_json(res, batch_index=i)

return outputs
elif self.enable_streaming:
outputs.add_property("content-type", "application/jsonlines")
if self.enable_streaming == "huggingface":
outputs.add_stream_content(
Expand All @@ -195,13 +228,6 @@ def inference(self, inputs):
input_data, self.device,
**parameters[0]))
return outputs
elif self.enable_rolling_batch:
result = self.rolling_batch.inference(input_data, parameters)
for i in range(len(batch)):
res = result[i]
outputs.add_as_json(res, batch_index=i)

return outputs

prediction = self.hf_pipeline(input_data, **parameters[0])

Expand All @@ -218,12 +244,12 @@ def inference(self, inputs):
def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
# define tokenizer or feature extractor as kwargs to load it the pipeline correctly
if task in {
"automatic-speech-recognition",
"image-segmentation",
"image-classification",
"audio-classification",
"object-detection",
"zero-shot-image-classification",
"automatic-speech-recognition",
"image-segmentation",
"image-classification",
"audio-classification",
"object-detection",
"zero-shot-image-classification",
}:
kwargs["feature_extractor"] = model_id_or_path
else:
Expand Down Expand Up @@ -265,7 +291,7 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs):
self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
padding_side="left")
model_config = AutoConfig.from_pretrained(model_id_or_path,
kwargs=kwargs)
**kwargs)
self.model_config = model_config
architectures = model_config.architectures
if architectures and architectures[0].endswith(
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/rolling_batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from .scheduler_rolling_batch import SchedulerRollingBatch
from .scheduler_rolling_batch import SchedulerRollingBatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from djl_python.rolling_batch.rolling_batch import RollingBatch
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
from lmi_dist.models import get_model
from lmi_dist.models.flash_causal_lm import FlashCausalLMBatch
from lmi_dist.models.seq2seq_lm import Seq2SeqLMBatch
from lmi_dist.utils.parameters import (
NextTokenChooserParameters,
StoppingCriteriaParameters,
)
import lmi_dist
from lmi_dist.utils.types import (
Batch,
Request,
Generation
)

import torch

ARCHITECTURE_2_BATCH_CLS = {
"RWForCausalLM": FlashCausalLMBatch,
"GPTNeoXForCausalLM": FlashCausalLMBatch,
"T5ForConditionalGeneration": Seq2SeqLMBatch,
"LlamaForCausalLM": FlashCausalLMBatch
}


def get_batch_cls_from_architecture(architecture):
if architecture in ARCHITECTURE_2_BATCH_CLS:
return ARCHITECTURE_2_BATCH_CLS[architecture]
raise ValueError("Invalid architecture, not supported by lmi-dist")


class LmiDistRollingBatch(RollingBatch):

def __init__(self, model_id_or_path, device, properties, **kwargs):
"""
Initializes the LmiDistRollingBatch.
:param model_id_or_path: model id or path
:param device: model loaded device
:param properties: other properties of the model, such as decoder strategy
:param kwargs passed while loading the model
"""

super().__init__(device)
self.properties = properties
self.batch_cls = None
self._init_model(kwargs, model_id_or_path)
self.batch_id_counter = 0
self.cache: Batch = None

def _init_model(self, kwargs, model_id_or_path):
self.config = AutoConfig.from_pretrained(model_id_or_path,
**kwargs)
self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0])
sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1
self.model = get_model(model_id_or_path,
revision=None,
sharded=sharded,
quantize=None,
trust_remote_code=kwargs.get("trust_remote_code"))

def inference(self, input_data, parameters):
"""
Performs prefill and decode operations for the batch.
:param input_data: List of input texts for each request in a batch
:param parameters: List of kwargs for each request in a batch
:return: generated batch decoded tokens
"""
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
batch_size)
new_batch = self.preprocess_requests(new_requests)
self._prefill_and_decode(new_batch)
return self.postprocess_results(batch_size)

def _prefill_and_decode(self, new_batch):
# prefill step
if new_batch:
generations, prefill_next_batch = self.model.generate_token(new_batch)

if self.cache:
decode_generations, decode_next_batch = self.model.generate_token(self.cache)
self.cache = decode_next_batch
generations.extend(decode_generations)

# concatenate with the existing batch of the model
self.cache = self.model.batch_type.concatenate([prefill_next_batch, self.cache])


else:
self.cache = prefill_next_batch
else:
generations, next_batch = self.model.generate_token(self.cache)
self.cache = next_batch

generation_dict = {}
for generation in generations:
generation_dict[generation.request_id] = generation

req_ids = []
for r in self.pending_requests:
generation = generation_dict[r.id]
is_last_token = generation.generated_text is not None
if not is_last_token:
req_ids.append((r.id))
r.set_next_token(generation.token_text, last_token=is_last_token)

# filter the requests that are stopped.
if self.cache:
self.cache = self.cache.filter(req_ids)

def preprocess_requests(self, requests, **kwargs):
preprocessed_requests = []
for r in requests:
param = r.parameters
parameters = NextTokenChooserParameters(
temperature=param.get("temperature", 0.5), # TODO: Find a better place to put default values
repetition_penalty=param.get("repetition_penalty", 1.0),
top_k=param.get("top_k", 4),
top_p=param.get("top_p", 1.0),
typical_p=param.get("typical_p", 1.0),
do_sample=param.get("do_sample", False),
)
stop_parameters = StoppingCriteriaParameters(stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))

preprocessed_requests.append(lmi_dist.utils.types.Request(
id=r.id,
inputs=r.input_text,
parameters=parameters,
stopping_parameters=stop_parameters
))

if preprocessed_requests:
batch = Batch(id=self.batch_id_counter,
requests=preprocessed_requests,
size=len(preprocessed_requests))
self.batch_id_counter += 1

return self.batch_cls.get_batch(
batch,
self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16),
self.device
)
else:
return None
11 changes: 9 additions & 2 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class Request(object):
"""

def __init__(self, input_text: str, parameters: dict):
def __init__(self, id: int, input_text: str, parameters: dict):
"""
Initialize a request
:param id: request id
:param input_text: request's input text
"""
self.id = id
self.input_text = input_text
self.parameters = parameters
self.next_token = None
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(self, device):

self.device = device
self.pending_requests = []
self.req_id_counter = 0

@abstractmethod
def inference(self, input_data, parameters):
Expand All @@ -98,9 +101,10 @@ def get_new_requests(self, input_data, parameters, batch_size):
for i in range(pending_req_len, batch_size):
data = input_data[i]
params = parameters[i] if i < len(parameters) else {}
request = Request(data, params)
request = Request(self.req_id_counter, data, params)
self.pending_requests.append(request)
new_requests.append(request)
self.req_id_counter += 1

return new_requests

Expand All @@ -119,4 +123,7 @@ def postprocess_results(self, batch_size):
if self.pending_requests[batch_size - i].is_last_token():
self.pending_requests.pop(batch_size - i)

if len(self.pending_requests) == 0:
self.req_id_counter = 0

return results
Loading

0 comments on commit 2ff12f5

Please sign in to comment.