diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py index b69cb1dca5..f9664f13ff 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py @@ -1,15 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings from importlib import import_module +from typing import List import torch import torch.distributed +import torch_npu from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.model_inputs import StepContext from lmdeploy.utils import get_logger from ...graph_runner import GraphRunner +ACL_FORMAT_ND = 2 + logger = get_logger('lmdeploy') @@ -110,3 +115,31 @@ def allocate_gpu_cache_mark_static(self): return gpu_cache setattr(cache_engine_class, func_str, allocate_gpu_cache_mark_static) + + def _convert_kv_format(self, + past_key_values: List[List[torch.Tensor]]) -> None: + """Convert key/value caches to ACL_FORMAT_ND format if needed.""" + # Check format of first KV cache + if torch_npu.get_npu_format(past_key_values[0][0]) == ACL_FORMAT_ND: + return + + # Convert all KV caches to ACL_FORMAT_ND + for layer_kv in past_key_values: + key_cache, value_cache = layer_kv + torch_npu.npu_format_cast(key_cache, ACL_FORMAT_ND) + torch_npu.npu_format_cast(value_cache, ACL_FORMAT_ND) + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + if self.enable_graph: + self._convert_kv_format(past_key_values) + return self.model.prepare_inputs_for_generation( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + context=context, + )