diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 4018b2a06..f0e9d0a45 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -23,6 +23,8 @@ import torch +QUANTIZATION_SUPPORT_ALGO = ["bitsandbytes"] + class LmiDistRollingBatch(RollingBatch): @@ -46,11 +48,22 @@ def __init__(self, model_id_or_path, device, properties, **kwargs): def _init_model(self, kwargs, model_id_or_path): self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs) sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1 + quantize = self.properties.get("quantize", None) + dtype = self.properties.get("dtype", None) + if quantize is not None and dtype is not None: + raise ValueError( + f"Can't set both dtype: {dtype} and quantize: {quantize}") + if quantize is not None and quantize not in QUANTIZATION_SUPPORT_ALGO: + raise ValueError( + f"Invalid value for quantize: {quantize}. Valid values are: {QUANTIZATION_SUPPORT_ALGO}" + ) + if quantize is None and dtype == "int8": + quantize = "bitsandbytes" self.model = get_model( model_id_or_path, revision=None, sharded=sharded, - quantize=None, + quantize=quantize, trust_remote_code=kwargs.get("trust_remote_code")) self.batch_cls = self.model.batch_type diff --git a/serving/docker/deepspeed.Dockerfile b/serving/docker/deepspeed.Dockerfile index f151bf49a..f46b3d18b 100644 --- a/serving/docker/deepspeed.Dockerfile +++ b/serving/docker/deepspeed.Dockerfile @@ -39,6 +39,7 @@ ENV DJL_CACHE_DIR=/tmp/.djl.ai ENV HUGGINGFACE_HUB_CACHE=/tmp ENV TRANSFORMERS_CACHE=/tmp ENV PYTORCH_KERNEL_CACHE_PATH=/tmp +ENV BITSANDBYTES_NOWELCOME=1 ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] diff --git a/serving/docker/fastertransformer.Dockerfile b/serving/docker/fastertransformer.Dockerfile index d81dde2a5..7f7f73afa 100644 --- a/serving/docker/fastertransformer.Dockerfile +++ b/serving/docker/fastertransformer.Dockerfile @@ -40,6 +40,7 @@ ENV DJL_CACHE_DIR=/tmp/.djl.ai ENV HUGGINGFACE_HUB_CACHE=/tmp ENV TRANSFORMERS_CACHE=/tmp ENV PYTORCH_KERNEL_CACHE_PATH=/tmp +ENV BITSANDBYTES_NOWELCOME=1 ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"]