Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Dec 14, 2023
1 parent c8cdf52 commit 232c294
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 20 deletions.
1 change: 0 additions & 1 deletion lmdeploy/turbomind/deploy/target_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bf16 import TurbomindBF16Model # noqa: F401
from .fp import TurbomindModel # noqa: F401
from .w4 import TurbomindW4Model # noqa: F401
16 changes: 0 additions & 16 deletions lmdeploy/turbomind/deploy/target_model/bf16.py

This file was deleted.

2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/target_model/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def transpose_tensor(input: List[torch.Tensor]):
return output


@OUTPUT_MODELS.register_module(name='fp16')
@OUTPUT_MODELS.register_module(name=['fp16', 'bf16'])
class TurbomindModel(BaseOutputModel):
"""Export to turbomind fp16 format."""

Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/models/llama/flash_attention2/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ class FlashAttentionOpImpl<T, FMHA_VERSION>::impl {

fwd_params.blockmask = reinterpret_cast<void*>(params.mask);

#ifdef ENABLE_BF16
fwd_params.is_bf16 = std::is_same<T, __nv_bfloat16>::value;
#else
fwd_params.is_bf16 = false;
#endif
fwd_params.is_causal = true;

fwd_params.q_enable_seqlen = params.layout_q.use_seqlens;
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ PYBIND11_MODULE(_turbomind, m)
model->setFfiLock(gil_control);
return model;
}
if (data_type == "bf16") {
else if (data_type == "bf16") {
#ifdef ENABLE_BF16
auto model = std::make_shared<LlamaTritonModel<__nv_bfloat16>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config);
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"),
reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0),
model_dir);
}else if (data_type == "bf16") {
}
else if (data_type == "bf16") {
#ifdef ENABLE_BF16
return std::make_shared<LlamaTritonModel<__nv_bfloat16>>(
reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"),
Expand Down

0 comments on commit 232c294

Please sign in to comment.