From ed9aa15b14e97e7987d4f90b8058e346bd175616 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Tue, 5 Nov 2024 19:39:55 +0800 Subject: [PATCH 1/2] feat: support dynamic/llama3 rotary embedding in ascend graph mode (#2670) * feat: support dynamic ntk scaling rotary embedding in ascend graph mode * add llama3 rotary embedding * remove useless codes --- .../backends/dlinfer/rotary_embedding.py | 153 ++++++++++++++---- 1 file changed, 124 insertions(+), 29 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py index e97c9d1338..fab6e510f5 100644 --- a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py @@ -1,14 +1,44 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math + import torch from torch import nn -from ..default.rotary_embedding import (Llama3RotaryEmbeddingImpl, - LlamaDynamicNTKScalingRotaryEmbedding) +from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters) +def _rotary_embedding_fwd(position_ids: torch.Tensor, + inv_freq: torch.Tensor, + scaling_factor: float, + mscale: float = None, + dtype: torch.dtype = None): + """rotary embedding forward.""" + if dtype is None: + dtype = torch.float16 + + if scaling_factor != 1.0: + position_ids = position_ids.float() / scaling_factor + else: + position_ids = position_ids.float() + + inv_freq_expanded = inv_freq.view(1, -1, 1) + position_ids_expanded = position_ids.unsqueeze(1) + + tmp = torch.bmm(inv_freq_expanded, position_ids_expanded) + freqs = tmp.transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + if mscale is not None: + cos = cos * mscale + sin = sin * mscale + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + class DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): """base rotary embedding.""" @@ -28,34 +58,100 @@ def __init__(self, def forward(self, x, position_ids): """forward.""" # x: [bs, num_attention_heads, seq_len, head_size] + dtype = x.dtype if self.inv_freq.device != x.device: self.inv_freq = self.inv_freq.to(x.device) + return _rotary_embedding_fwd(position_ids, + self.inv_freq, + scaling_factor=self.scaling_factor, + dtype=dtype) - if self.scaling_factor != 1.0: - position_ids = position_ids.float() / self.scaling_factor - else: - position_ids = position_ids.float() - - inv_freq_expanded = self.inv_freq.view(1, -1, 1) - position_ids_expanded = position_ids.unsqueeze(1) - - # # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance( - device_type, str) and device_type != 'mps' else 'cpu' - inv_freq_expanded = inv_freq_expanded - position_ids_expanded = position_ids_expanded - tmp = torch.bmm(inv_freq_expanded, position_ids_expanded) - freqs = tmp.transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +class DlinferLlamaDynamicNTKScalingRotaryEmbedding( + LlamaDynamicNTKScalingRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0, + max_position_embeddings: int = 2048): + super().__init__(dim, base, scaling_factor, max_position_embeddings) + self.dim_scale_ratio = self.dim / (self.dim - 2) + self.pos_freq_scaling = torch.arange( + 0, self.dim, 2, dtype=torch.int64).float().cuda() / self.dim + self.scale_offset = self.scaling_factor - 1 + self.pos_scale_factor = self.scaling_factor / \ + self.max_position_embeddings + + def _ntk_inv_freq(self, seq_len: torch.Tensor): + """Calculate inverse frequency with NTK scaling.""" + base = self.base * ((self.pos_scale_factor * seq_len) - + self.scale_offset)**self.dim_scale_ratio + inv_freq = 1.0 / (base**self.pos_freq_scaling) + return inv_freq + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """forward.""" + dtype = x.dtype + seq_len = torch.max(position_ids) + 1 + ntk_inv_freq = self._ntk_inv_freq(seq_len) + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + inv_freq = torch.where(seq_len > self.max_position_embeddings, + ntk_inv_freq, self.inv_freq) + + cos, sin = _rotary_embedding_fwd(position_ids, + inv_freq, + scaling_factor=1.0, + dtype=dtype) + return cos, sin + + +class DlinferLlama3RotaryEmbeddingImpl(DlinferRotaryEmbeddingImpl): + """llama3 rotary embedding implementation.""" + + def __init__( + self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, + original_max_position_embeddings: int = 8194, + ): + super().__init__(dim, base, scaling_factor) + old_context_len = original_max_position_embeddings + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + inv_freq = self.inv_freq + factor = self.scaling_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, + inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > + low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, + inv_freq_llama) + self.scaling_factor = 1.0 + self.register_buffer('inv_freq', inv_freq_llama) class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): - """rotary embedding builder.""" + """rotary embedding dlinfer builder.""" @staticmethod def build( @@ -72,13 +168,12 @@ def build( if emb_type in (RopeType.Default, RopeType.LinearScaling): return DlinferRotaryEmbeddingImpl(dim, base, scaling_factor) elif emb_type == RopeType.DynamicNTKScaling: - return LlamaDynamicNTKScalingRotaryEmbedding( + return DlinferLlamaDynamicNTKScalingRotaryEmbedding( dim, base, scaling_factor, max_position_embeddings) elif emb_type == RopeType.Llama3: - return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, - llama3_params.low_freq_factor, - llama3_params.high_freq_factor, - max_position_embeddings) + return DlinferLlama3RotaryEmbeddingImpl( + dim, base, scaling_factor, llama3_params.low_freq_factor, + llama3_params.high_freq_factor, max_position_embeddings) else: raise NotImplementedError( f'Unsupported embedding type: {emb_type}') From 364a142916dd2c7264408af0d282ad9879b0069d Mon Sep 17 00:00:00 2001 From: yaofengchen <67218893+yao-fengchen@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:41:20 +0800 Subject: [PATCH 2/2] add linear op on dlinfer platform (#2627) * add linear op on ascend platform * update code --- lmdeploy/pytorch/backends/dlinfer/linear.py | 32 +++++++++++++++++++ .../pytorch/backends/dlinfer/op_backend.py | 3 ++ lmdeploy/pytorch/kernels/dlinfer/__init__.py | 2 ++ lmdeploy/pytorch/kernels/dlinfer/linear.py | 12 +++++++ 4 files changed, 49 insertions(+) create mode 100644 lmdeploy/pytorch/backends/dlinfer/linear.py create mode 100644 lmdeploy/pytorch/kernels/dlinfer/linear.py diff --git a/lmdeploy/pytorch/backends/dlinfer/linear.py b/lmdeploy/pytorch/backends/dlinfer/linear.py new file mode 100644 index 0000000000..567a01dddf --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/linear.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch + +from lmdeploy.pytorch.kernels.dlinfer import linear + +from ..linear import LinearBuilder, LinearImpl + + +class DlinferLinearImpl(LinearImpl): + """Dlinfer linear implementation api.""" + + def forward(self, + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + return linear(x, weight, bias, all_reduce) + + +class DlinferLinearBuilder(LinearBuilder): + """Dlinfer linear implementation builder.""" + + @staticmethod + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): + """build.""" + return DlinferLinearImpl() diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 031f51fdca..52a8830595 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -40,6 +40,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.FusedMoE: from .moe import DlinferFusedMoEBuilder return DlinferFusedMoEBuilder + elif layer_type == OpType.Linear: + from .linear import DlinferLinearBuilder + return DlinferLinearBuilder elif layer_type == OpType.LinearW4A16: from .awq_modules import AwqLinearW4A16Builder return AwqLinearW4A16Builder diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index 4d678bfe68..8f86f0019a 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -4,6 +4,7 @@ from .awq_kernels import awq_linear from .fill_kv_cache import fill_kv_cache from .fused_moe import fused_moe +from .linear import linear from .moe_gating_topk_softmax import moe_gating_topk_softmax from .pagedattention import paged_attention_fwd from .rms_norm import rms_norm @@ -15,6 +16,7 @@ 'fill_kv_cache', 'fused_moe', 'paged_attention_fwd', + 'linear', 'moe_gating_topk_softmax', 'multinomial_sampling', ] diff --git a/lmdeploy/pytorch/kernels/dlinfer/linear.py b/lmdeploy/pytorch/kernels/dlinfer/linear.py new file mode 100644 index 0000000000..695e089fd8 --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/linear.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import dlinfer.ops as ext_ops +from torch import Tensor + + +def linear(x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + all_reduce: bool = False): + return ext_ops.linear(x, weight, bias=bias, all_reduce=all_reduce)