Skip to content

Commit

Permalink
Format python
Browse files Browse the repository at this point in the history
  • Loading branch information
maaquib committed Jun 30, 2023
1 parent 9c502aa commit ffd8dc7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
2 changes: 2 additions & 0 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from djl_python.streaming_utils import StreamingUtils
from djl_python.rolling_batch import SchedulerRollingBatch

os.environ["BITSANDBYTES_NOWELCOME"] = "1"

ARCHITECTURES_2_TASK = {
"TapasForQuestionAnswering": "table-question-answering",
"ForQuestionAnswering": "question-answering",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@
StoppingCriteriaParameters,
)
import lmi_dist
from lmi_dist.utils.types import (
Batch,
Request,
Generation
)
import logging
from lmi_dist.utils.types import (Batch, Request, Generation)

import torch

ARCHITECTURE_2_BATCH_CLS = {
Expand All @@ -36,6 +32,8 @@
"LlamaForCausalLM": FlashCausalLMBatch
}

QUANTIZATION_SUPPORT_ALGO = ["bitsandbytes"]


def get_batch_cls_from_architecture(architecture):
if architecture in ARCHITECTURE_2_BATCH_CLS:
Expand Down Expand Up @@ -63,20 +61,25 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
self.cache: Batch = None

def _init_model(self, kwargs, model_id_or_path):
self.config = AutoConfig.from_pretrained(model_id_or_path,
**kwargs)
self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0])
self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs)
self.batch_cls = get_batch_cls_from_architecture(
self.config.architectures[0])
sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1
quantize = self.properties.get("quantize", None)
if quantize and quantize != "bitsanadbytes":
logging.info(f"Invalid value for quantize: {quantize}. Only `bitsandbytes` quantization is supported. "
f"Setting quantization to None")
quantize = None
self.model = get_model(model_id_or_path,
revision=None,
sharded=sharded,
quantize=quantize,
trust_remote_code=kwargs.get("trust_remote_code"))
dtype = self.properties.get("dtype", None)
if quantize is not None and dtype is not None:
raise ValueError(
f"Can't set both dtype: {dtype} and quantize: {quantize}")
if quantize is not None and quantize not in QUANTIZATION_SUPPORT_ALGO:
raise ValueError(
f"Invalid value for quantize: {quantize}. Valid values are: {QUANTIZATION_SUPPORT_ALGO}"
)
self.model = get_model(
model_id_or_path,
revision=None,
sharded=sharded,
quantize=quantize,
trust_remote_code=kwargs.get("trust_remote_code"))

def inference(self, input_data, parameters):
"""
Expand All @@ -96,15 +99,18 @@ def inference(self, input_data, parameters):
def _prefill_and_decode(self, new_batch):
# prefill step
if new_batch:
generations, prefill_next_batch = self.model.generate_token(new_batch)
generations, prefill_next_batch = self.model.generate_token(
new_batch)

if self.cache:
decode_generations, decode_next_batch = self.model.generate_token(self.cache)
decode_generations, decode_next_batch = self.model.generate_token(
self.cache)
generations.extend(decode_generations)

# concatenate with the existing batch of the model
if decode_next_batch:
self.cache = self.model.batch_type.concatenate([prefill_next_batch, decode_next_batch])
self.cache = self.model.batch_type.concatenate(
[prefill_next_batch, decode_next_batch])
else:
self.cache = prefill_next_batch
else:
Expand All @@ -113,7 +119,10 @@ def _prefill_and_decode(self, new_batch):
generations, next_batch = self.model.generate_token(self.cache)
self.cache = next_batch

generation_dict = {generation.request_id: generation for generation in generations}
generation_dict = {
generation.request_id: generation
for generation in generations
}

req_ids = []
for r in self.pending_requests:
Expand All @@ -132,22 +141,25 @@ def preprocess_requests(self, requests, **kwargs):
for r in requests:
param = r.parameters
parameters = NextTokenChooserParameters(
temperature=param.get("temperature", 0.5), # TODO: Find a better place to put default values
temperature=param.get(
"temperature",
0.5), # TODO: Find a better place to put default values
repetition_penalty=param.get("repetition_penalty", 1.0),
top_k=param.get("top_k", 4),
top_p=param.get("top_p", 1.0),
typical_p=param.get("typical_p", 1.0),
do_sample=param.get("do_sample", False),
)
stop_parameters = StoppingCriteriaParameters(stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))
stop_parameters = StoppingCriteriaParameters(
stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))

preprocessed_requests.append(lmi_dist.utils.types.Request(
id=r.id,
inputs=r.input_text,
parameters=parameters,
stopping_parameters=stop_parameters
))
preprocessed_requests.append(
lmi_dist.utils.types.Request(
id=r.id,
inputs=r.input_text,
parameters=parameters,
stopping_parameters=stop_parameters))

if preprocessed_requests:
batch = Batch(id=self.batch_id_counter,
Expand All @@ -156,10 +168,7 @@ def preprocess_requests(self, requests, **kwargs):
self.batch_id_counter += 1

return self.batch_cls.get_batch(
batch,
self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16),
self.device
)
batch, self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16), self.device)
else:
return None

0 comments on commit ffd8dc7

Please sign in to comment.