diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py index e207142281..87d5d73f10 100644 --- a/autotest/utils/config_utils.py +++ b/autotest/utils/config_utils.py @@ -113,7 +113,7 @@ def get_all_model_list(tp_num: int = None, model_type=model_type): if case not in case_list: case_list.append(case) - return [x for x in case_list if 'w8a8' not in x] + return case_list def get_quantization_model_list(type): diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index ce5cbd98ff..53039bdb53 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -from typing import Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union from transformers import AutoConfig @@ -193,3 +193,22 @@ def get_model_arch(model_path: str): raise RuntimeError( f'Could not find model architecture from config: {_cfg}') return arch, cfg + + +def search_nested_config(config, key): + """Recursively searches for the value associated with the given key in a + nested configuration of a model.""" + if isinstance(config, Dict): + for k, v in config.items(): + if k == key: + return v + if isinstance(v, (Dict, List)): + result = search_nested_config(v, key) + if result is not None: + return result + elif isinstance(config, List): + for item in config: + result = search_nested_config(item, key) + if result is not None: + return result + return None diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 68f9de8c15..5cf3453b7e 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -238,6 +238,7 @@ def add_parser_proxy(): help='the strategy to dispatch requests to nodes') ArgumentHelper.api_keys(parser) ArgumentHelper.ssl(parser) + ArgumentHelper.log_level(parser) @staticmethod def gradio(args): diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 307cf6d7e9..85467997e3 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -262,7 +262,13 @@ def calibrate(model: str, if dtype == 'float16': model.half() elif dtype == 'bfloat16': + assert torch.cuda.is_bf16_supported( + ), 'your device does not support bfloat16 please set --dtype float16' # noqa model.to(torch.bfloat16) + elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. You' + ' may enforce it bfloat16 by `--dtype bfloat16`') + model.half() model.eval() model_type = type(model).__name__ diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py index 170c149778..ac4519371a 100644 --- a/lmdeploy/lite/utils/load.py +++ b/lmdeploy/lite/utils/load.py @@ -12,14 +12,13 @@ def load_hf_from_pretrained(pretrained_model_name_or_path, dtype: Literal['float16', 'bfloat16', 'auto'], **kwargs): - if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): raise RuntimeError('Your device does not supports bf16(bfloat16), ' 'please change to fp16(float16)') kwargs.pop('config', None) hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, - torch_dtype=dtype, trust_remote_code=True) # HACK hard code for qwen, other configs do not have the `fp16` attribute. @@ -29,13 +28,23 @@ def load_hf_from_pretrained(pretrained_model_name_or_path, else: hf_config.fp16 = True - if dtype != 'auto': - setattr(hf_config, 'torch_dtype', dtype) + torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) + if dtype == 'bfloat16': + torch_dtype = torch.bfloat16 + elif dtype == 'float16': + torch_dtype = torch.float16 + elif dtype == 'auto' and torch_dtype == torch.bfloat16: + print('Warning: we cast model to float16 to prevent OOM. ' + 'You may enforce it bfloat16 by `--dtype bfloat16`') + torch_dtype = torch.float16 with LoadNoInit(): # Load model model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, config=hf_config, **kwargs) + pretrained_model_name_or_path, + config=hf_config, + torch_dtype=torch_dtype, + **kwargs) model.config.use_cache = False return model diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 4bbbca2298..a0b0c8e09b 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -1921,5 +1921,5 @@ def best_match_model(query: str) -> Optional[str]: for name, model in MODELS.module_dict.items(): if model.match(query): return model.match(query) - logger.warn(f'Did not find a chat template matching {query}.') + logger.warning(f'Did not find a chat template matching {query}.') return 'base' diff --git a/lmdeploy/pytorch/backends/dlinfer/flash_attention.py b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py new file mode 100644 index 0000000000..d0d9ddbb26 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/flash_attention.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl + + +class DlinferFlashAttentionImpl(FlashAttentionImpl): + """dlinfer flash attention implementation.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + ): + if scale is None: + scale = 1.0 / (head_dim**0.5) + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_dim is None: + v_head_dim = head_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.v_head_dim = v_head_dim + self.causal = causal + self.sliding_window = sliding_window + self.logical_softcapping = logical_softcapping + from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd + self.flash_attention_fwd = flash_attention_fwd + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_dim, ) + out = query.new_empty(o_shape) + self.flash_attention_fwd( + query, + key, + value, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_q_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logical_softcapping, + causal=self.causal, + ) + return out + + +class DlinferFlashAttentionBuilder(FlashAttentionBuilder): + """dlinfer attention builder.""" + + @staticmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + return DlinferFlashAttentionImpl( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logical_softcapping=logical_softcapping, + ) diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 93733fbf57..a0f04f34b1 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -25,6 +25,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): if layer_type == OpType.PagedAttention: from .attention import DlinferAttentionBuilder return DlinferAttentionBuilder + elif layer_type == OpType.FlashAttention: + from .flash_attention import DlinferFlashAttentionBuilder + return DlinferFlashAttentionBuilder elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder return DlinferApplyRotaryEmbBuilder diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index 8f86f0019a..fe82010761 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -3,6 +3,7 @@ from .apply_rotary_pos_emb import apply_rotary_pos_emb from .awq_kernels import awq_linear from .fill_kv_cache import fill_kv_cache +from .flash_attention import flash_attention_fwd from .fused_moe import fused_moe from .linear import linear from .moe_gating_topk_softmax import moe_gating_topk_softmax @@ -16,6 +17,7 @@ 'fill_kv_cache', 'fused_moe', 'paged_attention_fwd', + 'flash_attention_fwd', 'linear', 'moe_gating_topk_softmax', 'multinomial_sampling', diff --git a/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py new file mode 100644 index 0000000000..1788f947ee --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from dlinfer.utils.type_annotation import Tensor + + +def flash_attention_fwd( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None, + window_size: int = None, + sm_scale: float = None, + logit_softcapping: float = None, + causal: bool = True, +): + num_q_heads = query_states.shape[1] + num_kv_heads = value_states.shape[1] + return ext_ops.prefill_attention( + query_states, + key_states, + value_states, + q_start_loc, + q_seqlens, + max_q_seqlen, + num_q_heads, + num_kv_heads, + attn_mask=None, + softmax_scale=sm_scale, + attn_output=attn_output, + ) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 78574a38b1..dfcf01a69d 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -505,8 +505,8 @@ async def generate( if gen_config.stop_token_ids is None: gen_config.stop_token_ids = self.stop_words if not gen_config.do_sample: - logger.warn(f'GenerationConfig: {gen_config}') - logger.warn( + logger.warning(f'GenerationConfig: {gen_config}') + logger.warning( 'Since v0.6.0, lmdeploy add `do_sample` in ' 'GenerationConfig. It defaults to False, meaning greedy ' 'decoding. Please set `do_sample=True` if sampling ' diff --git a/lmdeploy/serve/proxy/constants.py b/lmdeploy/serve/proxy/constants.py index 88d86a3e33..5bf6e67659 100644 --- a/lmdeploy/serve/proxy/constants.py +++ b/lmdeploy/serve/proxy/constants.py @@ -2,8 +2,8 @@ import enum -LATENCY_DEEQUE_LEN = 15 -API_TIMEOUT_LEN = 100 +LATENCY_DEQUE_LEN = 15 +API_READ_TIMEOUT = 100 class Strategy(enum.Enum): diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 5f05930bd0..392ede3267 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import asyncio import copy import json import os @@ -18,14 +19,15 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field +from requests.exceptions import RequestException from lmdeploy.serve.openai.api_server import (check_api_key, create_error_response) from lmdeploy.serve.openai.protocol import ( # noqa: E501 ChatCompletionRequest, CompletionRequest, ModelCard, ModelList, ModelPermission) -from lmdeploy.serve.proxy.constants import (API_TIMEOUT_LEN, - LATENCY_DEEQUE_LEN, ErrorCodes, +from lmdeploy.serve.proxy.constants import (API_READ_TIMEOUT, + LATENCY_DEQUE_LEN, ErrorCodes, Strategy, err_msg) from lmdeploy.utils import get_logger @@ -36,7 +38,7 @@ class Status(BaseModel): """Status protocol consists of models' information.""" models: Optional[List[str]] = Field(default=[], examples=[[]]) unfinished: int = 0 - latency: Deque = Field(default=deque(maxlen=LATENCY_DEEQUE_LEN), + latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]]) speed: Optional[int] = Field(default=None, examples=[None]) @@ -87,6 +89,9 @@ def __init__(self, with open(self.config_path, 'r') as config_file: self.nodes = yaml.safe_load(config_file)['nodes'] for url, status in self.nodes.items(): + latency = deque(status.get('latency', []), + maxlen=LATENCY_DEQUE_LEN) + status['latency'] = latency status = Status(**status) self.nodes[url] = status self.heart_beat_thread = threading.Thread(target=heart_beat_controller, @@ -99,7 +104,7 @@ def update_config_file(self): nodes = copy.deepcopy(self.nodes) for url, status in nodes.items(): nodes[url] = status.model_dump() - nodes[url]['latency'] = list(status.latency) + nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:] with open(self.config_path, 'w') as config_file: # update cfg yml yaml.dump(dict(nodes=nodes), config_file) @@ -149,7 +154,8 @@ def remove_stale_nodes_by_expiration(self): to_be_deleted.append(node_url) for node_url in to_be_deleted: self.remove(node_url) - logger.info(f'Removed node_url: {node_url}') + logger.info(f'Removed node_url: {node_url} ' + 'due to heart beat expiration') @property def model_list(self): @@ -251,7 +257,7 @@ def handle_unavailable_model(self, model_name): Args: model_name (str): the model in the request. """ - logger.info(f'no model name: {model_name}') + logger.warning(f'no model name: {model_name}') ret = { 'error_code': ErrorCodes.MODEL_NOT_FOUND, 'text': err_msg[ErrorCodes.MODEL_NOT_FOUND], @@ -260,51 +266,54 @@ def handle_unavailable_model(self, model_name): def handle_api_timeout(self, node_url): """Handle the api time out.""" - logger.info(f'api timeout: {node_url}') + logger.warning(f'api timeout: {node_url}') ret = { - 'error_code': ErrorCodes.API_TIMEOUT, + 'error_code': ErrorCodes.API_TIMEOUT.value, 'text': err_msg[ErrorCodes.API_TIMEOUT], } return json.dumps(ret).encode() + b'\n' - def stream_generate(self, request: Dict, node_url: str, node_path: str): + def stream_generate(self, request: Dict, node_url: str, endpoint: str): """Return a generator to handle the input request. Args: request (Dict): the input request. node_url (str): the node url. - node_path (str): the node path. Such as `/v1/chat/completions`. + endpoint (str): the endpoint. Such as `/v1/chat/completions`. """ try: response = requests.post( - node_url + node_path, + node_url + endpoint, json=request, - stream=request['stream'], - timeout=API_TIMEOUT_LEN, + stream=True, + timeout=(5, API_READ_TIMEOUT), ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\n'): if chunk: yield chunk + b'\n\n' - except requests.exceptions.RequestException as e: # noqa + except (Exception, GeneratorExit, RequestException) as e: # noqa + logger.error(f'catched an exception: {e}') + # exception happened, reduce unfinished num yield self.handle_api_timeout(node_url) - async def generate(self, request: Dict, node_url: str, node_path: str): + async def generate(self, request: Dict, node_url: str, endpoint: str): """Return a the response of the input request. Args: request (Dict): the input request. node_url (str): the node url. - node_path (str): the node path. Such as `/v1/chat/completions`. + endpoint (str): the endpoint. Such as `/v1/chat/completions`. """ try: import httpx async with httpx.AsyncClient() as client: - response = await client.post(node_url + node_path, + response = await client.post(node_url + endpoint, json=request, - timeout=API_TIMEOUT_LEN) + timeout=API_READ_TIMEOUT) return response.text - except requests.exceptions.RequestException as e: # noqa + except (Exception, GeneratorExit, RequestException, asyncio.CancelledError) as e: # noqa # yapf: disable + logger.error(f'catched an exception: {e}') return self.handle_api_timeout(node_url) def pre_call(self, node_url): @@ -381,7 +390,11 @@ def add_node(node: Node, raw_request: Request = None): RPM or other metric. All the values of nodes should be the same metric. """ try: - node_manager.add(node.url, node.status) + res = node_manager.add(node.url, node.status) + if res is not None: + logger.error(f'add node {node.url} failed, {res}') + return res + logger.info(f'add node {node.url} successfully') return 'Added successfully' except: # noqa return 'Failed to add, please check the input url.' @@ -392,8 +405,10 @@ def remove_node(node_url: str): """Show available models.""" try: node_manager.remove(node_url) + logger.info(f'delete node {node_url} successfully') return 'Deleted successfully' except: # noqa + logger.error(f'delete node {node_url} failed.') return 'Failed to delete, please check the input url.' @@ -407,28 +422,50 @@ async def chat_completions_v1(request: ChatCompletionRequest, The request should be a JSON object with the following fields: - model: model name. Available from /v1/models. - - messages: string prompt or chat history in OpenAI format. A example - for chat history is `[{"role": "user", "content":"knock knock"}]`. + - messages: string prompt or chat history in OpenAI format. Chat history + example: `[{"role": "user", "content": "hi"}]`. - temperature (float): to modulate the next token probability - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - - max_tokens (int): output token nums + - max_tokens (int | None): output token nums. Default to None. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty - stop (str | List[str] | None): To stop generating further tokens. Only accept stop words that's encoded to one token idex. + - response_format (Dict | None): Only pytorch backend support formatting + response. Examples: `{"type": "json_schema", "json_schema": {"name": + "test","schema": {"properties": {"name": {"type": "string"}}, + "required": ["name"], "type": "object"}}}` + or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}` + - logit_bias (Dict): Bias to logits. Only supported in pytorch engine. + - tools (List): A list of tools the model may call. Currently, only + internlm2 functions are supported as a tool. Use this to specify a + list of functions for which the model can generate JSON inputs. + - tool_choice (str | object): Controls which (if any) tool is called by + the model. `none` means the model will not call any tool and instead + generates a message. Specifying a particular tool via {"type": + "function", "function": {"name": "my_function"}} forces the model to + call that tool. `auto` or `required` will put all the tools information + to the model. Additional arguments supported by LMDeploy: + - top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering - ignore_eos (bool): indicator for ignoring eos - - session_id (int): if not specified, will set random value + - skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. + - min_new_tokens (int): To generate at least numbers of tokens. + - min_p (float): Minimum token probability, which will be scaled by the + probability of the most likely token. It must be a value between + 0 and 1. Typical values are in the 0.01-0.2 range, comparably + selective as setting `top_p` in the 0.99-0.8 range (use the + opposite of normal `top_p` values) Currently we do not support the following features: - - function_call (Users should implement this by themselves) - - logit_bias (not supported yet) - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ @@ -439,6 +476,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, if not node_url: return node_manager.handle_unavailable_model(request.model) + logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() start = node_manager.pre_call(node_url) if request.stream is True: @@ -465,13 +503,13 @@ async def completions_v1(request: CompletionRequest, - model (str): model name. Available from /v1/models. - prompt (str): the input prompt. - suffix (str): The suffix that comes after a completion of inserted text. - - max_tokens (int): output token nums + - max_tokens (int): output token nums. Default to 16. - temperature (float): to modulate the next token probability - top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty @@ -481,7 +519,8 @@ async def completions_v1(request: CompletionRequest, Additional arguments supported by LMDeploy: - ignore_eos (bool): indicator for ignoring eos - - session_id (int): if not specified, will set random value + - skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. - top_k (int): The number of the highest probability vocabulary tokens to keep for top-k-filtering @@ -497,6 +536,7 @@ async def completions_v1(request: CompletionRequest, if not node_url: return node_manager.handle_unavailable_model(request.model) + logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() start = node_manager.pre_call(node_url) if request.stream is True: @@ -517,6 +557,7 @@ def proxy(server_name: str = '0.0.0.0', 'min_observed_latency'] = 'min_expected_latency', api_keys: Optional[Union[List[str], str]] = None, ssl: bool = False, + log_level: str = 'INFO', **kwargs): """To launch the proxy server. @@ -540,6 +581,7 @@ def proxy(server_name: str = '0.0.0.0', if ssl: ssl_keyfile = os.environ['SSL_KEYFILE'] ssl_certfile = os.environ['SSL_CERTFILE'] + logger.setLevel(log_level) uvicorn.run(app=app, host=server_name, port=server_port, diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 77f0bc8dc8..176c3191f4 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -6,7 +6,7 @@ import fire import torch -from lmdeploy.archs import get_model_arch +from lmdeploy.archs import get_model_arch, search_nested_config from lmdeploy.messages import TurbomindEngineConfig from lmdeploy.model import MODELS, best_match_model from lmdeploy.utils import get_logger, get_model @@ -129,16 +129,17 @@ def get_output_model_registered_name_and_config(model_path: str, ] else 'float16' elif dtype in ['float16', 'bfloat16']: if weight_type == 'int4': - logger.warn(f'The model {model_path} is a quantized model, so the ' - f'specified data type {dtype} is ignored') + logger.warning( + f'The model {model_path} is a quantized model, so the ' + f'specified data type {dtype} is ignored') else: weight_type = dtype else: assert 0, f'unsupported specified data type {dtype}' if weight_type == 'bfloat16' and not is_bf16_supported(): - logger.warn('data type fallback to float16 since ' - 'torch.cuda.is_bf16_supported is False') + logger.warning('data type fallback to float16 since ' + 'torch.cuda.is_bf16_supported is False') weight_type = 'float16' config.model_config.model_arch = model_arch config.model_config.weight_type = weight_type @@ -174,23 +175,6 @@ def pack_model_repository(workspace_path: str): dst=osp.join(model_repo_dir, 'postprocessing')) -def find_quantization_config(nested, target_key): - if isinstance(nested, dict): - for key, value in nested.items(): - if key == target_key: - return value - if isinstance(value, (dict, list)): - result = find_quantization_config(value, target_key) - if result is not None: - return result - elif isinstance(nested, list): - for item in nested: - result = find_quantization_config(item, target_key) - if result is not None: - return result - return None - - def get_tm_model(model_path, model_name, chat_template_name, @@ -213,8 +197,7 @@ def get_tm_model(model_path, If it is None, the turbomind model won't be saved """ _, cfg = get_model_arch(model_path) - quant_config = find_quantization_config(cfg.to_dict(), - 'quantization_config') + quant_config = search_nested_config(cfg.to_dict(), 'quantization_config') if quant_config: quant_method = quant_config.get('quant_method') _group_size = int(quant_config.get('group_size', 0)) diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 11e99edfa0..2b9c5156ed 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.archs import get_model_arch +from lmdeploy.archs import get_model_arch, search_nested_config from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -80,7 +80,12 @@ def _is_head_dim_supported(cfg): if os.path.exists(triton_model_path): support_by_turbomind = True else: + arch, cfg = get_model_arch(model_path) + quant_method = search_nested_config(cfg.to_dict(), 'quant_method') + if quant_method and quant_method in ['smooth_quant']: + # tm hasn't support quantized models by applying smoothquant + return False if arch in SUPPORTED_ARCHS.keys(): support_by_turbomind = True