diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 3858d1834e..7e9b979ca0 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -21,13 +21,10 @@ StoppingCriteriaParameters, ) import lmi_dist -from lmi_dist.utils.types import ( - Batch, - Request, - Generation -) +from lmi_dist.utils.types import (Batch, Request, Generation) import logging import torch +import os ARCHITECTURE_2_BATCH_CLS = { "RWForCausalLM": FlashCausalLMBatch, @@ -36,6 +33,8 @@ "LlamaForCausalLM": FlashCausalLMBatch } +QUANTIZATION_SUPPORT_ALGO = ["bitsanadbytes"] + def get_batch_cls_from_architecture(architecture): if architecture in ARCHITECTURE_2_BATCH_CLS: @@ -54,7 +53,7 @@ def __init__(self, model_id_or_path, device, properties, **kwargs): :param properties: other properties of the model, such as decoder strategy :param kwargs passed while loading the model """ - + os.environ["BITSANDBYTES_NOWELCOME"] = "1" super().__init__(device) self.properties = properties self.batch_cls = None @@ -63,20 +62,25 @@ def __init__(self, model_id_or_path, device, properties, **kwargs): self.cache: Batch = None def _init_model(self, kwargs, model_id_or_path): - self.config = AutoConfig.from_pretrained(model_id_or_path, - **kwargs) - self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0]) + self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs) + self.batch_cls = get_batch_cls_from_architecture( + self.config.architectures[0]) sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1 quantize = self.properties.get("quantize", None) - if quantize and quantize != "bitsanadbytes": - logging.info(f"Invalid value for quantize: {quantize}. Only `bitsandbytes` quantization is supported. " - f"Setting quantization to None") + dtype = self.properties.get("dtype", None) + if quantize is not None and dtype is not None: + logging.error("Can't set both dtype and quantize") + if quantize and quantize not in QUANTIZATION_SUPPORT_ALGO: + logging.error( + f"Invalid value for quantize: {quantize}. Only `bitsandbytes` quantization is supported." + ) quantize = None - self.model = get_model(model_id_or_path, - revision=None, - sharded=sharded, - quantize=quantize, - trust_remote_code=kwargs.get("trust_remote_code")) + self.model = get_model( + model_id_or_path, + revision=None, + sharded=sharded, + quantize=quantize, + trust_remote_code=kwargs.get("trust_remote_code")) def inference(self, input_data, parameters): """ @@ -96,15 +100,18 @@ def inference(self, input_data, parameters): def _prefill_and_decode(self, new_batch): # prefill step if new_batch: - generations, prefill_next_batch = self.model.generate_token(new_batch) + generations, prefill_next_batch = self.model.generate_token( + new_batch) if self.cache: - decode_generations, decode_next_batch = self.model.generate_token(self.cache) + decode_generations, decode_next_batch = self.model.generate_token( + self.cache) generations.extend(decode_generations) # concatenate with the existing batch of the model if decode_next_batch: - self.cache = self.model.batch_type.concatenate([prefill_next_batch, decode_next_batch]) + self.cache = self.model.batch_type.concatenate( + [prefill_next_batch, decode_next_batch]) else: self.cache = prefill_next_batch else: @@ -113,7 +120,10 @@ def _prefill_and_decode(self, new_batch): generations, next_batch = self.model.generate_token(self.cache) self.cache = next_batch - generation_dict = {generation.request_id: generation for generation in generations} + generation_dict = { + generation.request_id: generation + for generation in generations + } req_ids = [] for r in self.pending_requests: @@ -132,22 +142,25 @@ def preprocess_requests(self, requests, **kwargs): for r in requests: param = r.parameters parameters = NextTokenChooserParameters( - temperature=param.get("temperature", 0.5), # TODO: Find a better place to put default values + temperature=param.get( + "temperature", + 0.5), # TODO: Find a better place to put default values repetition_penalty=param.get("repetition_penalty", 1.0), top_k=param.get("top_k", 4), top_p=param.get("top_p", 1.0), typical_p=param.get("typical_p", 1.0), do_sample=param.get("do_sample", False), ) - stop_parameters = StoppingCriteriaParameters(stop_sequences=param.get("stop_sequences", []), - max_new_tokens=param.get("max_new_tokens", 30)) + stop_parameters = StoppingCriteriaParameters( + stop_sequences=param.get("stop_sequences", []), + max_new_tokens=param.get("max_new_tokens", 30)) - preprocessed_requests.append(lmi_dist.utils.types.Request( - id=r.id, - inputs=r.input_text, - parameters=parameters, - stopping_parameters=stop_parameters - )) + preprocessed_requests.append( + lmi_dist.utils.types.Request( + id=r.id, + inputs=r.input_text, + parameters=parameters, + stopping_parameters=stop_parameters)) if preprocessed_requests: batch = Batch(id=self.batch_id_counter, @@ -156,10 +169,7 @@ def preprocess_requests(self, requests, **kwargs): self.batch_id_counter += 1 return self.batch_cls.get_batch( - batch, - self.model.tokenizer, - kwargs.get("torch_dtype", torch.float16), - self.device - ) + batch, self.model.tokenizer, + kwargs.get("torch_dtype", torch.float16), self.device) else: return None