Skip to content

Commit

Permalink
Merge branch 'InternLM:main' into install_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhulinJulia24 authored Dec 26, 2024
2 parents 4344edd + 191a7dd commit e43d9ab
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 67 deletions.
2 changes: 1 addition & 1 deletion autotest/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_all_model_list(tp_num: int = None,
model_type=model_type):
if case not in case_list:
case_list.append(case)
return [x for x in case_list if 'w8a8' not in x]
return case_list


def get_quantization_model_list(type):
Expand Down
21 changes: 20 additions & 1 deletion lmdeploy/archs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union

from transformers import AutoConfig

Expand Down Expand Up @@ -193,3 +193,22 @@ def get_model_arch(model_path: str):
raise RuntimeError(
f'Could not find model architecture from config: {_cfg}')
return arch, cfg


def search_nested_config(config, key):
"""Recursively searches for the value associated with the given key in a
nested configuration of a model."""
if isinstance(config, Dict):
for k, v in config.items():
if k == key:
return v
if isinstance(v, (Dict, List)):
result = search_nested_config(v, key)
if result is not None:
return result
elif isinstance(config, List):
for item in config:
result = search_nested_config(item, key)
if result is not None:
return result
return None
1 change: 1 addition & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def add_parser_proxy():
help='the strategy to dispatch requests to nodes')
ArgumentHelper.api_keys(parser)
ArgumentHelper.ssl(parser)
ArgumentHelper.log_level(parser)

@staticmethod
def gradio(args):
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,13 @@ def calibrate(model: str,
if dtype == 'float16':
model.half()
elif dtype == 'bfloat16':
assert torch.cuda.is_bf16_supported(
), 'your device does not support bfloat16 please set --dtype float16' # noqa
model.to(torch.bfloat16)
elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16:
print('Warning: we cast model to float16 to prevent OOM. You'
' may enforce it bfloat16 by `--dtype bfloat16`')
model.half()
model.eval()

model_type = type(model).__name__
Expand Down
19 changes: 14 additions & 5 deletions lmdeploy/lite/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ def load_hf_from_pretrained(pretrained_model_name_or_path,
dtype: Literal['float16', 'bfloat16',
'auto'], **kwargs):

if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported():
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')

kwargs.pop('config', None)

hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
torch_dtype=dtype,
trust_remote_code=True)

# HACK hard code for qwen, other configs do not have the `fp16` attribute.
Expand All @@ -29,13 +28,23 @@ def load_hf_from_pretrained(pretrained_model_name_or_path,
else:
hf_config.fp16 = True

if dtype != 'auto':
setattr(hf_config, 'torch_dtype', dtype)
torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)
if dtype == 'bfloat16':
torch_dtype = torch.bfloat16
elif dtype == 'float16':
torch_dtype = torch.float16
elif dtype == 'auto' and torch_dtype == torch.bfloat16:
print('Warning: we cast model to float16 to prevent OOM. '
'You may enforce it bfloat16 by `--dtype bfloat16`')
torch_dtype = torch.float16

with LoadNoInit():
# Load model
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, config=hf_config, **kwargs)
pretrained_model_name_or_path,
config=hf_config,
torch_dtype=torch_dtype,
**kwargs)
model.config.use_cache = False

return model
2 changes: 1 addition & 1 deletion lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,5 +1921,5 @@ def best_match_model(query: str) -> Optional[str]:
for name, model in MODELS.module_dict.items():
if model.match(query):
return model.match(query)
logger.warn(f'Did not find a chat template matching {query}.')
logger.warning(f'Did not find a chat template matching {query}.')
return 'base'
94 changes: 94 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl


class DlinferFlashAttentionImpl(FlashAttentionImpl):
"""dlinfer flash attention implementation."""

def __init__(
self,
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
):
if scale is None:
scale = 1.0 / (head_dim**0.5)
if num_kv_heads is None:
num_kv_heads = num_heads
if v_head_dim is None:
v_head_dim = head_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.v_head_dim = v_head_dim
self.causal = causal
self.sliding_window = sliding_window
self.logical_softcapping = logical_softcapping
from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd
self.flash_attention_fwd = flash_attention_fwd

def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None):
"""forward."""
q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_dim, )
out = query.new_empty(o_shape)
self.flash_attention_fwd(
query,
key,
value,
out,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
max_q_seqlen=max_q_seqlen,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logical_softcapping,
causal=self.causal,
)
return out


class DlinferFlashAttentionBuilder(FlashAttentionBuilder):
"""dlinfer attention builder."""

@staticmethod
def build(
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
**kwargs,
) -> FlashAttentionImpl:
"""build."""
return DlinferFlashAttentionImpl(
num_heads=num_heads,
head_dim=head_dim,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_dim=v_head_dim,
causal=causal,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
if layer_type == OpType.PagedAttention:
from .attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
elif layer_type == OpType.FlashAttention:
from .flash_attention import DlinferFlashAttentionBuilder
return DlinferFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
from .flash_attention import flash_attention_fwd
from .fused_moe import fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
Expand All @@ -16,6 +17,7 @@
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
'flash_attention_fwd',
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
Expand Down
35 changes: 35 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import Tensor


def flash_attention_fwd(
query_states: Tensor,
key_states: Tensor,
value_states: Tensor,
attn_output: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None,
window_size: int = None,
sm_scale: float = None,
logit_softcapping: float = None,
causal: bool = True,
):
num_q_heads = query_states.shape[1]
num_kv_heads = value_states.shape[1]
return ext_ops.prefill_attention(
query_states,
key_states,
value_states,
q_start_loc,
q_seqlens,
max_q_seqlen,
num_q_heads,
num_kv_heads,
attn_mask=None,
softmax_scale=sm_scale,
attn_output=attn_output,
)
4 changes: 2 additions & 2 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ async def generate(
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
if not gen_config.do_sample:
logger.warn(f'GenerationConfig: {gen_config}')
logger.warn(
logger.warning(f'GenerationConfig: {gen_config}')
logger.warning(
'Since v0.6.0, lmdeploy add `do_sample` in '
'GenerationConfig. It defaults to False, meaning greedy '
'decoding. Please set `do_sample=True` if sampling '
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/serve/proxy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import enum

LATENCY_DEEQUE_LEN = 15
API_TIMEOUT_LEN = 100
LATENCY_DEQUE_LEN = 15
API_READ_TIMEOUT = 100


class Strategy(enum.Enum):
Expand Down
Loading

0 comments on commit e43d9ab

Please sign in to comment.