Skip to content

Commit

Permalink
feat(pytorch_poc): implement ReRoPE (#625)
Browse files Browse the repository at this point in the history
* fix(pytorch_poc): memory cal

* style(pytorch_poc): lint

* style(.pre-commit-config.yaml): update

* style(pytorch_poc): remove useless

* feat(pytorch_poc): llama2 support rerope

* feat(pytorch_poc): fix long input generate

* feat(lmdeploy): add kernel

* feat(lmdeploy): update

* feat(lmdeploy): add rerope implementation

* fix(lmdeploy/pytorch_poc): apply rotary_emb

* fix(lmdeploy): update

* style(pytorch_poc): format

* style(lmdeploy): fix lint

* style(lmdeploy): typo

* style(pytorch_poc): format

* style(pytorch_poc): format

* fix(pytorch_poc): rms_norm add mask

* style(pytorch_poc/kernels): format rerope

* style(pytorch_poc): format rerope attn function description

* style(lmdeploy/pytorch_poc): format

* style(pytorch_poc): add code ref

* style(pytorch_poc): format rerope attn
  • Loading branch information
tpoisonooo authored Nov 8, 2023
1 parent f8b1afd commit 3618c0d
Show file tree
Hide file tree
Showing 8 changed files with 930 additions and 51 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__/

# Distribution / packaging
.Python
triton-rerope/
develop-eggs/
dist/
downloads/
Expand Down Expand Up @@ -60,6 +61,7 @@ work_dir*/
*generate_config.json

# Pytorch
*.pt
*.pth
*.py~
*.sh~
Expand Down
30 changes: 22 additions & 8 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ def __init__(
seq_length: torch.Tensor,
world_size: int = 1,
device='cuda',
json_config: dict = None,
):
self.block_offsets_list = block_offsets
self.history_lengths = history_lengths
self.position_ids = position_ids
self.q_start_loc = q_start_loc
self.seq_length = seq_length
self.world_size = world_size
self.json_config = json_config

# padding zero
pad_sequence = torch.nn.utils.rnn.pad_sequence
Expand Down Expand Up @@ -225,6 +227,8 @@ def _update_cache_config(model_config: ModelConfig,
if cache_config.num_gpu_blocks == 0:
cache_config.num_gpu_blocks = int(gpu_mem / cache_block_size)

logger.info('block num: {}'.format(cache_config.num_gpu_blocks))


def _get_torch_dtype(config: Any, default: str = 'float16'):
"""Get the torch dtype from the model config.
Expand All @@ -243,6 +247,7 @@ def _tp_model_loop(
extra_args: List[str],
model_config: ModelConfig,
cache_config: CacheConfig,
json_config: dict,
in_que: mp.Queue,
out_que: mp.Queue,
world_size: int,
Expand Down Expand Up @@ -406,6 +411,7 @@ def _tp_model_loop(
seq_length=inputs['seq_length'],
world_size=world_size,
),
json_config=json_config,
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
)

Expand Down Expand Up @@ -456,14 +462,13 @@ class Engine:
tp (int): Number of tensor parallel.
"""

def __init__(
self,
model_path: str,
scheduler_config: SchedulerConfig = None,
cache_config: CacheConfig = None,
tp: int = 1,
trust_remote_code=True,
) -> None:
def __init__(self,
model_path: str,
scheduler_config: SchedulerConfig = None,
cache_config: CacheConfig = None,
tp: int = 1,
trust_remote_code=True,
json_config_file: str = 'config.json') -> None:

self.tp = tp
self.gpu_count = tp
Expand All @@ -480,6 +485,11 @@ def __init__(
cache_config = CacheConfig(block_size=64,
num_cpu_blocks=0,
num_gpu_blocks=0)

self.json_config = None
with open(os.path.join(model_path, json_config_file)) as f:
self.json_config = json.load(f)

if 'falcon' in model_path:
if hf_config.new_decoder_architecture:
# 40b-instruct, GQA
Expand Down Expand Up @@ -553,6 +563,7 @@ def __init__(
['context', 'use_origin', 'q_seq_info'],
model_config=model_config,
cache_config=cache_config,
json_config=self.json_config,
in_que=self.tp_model_in_que,
out_que=self.tp_model_out_que,
world_size=tp,
Expand Down Expand Up @@ -583,6 +594,7 @@ def patch_model_tp(
extra_args: List[str],
model_config: ModelConfig,
cache_config: CacheConfig,
json_config: dict,
in_que: mp.Queue,
out_que: mp.Queue,
world_size: int,
Expand All @@ -609,6 +621,7 @@ def patch_model_tp(
dict(
model_config=model_config,
cache_config=cache_config,
json_config=json_config,
in_que=in_que,
out_que=out_que,
world_size=world_size,
Expand Down Expand Up @@ -772,6 +785,7 @@ def _model_forward(self, inputs: Dict, swap_in_map: Dict[int, int],
position_ids=inputs['position_ids'],
q_start_loc=inputs['q_start_loc'],
seq_length=inputs['seq_length'],
json_config=self.json_config,
),
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
)
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch_poc/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from .fill_kv_cache import fill_kv_cache
from .flashattention_nopad import context_attention_fwd
from .pagedattention import paged_attention_fwd
from .rerope_attention import rerope_attention_fwd
from .rms_norm import rms_norm

__all__ = [
'apply_rotary_pos_emb', 'context_attention_fwd', 'paged_attention_fwd',
'biased_paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache',
'rms_norm'
'rms_norm', 'rerope_attention_fwd'
]
Loading

0 comments on commit 3618c0d

Please sign in to comment.