Skip to content

Commit

Permalink
support remote model
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Jan 14, 2025
1 parent e306c5d commit 1faddf2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lmdeploy/cli/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def add_parser_smooth_quant():
ArgumentHelper.calib_search_scale(parser)
ArgumentHelper.dtype(parser)
ArgumentHelper.quant_dtype(parser)
ArgumentHelper.revision(parser)
ArgumentHelper.download_dir(parser)

@staticmethod
def auto_awq(args):
Expand Down
10 changes: 9 additions & 1 deletion lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

import os.path as osp
from typing import Literal

import fire
Expand All @@ -26,7 +27,9 @@ def smooth_quant(model: str,
dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
device: str = 'cuda',
quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn',
'float8_e5m2'] = 'int8'):
'float8_e5m2'] = 'int8',
revision: str = None,
download_dir: str = None):
if quant_dtype == 'fp8':
quant_dtype = 'float8_e4m3fn'

Expand All @@ -37,6 +40,11 @@ def smooth_quant(model: str,
q_dtype_info = torch.iinfo(quant_dtype)

assert q_dtype_info.bits == w_bits
if not osp.exists(model):
print(f'can\'t find model from local_path {model}, '
'try to download from remote')
from lmdeploy.utils import get_model
model = get_model(model, revision=revision, download_dir=download_dir)
model_path = model
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
Expand Down

0 comments on commit 1faddf2

Please sign in to comment.