diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py index 499bace48..5ef989957 100644 --- a/lmdeploy/cli/lite.py +++ b/lmdeploy/cli/lite.py @@ -127,6 +127,8 @@ def add_parser_smooth_quant(): ArgumentHelper.calib_search_scale(parser) ArgumentHelper.dtype(parser) ArgumentHelper.quant_dtype(parser) + ArgumentHelper.revision(parser) + ArgumentHelper.download_dir(parser) @staticmethod def auto_awq(args): diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index bfc19fd60..6114fdf89 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp from typing import Literal import fire @@ -26,7 +27,9 @@ def smooth_quant(model: str, dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', device: str = 'cuda', quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', - 'float8_e5m2'] = 'int8'): + 'float8_e5m2'] = 'int8', + revision: str = None, + download_dir: str = None): if quant_dtype == 'fp8': quant_dtype = 'float8_e4m3fn' @@ -37,6 +40,11 @@ def smooth_quant(model: str, q_dtype_info = torch.iinfo(quant_dtype) assert q_dtype_info.bits == w_bits + if not osp.exists(model): + print(f'can\'t find model from local_path {model}, ' + 'try to download from remote') + from lmdeploy.utils import get_model + model = get_model(model, revision=revision, download_dir=download_dir) model_path = model vl_model, model, tokenizer, work_dir = calibrate(model, calib_dataset,