diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 4a557f883..3858d1834 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -26,7 +26,7 @@ Request, Generation ) - +import logging import torch ARCHITECTURE_2_BATCH_CLS = { @@ -67,10 +67,15 @@ def _init_model(self, kwargs, model_id_or_path): **kwargs) self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0]) sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1 + quantize = self.properties.get("quantize", None) + if quantize and quantize != "bitsanadbytes": + logging.info(f"Invalid value for quantize: {quantize}. Only `bitsandbytes` quantization is supported. " + f"Setting quantization to None") + quantize = None self.model = get_model(model_id_or_path, revision=None, sharded=sharded, - quantize=None, + quantize=quantize, trust_remote_code=kwargs.get("trust_remote_code")) def inference(self, input_data, parameters):