Skip to content

Commit

Permalink
support-moe-fp8 (InternLM#3007)
Browse files Browse the repository at this point in the history
* support-moe-fp8-w8a8

* disable transpose weights for int8/fp8

* fix conflicts
  • Loading branch information
RunningLeon authored Jan 13, 2025
1 parent 7890684 commit 4ac1894
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 33 deletions.
37 changes: 21 additions & 16 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,24 @@ def build(top_k: int, num_experts: int, renormalize: bool = False):
class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl):
"""triton fused moe w8a8 implementation."""

def __init__(self,
top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16):
def __init__(
self,
top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16,
quant_dtype: torch.dtype = torch.int8,
):
self.num_experts = num_experts
self.top_k = top_k
self.renormalize = renormalize
self.out_dtype = out_dtype
self.quant_dtype = quant_dtype

def update_weights(self, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
down_scale: torch.Tensor):
gate_up_weights = gate_up_weights.transpose(1,
2).contiguous().transpose(
1, 2)
down_weights = down_weights.transpose(1,
2).contiguous().transpose(1, 2)
# do not transpose weight for int8/fp8
return gate_up_weights, down_weights, gate_up_scale, down_scale

def support_ep(self):
Expand Down Expand Up @@ -133,7 +133,7 @@ def forward(self,
if isinstance(hidden_states, torch.Tensor):
hidden_states = hidden_states.contiguous()
input_quant, input_scale = per_token_quant_int8(
hidden_states, 1e-7)
hidden_states, 1e-7, quant_dtype=self.quant_dtype)
else:
assert isinstance(hidden_states, QTensor)
input_quant, input_scale = (hidden_states.tensor,
Expand All @@ -154,6 +154,7 @@ def forward(self,
topk_ids=topk_ids,
topk=self.top_k,
out_dtype=self.out_dtype,
quant_dtype=self.quant_dtype,
expert_offset=expert_offset,
num_experts=num_experts,
renormalize=self.renormalize)
Expand All @@ -163,15 +164,19 @@ class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder):
"""triton fused moe w8a8 builder."""

@staticmethod
def build(top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16):
def build(
top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16,
quant_dtype: torch.dtype = torch.int8,
):
"""build from mlp."""
return TritonFusedMoEW8A8Impl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize,
out_dtype=out_dtype)
out_dtype=out_dtype,
quant_dtype=quant_dtype)


class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class FusedMoEW8A8Builder(ABC):
def build(top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16):
out_dtype: torch.dtype = torch.float16,
quant_dtype: torch.dtype = torch.int8):
"""build from mlp."""
raise NotImplementedError

Expand Down
25 changes: 22 additions & 3 deletions lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def get_cuda_autotune_config():
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 1,
},
num_stages=3,
num_warps=8),
]


Expand Down Expand Up @@ -79,6 +88,7 @@ def fused_moe_w8a8_kernel(
expert_offset: tl.constexpr,
reindex_a: tl.constexpr,
reindex_c: tl.constexpr,
ACCUMULATOR_DTYPE: tl.constexpr,
):
"""fused moe kernel."""
exp_id = tl.program_id(1)
Expand Down Expand Up @@ -129,7 +139,8 @@ def fused_moe_w8a8_kernel(
offs_bn[None, :] * stride_bn)
bs_ptrs = B_scale + exp_id * stride_bse + offs_bn

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
dtype=ACCUMULATOR_DTYPE)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
Expand All @@ -139,7 +150,10 @@ def fused_moe_w8a8_kernel(
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
accumulator = tl.dot(a, b, acc=accumulator)
accumulator = tl.dot(a,
b,
acc=accumulator,
out_dtype=ACCUMULATOR_DTYPE)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

Expand Down Expand Up @@ -189,6 +203,7 @@ def fused_moe_w8a8_kernel_launcher(

assert A_scale.is_contiguous()
assert B_scale.is_contiguous()
accumulator_dtype = tl.float32 if A.is_floating_point() else tl.int32

def _grid_fn(META):
grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) *
Expand Down Expand Up @@ -226,6 +241,7 @@ def _grid_fn(META):
reindex_a=reindex_a,
reindex_c=reindex_c,
M_NP2=M_NP2,
ACCUMULATOR_DTYPE=accumulator_dtype,
**kernel_meta,
)

Expand All @@ -240,6 +256,7 @@ def fused_moe_w8a8(input: torch.Tensor,
topk_ids: torch.Tensor,
topk: int,
out_dtype: torch.dtype = torch.float16,
quant_dtype: torch.dtype = torch.int8,
expert_offset: int = 0,
num_experts: int = None,
renormalize: bool = False) -> torch.Tensor:
Expand Down Expand Up @@ -283,7 +300,9 @@ def fused_moe_w8a8(input: torch.Tensor,
gate_cache = silu_and_mul(intermediate_cache1)
del intermediate_cache1
gate_cache = gate_cache.unflatten(0, unflat_size)
gate_cache, gate_scale = per_token_quant_int8(gate_cache, 1e-7)
gate_cache, gate_scale = per_token_quant_int8(gate_cache,
1e-7,
quant_dtype=quant_dtype)

intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
dtype=out_dtype,
Expand Down
35 changes: 22 additions & 13 deletions lmdeploy/pytorch/nn/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,14 @@ def __init__(self,
weight_type: str,
device: torch.device,
expert_list: List[int] = None,
ep: bool = False):
ep: bool = False,
quant_dtype: torch.dtype = torch.int8):
super().__init__(
num_experts=num_experts,
in_features=in_features,
out_features=out_features,
weight_type=weight_type,
dtype=torch.int8,
dtype=quant_dtype,
device=device,
expert_list=expert_list,
ep=ep,
Expand Down Expand Up @@ -267,17 +268,23 @@ def __init__(self,
top_k: int,
renormalize: bool = False,
dtype: Optional[torch.dtype] = None,
quant_dtype: Optional[torch.dtype] = torch.int8,
device: Optional[torch.device] = None,
all_reduce: bool = True,
enable_ep: bool = False):
super().__init__()

if device is None:
device = torch.device('cpu')
dtype = torch.float16 if dtype is None else dtype

impl_builder = get_backend().get_layer_impl_builder(
OpType.FusedMoEW8A8)
self.impl = impl_builder.build(top_k, num_experts, renormalize, dtype)
self.impl = impl_builder.build(top_k,
num_experts,
renormalize,
dtype,
quant_dtype=quant_dtype)

enable_ep = enable_ep and self.impl.support_ep()
if enable_ep:
Expand All @@ -295,16 +302,16 @@ def __init__(self,
weight_type='gate_up',
device=device,
expert_list=expert_list,
ep=enable_ep)
self.down = LinearWeightsW8A8(
num_experts,
ffn_dim,
hidden_dim,
weight_type='down',
device=device,
expert_list=expert_list,
ep=enable_ep,
)
ep=enable_ep,
quant_dtype=quant_dtype)
self.down = LinearWeightsW8A8(num_experts,
ffn_dim,
hidden_dim,
weight_type='down',
device=device,
expert_list=expert_list,
ep=enable_ep,
quant_dtype=quant_dtype)

self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
Expand Down Expand Up @@ -520,13 +527,15 @@ def build_fused_moe(

quant_method = quant_config['quant_method']
if quant_method == 'smooth_quant':
quant_dtype = eval('torch.' + quant_config.get('quant_dtype', 'int8'))
return FusedMoEW8A8(
hidden_dim=hidden_dim,
ffn_dim=ffn_dim,
num_experts=num_experts,
top_k=top_k,
renormalize=renormalize,
dtype=dtype,
quant_dtype=quant_dtype,
device=device,
all_reduce=all_reduce,
enable_ep=enable_ep,
Expand Down

0 comments on commit 4ac1894

Please sign in to comment.