Skip to content

Commit

Permalink
convert kv cache to nd format in ascend graph mode (InternLM#2853)
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 authored Dec 4, 2024
1 parent 69a4306 commit 9bfdeae
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from importlib import import_module
from typing import List

import torch
import torch.distributed
import torch_npu

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext
from lmdeploy.utils import get_logger

from ...graph_runner import GraphRunner

ACL_FORMAT_ND = 2

logger = get_logger('lmdeploy')


Expand Down Expand Up @@ -110,3 +115,31 @@ def allocate_gpu_cache_mark_static(self):
return gpu_cache

setattr(cache_engine_class, func_str, allocate_gpu_cache_mark_static)

def _convert_kv_format(self,
past_key_values: List[List[torch.Tensor]]) -> None:
"""Convert key/value caches to ACL_FORMAT_ND format if needed."""
# Check format of first KV cache
if torch_npu.get_npu_format(past_key_values[0][0]) == ACL_FORMAT_ND:
return

# Convert all KV caches to ACL_FORMAT_ND
for layer_kv in past_key_values:
key_cache, value_cache = layer_kv
torch_npu.npu_format_cast(key_cache, ACL_FORMAT_ND)
torch_npu.npu_format_cast(value_cache, ACL_FORMAT_ND)

def prepare_inputs_for_generation(
self,
past_key_values: List[List[torch.Tensor]],
inputs_embeds: torch.Tensor = None,
context: StepContext = None,
):
"""prepare inputs."""
if self.enable_graph:
self._convert_kv_format(past_key_values)
return self.model.prepare_inputs_for_generation(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
context=context,
)

0 comments on commit 9bfdeae

Please sign in to comment.