Skip to content

Commit

Permalink
Merge branch 'main' into support-molmo
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Nov 5, 2024
2 parents 8d8f8b9 + 364a142 commit 0bfcae6
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 29 deletions.
32 changes: 32 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch

from lmdeploy.pytorch.kernels.dlinfer import linear

from ..linear import LinearBuilder, LinearImpl


class DlinferLinearImpl(LinearImpl):
"""Dlinfer linear implementation api."""

def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
"""forward."""
return linear(x, weight, bias, all_reduce)


class DlinferLinearBuilder(LinearBuilder):
"""Dlinfer linear implementation builder."""

@staticmethod
def build(in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None):
"""build."""
return DlinferLinearImpl()
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.FusedMoE:
from .moe import DlinferFusedMoEBuilder
return DlinferFusedMoEBuilder
elif layer_type == OpType.Linear:
from .linear import DlinferLinearBuilder
return DlinferLinearBuilder
elif layer_type == OpType.LinearW4A16:
from .awq_modules import AwqLinearW4A16Builder
return AwqLinearW4A16Builder
Expand Down
153 changes: 124 additions & 29 deletions lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
from torch import nn

from ..default.rotary_embedding import (Llama3RotaryEmbeddingImpl,
LlamaDynamicNTKScalingRotaryEmbedding)
from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding
from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters,
RopeType, RotaryEmbeddingBuilder,
RotaryEmbeddingImpl, YarnParameters)


def _rotary_embedding_fwd(position_ids: torch.Tensor,
inv_freq: torch.Tensor,
scaling_factor: float,
mscale: float = None,
dtype: torch.dtype = None):
"""rotary embedding forward."""
if dtype is None:
dtype = torch.float16

if scaling_factor != 1.0:
position_ids = position_ids.float() / scaling_factor
else:
position_ids = position_ids.float()

inv_freq_expanded = inv_freq.view(1, -1, 1)
position_ids_expanded = position_ids.unsqueeze(1)

tmp = torch.bmm(inv_freq_expanded, position_ids_expanded)
freqs = tmp.transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

if mscale is not None:
cos = cos * mscale
sin = sin * mscale
return cos.to(dtype=dtype), sin.to(dtype=dtype)


class DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module):
"""base rotary embedding."""

Expand All @@ -28,34 +58,100 @@ def __init__(self,
def forward(self, x, position_ids):
"""forward."""
# x: [bs, num_attention_heads, seq_len, head_size]
dtype = x.dtype
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)
return _rotary_embedding_fwd(position_ids,
self.inv_freq,
scaling_factor=self.scaling_factor,
dtype=dtype)

if self.scaling_factor != 1.0:
position_ids = position_ids.float() / self.scaling_factor
else:
position_ids = position_ids.float()

inv_freq_expanded = self.inv_freq.view(1, -1, 1)
position_ids_expanded = position_ids.unsqueeze(1)

# # Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(
device_type, str) and device_type != 'mps' else 'cpu'
inv_freq_expanded = inv_freq_expanded
position_ids_expanded = position_ids_expanded
tmp = torch.bmm(inv_freq_expanded, position_ids_expanded)
freqs = tmp.transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

class DlinferLlamaDynamicNTKScalingRotaryEmbedding(
LlamaDynamicNTKScalingRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""

def __init__(self,
dim: int,
base: int = 10000,
scaling_factor: float = 1.0,
max_position_embeddings: int = 2048):
super().__init__(dim, base, scaling_factor, max_position_embeddings)
self.dim_scale_ratio = self.dim / (self.dim - 2)
self.pos_freq_scaling = torch.arange(
0, self.dim, 2, dtype=torch.int64).float().cuda() / self.dim
self.scale_offset = self.scaling_factor - 1
self.pos_scale_factor = self.scaling_factor / \
self.max_position_embeddings

def _ntk_inv_freq(self, seq_len: torch.Tensor):
"""Calculate inverse frequency with NTK scaling."""
base = self.base * ((self.pos_scale_factor * seq_len) -
self.scale_offset)**self.dim_scale_ratio
inv_freq = 1.0 / (base**self.pos_freq_scaling)
return inv_freq

def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""forward."""
dtype = x.dtype
seq_len = torch.max(position_ids) + 1
ntk_inv_freq = self._ntk_inv_freq(seq_len)
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)
inv_freq = torch.where(seq_len > self.max_position_embeddings,
ntk_inv_freq, self.inv_freq)

cos, sin = _rotary_embedding_fwd(position_ids,
inv_freq,
scaling_factor=1.0,
dtype=dtype)
return cos, sin


class DlinferLlama3RotaryEmbeddingImpl(DlinferRotaryEmbeddingImpl):
"""llama3 rotary embedding implementation."""

def __init__(
self,
dim: int,
base: int = 10000,
scaling_factor: float = 1.0,
low_freq_factor: float = 1.0,
high_freq_factor: float = 4.0,
original_max_position_embeddings: int = 8194,
):
super().__init__(dim, base, scaling_factor)
old_context_len = original_max_position_embeddings
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor

inv_freq = self.inv_freq
factor = self.scaling_factor

wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen,
inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen >
low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq,
inv_freq_llama)
self.scaling_factor = 1.0
self.register_buffer('inv_freq', inv_freq_llama)


class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
"""rotary embedding builder."""
"""rotary embedding dlinfer builder."""

@staticmethod
def build(
Expand All @@ -72,13 +168,12 @@ def build(
if emb_type in (RopeType.Default, RopeType.LinearScaling):
return DlinferRotaryEmbeddingImpl(dim, base, scaling_factor)
elif emb_type == RopeType.DynamicNTKScaling:
return LlamaDynamicNTKScalingRotaryEmbedding(
return DlinferLlamaDynamicNTKScalingRotaryEmbedding(
dim, base, scaling_factor, max_position_embeddings)
elif emb_type == RopeType.Llama3:
return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor,
llama3_params.low_freq_factor,
llama3_params.high_freq_factor,
max_position_embeddings)
return DlinferLlama3RotaryEmbeddingImpl(
dim, base, scaling_factor, llama3_params.low_freq_factor,
llama3_params.high_freq_factor, max_position_embeddings)
else:
raise NotImplementedError(
f'Unsupported embedding type: {emb_type}')
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
from .fused_moe import fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
Expand All @@ -15,6 +16,7 @@
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
]
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import dlinfer.ops as ext_ops
from torch import Tensor


def linear(x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
all_reduce: bool = False):
return ext_ops.linear(x, weight, bias=bias, all_reduce=all_reduce)

0 comments on commit 0bfcae6

Please sign in to comment.