From eb9857d6af073ae1d86a9865b30601459855e60a Mon Sep 17 00:00:00 2001 From: hx Date: Tue, 18 Feb 2025 11:58:03 -0800 Subject: [PATCH] [MoE][PyTorch] Add prob permutation to mask-based MoE permutation; Fix FP8 related codes (#1468) * add prob permute; fix fp8tensor Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert unnecessary changes in UT Signed-off-by: Hongxiao Bai * remove unnecessary probs dtype convert Signed-off-by: Hongxiao Bai * keep the output nums if probs is not provided Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine the doc string Signed-off-by: Hongxiao Bai * fix lint Signed-off-by: Hongxiao Bai * use fp32 compute type Signed-off-by: Hongxiao Bai * style fix Signed-off-by: Hongxiao Bai * fix empty input return Signed-off-by: Hongxiao Bai * separate prob related functions out Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Phuong Nguyen --- docs/api/pytorch.rst | 4 + tests/pytorch/test_permutation.py | 435 +++++++++++++++--- transformer_engine/pytorch/__init__.py | 2 + transformer_engine/pytorch/permutation.py | 221 +++++++-- .../pytorch/triton/permutation.py | 222 ++++++--- 5 files changed, 721 insertions(+), 163 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 6d5fe6761d..4154a18598 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -48,10 +48,14 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute +.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs + .. autoapifunction:: transformer_engine.pytorch.moe_unpermute .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index +.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs + .. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 35c6266a3f..0dc183e298 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -10,12 +10,14 @@ from transformer_engine.pytorch import ( moe_permute as te_permute, + moe_permute_with_probs as te_permute_with_probs, moe_unpermute as te_unpermute, moe_sort_chunks_by_index as te_sort_chunks_by_index, + moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, ) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine_torch as tex @@ -198,6 +200,16 @@ def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: raise ValueError(f"Unsuppored dtype ({te_dtype})") +def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False +): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + def _test_permutation_index_map( te_dtype, num_tokens, @@ -265,9 +277,9 @@ def _test_permutation_index_map( permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input) - pytorch_permute_fwd_input = permute_fwd_input.dequantize().to(torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.dequantize().to(torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize().to(torch.float16) + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -341,10 +353,10 @@ def _test_permutation_index_map( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = te_permute_output.dequantize().to(torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize().to(torch.float32) - te_unpermute_output_ = te_unpermute_output.dequantize().to(torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize().to(torch.float32) + te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) else: te_permute_output_ = te_permute_output.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() @@ -388,15 +400,6 @@ def _test_permutation_index_map( # Benchmark # ################################################################################################################################### - def backward_wrapper( - act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False - ): - # Set forward_input.grad to None to avoid grad accumulation. - if accumulate_grad == False: - for i in forward_input: - i.grad = None - return act.backward(backward_input, retain_graph=retain_graph) - if BENCHMARK: t1 = perf_test_cuda_kernel( lambda: pytorch_permute_index_map(pytorch_permute_fwd_input, indices, num_out_tokens) @@ -509,19 +512,28 @@ def _test_permutation_mask_map( size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" ) - permute_fwd_input = Float8Tensor.to_float8( - permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - permute_bwd_input = Float8Tensor.to_float8( - permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - unpermute_bwd_input = Float8Tensor.to_float8( - unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _unpermute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) + permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) + permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) + unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input) - pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -541,6 +553,10 @@ def _test_permutation_mask_map( probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) probs = probs / row_sums + if fp8: + probs = probs.to(torch.float16) + else: + probs = probs.to(dtype) probs.requires_grad_(True) ################################################################################################################################### @@ -571,7 +587,7 @@ def _test_permutation_mask_map( te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask" + te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) @@ -596,10 +612,10 @@ def _test_permutation_mask_map( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = te_permute_output.from_float8(torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) - te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) else: te_permute_output_ = te_permute_output.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() @@ -644,21 +660,14 @@ def _test_permutation_mask_map( # Benchmark # ################################################################################################################################### - def backward_wrapper( - act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False - ): - # Set forward_input.grad to None to avoid grad accumulation. - if accumulate_grad == False: - for i in forward_input: - i.grad = None - return act.backward(backward_input, retain_graph=retain_graph) - if BENCHMARK: t1 = perf_test_cuda_kernel( lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map) ) t2 = perf_test_cuda_kernel( - lambda: te_permute(te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask") + lambda: te_permute( + te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" + ) ) print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -752,15 +761,21 @@ def _test_moe_chunk_sort( fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") - fwd_input = Float8Tensor.to_float8( - fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - bwd_input = Float8Tensor.to_float8( - bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) + fwd_input = _fwd_input_quantizer.quantize(fwd_input) + bwd_input = _bwd_input_quantizer.quantize(bwd_input) - pytorch_fwd_input = fwd_input.from_float8(torch.float16) - pytorch_bwd_input = bwd_input.from_float8(torch.float16) + pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16) + pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16) else: pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() @@ -806,8 +821,8 @@ def _test_moe_chunk_sort( tols = dtype_tols(te_dtype) if fp8: - te_output_ = te_output.from_float8(torch.float32) - te_fwd_input_grad = te_fwd_input.grad.from_float8(torch.float32) + te_output_ = te_output.dequantize(dtype=torch.float32) + te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32) else: te_output_ = te_output.float() te_fwd_input_grad = te_fwd_input.grad.float() @@ -834,15 +849,6 @@ def _test_moe_chunk_sort( # Benchmark # ################################################################################################################################### - def backward_wrapper( - act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False - ): - # Set forward_input.grad to None to avoid grad accumulation. - if accumulate_grad == False: - for i in forward_input: - i.grad = None - return act.backward(backward_input, retain_graph=retain_graph) - if BENCHMARK: t1 = perf_test_cuda_kernel( lambda: pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs) @@ -873,6 +879,210 @@ def backward_wrapper( print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_mask_map_alongside_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "mask map alongside probs:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + _unpermute_bwd_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input) + unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input) + + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_permute_fwd_input.requires_grad_(True) + + restore_shape = pytorch_permute_fwd_input.shape + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + if fp8: + probs = probs.to(torch.float16) + else: + probs = probs.to(dtype) + probs.requires_grad_(True) + + split_sizes = [0] * (num_expert * tp_size) + for i in range(num_out_tokens): + idx = random.randint(0, num_expert * tp_size - 1) + split_sizes[idx] += 1 + split_sizes = torch.tensor(split_sizes, dtype=torch.int32) + split_sizes_cuda = split_sizes.to(device="cuda") + + _sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32) + sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel() + sorted_idxs_cuda = sorted_idxs.to(device="cuda") + + split_sizes_2 = [split_sizes[i] for i in sorted_idxs.tolist()] + split_sizes_2 = torch.tensor(split_sizes_2, dtype=torch.int32) + split_sizes_2_cuda = split_sizes_2.to(device="cuda") + + sorted_idxs_2 = [0] * (num_expert * tp_size) + for i in range(num_expert * tp_size): + sorted_idxs_2[sorted_idxs[i]] = i + sorted_idxs_2 = torch.tensor(sorted_idxs_2, dtype=torch.int32) + sorted_idxs_2_cuda = sorted_idxs_2.to(device="cuda") + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute_mask_map( + pytorch_permute_fwd_input, routing_map + ) + + pytorch_permute_output = pytorch_sort_chunks_by_index( + pytorch_permute_output, split_sizes, sorted_idxs + ) + + pytorch_permute_output = pytorch_sort_chunks_by_index( + pytorch_permute_output, split_sizes_2, sorted_idxs_2 + ) + + pytorch_unpermute_output = pytorch_unpermute_mask_map( + pytorch_permute_output, sorted_indices, restore_shape, probs, routing_map + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_probs = probs.detach() + te_probs.requires_grad_(True) + print(te_probs.shape) + + te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( + te_permute_fwd_input, + te_probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + print(te_permuted_probs.shape) + + te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( + te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda + ) + + if fp8: + _permute_output_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + te_permute_output = te_permute_output.dequantize(dtype=torch.float32) + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = _permute_output_quantizer.quantize(te_permute_output) + else: + te_permute_output_dtype = te_permute_output.dtype + print(te_permute_output.shape) + print(te_permuted_probs.shape) + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) + + te_permute_output = te_sort_chunks_by_index( + te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda + ) + + te_unpermute_output = te_unpermute( + te_permute_output, + row_id_map, + restore_shape=restore_shape, + map_type="mask", + ) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + ############################################################################################### + + tols = dtype_tols(te_dtype) + + if fp8: + # backward of dequantize is in high precision + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + else: + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in fused_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in fused_permute bwd", + **tols, + ) + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols + ) + + def perf_test_cuda_kernel(cuda_kernel_fn): if torch.cuda.is_available(): # create CUDA event @@ -959,6 +1169,63 @@ def test_permutation_mask_map( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_permutation_mask_map_empty_input(te_dtype): + with_probs = True + BENCHMARK = False + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + hidden_size=4096, + topK=2, + num_out_tokens=0, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +def test_permutation_mask_map_alongside_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + hidden_size=4096, + topK=2, + num_out_tokens=0, + tp_size=2, + ) + + # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -1023,6 +1290,34 @@ def test_permutation_mask_map_fp8( ) +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +def test_permutation_mask_map_alongside_probs_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + ) + + @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [8, 16]) @@ -1101,6 +1396,20 @@ def test_chunk_permutation( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_chunk_permutation_empty_input(te_dtype): + BENCHMARK = False + + _test_moe_chunk_sort( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + tp_size=2, + hidden_size=4096, + BENCHMARK=BENCHMARK, + ) + + def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) @@ -1149,6 +1458,16 @@ def test_permutation_single_case(): BENCHMARK=Benchmark, ) + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=4, + ) + if __name__ == "__main__": test_permutation_single_case() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 57addca3b9..d424b97f74 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -76,8 +76,10 @@ def _load_library(): from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch.permutation import ( moe_permute, + moe_permute_with_probs, moe_unpermute, moe_sort_chunks_by_index, + moe_sort_chunks_by_index_with_probs, ) from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 2e6167a6e0..dd2f60deba 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -261,13 +261,17 @@ def forward( inp: torch.Tensor, routing_map: torch.Tensor, num_out_tokens: int, + probs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): - return inp, torch.tensor([], device=inp.device) + ctx.probs = probs + return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) assert inp.is_cuda, "TransformerEngine needs CUDA." assert routing_map.is_cuda, "TransformerEngine needs CUDA." + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." assert inp.size(0) == routing_map.size(0), "Permute not possible" num_tokens, hidden_size = inp.size() @@ -282,48 +286,60 @@ def forward( if fp8: fp8_dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data - output = triton_permutation.permute_with_mask_map( + output, permuted_probs = triton_permutation.permute_with_mask_map( inp, row_id_map, + probs, num_tokens, num_experts, num_out_tokens, hidden_size, ) if fp8: - output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv) + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) ctx.save_for_backward(row_id_map) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.hidden_size = hidden_size - return output, row_id_map + return output, row_id_map, permuted_probs @staticmethod def backward( ctx, permuted_act_grad: torch.Tensor, _, + permuted_probs_grad: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring if not permuted_act_grad.numel(): - return permuted_act_grad, None, None + return permuted_act_grad, None, None, ctx.probs act_grad = None + probs_grad = None if ctx.needs_input_grad[0]: (row_id_map,) = ctx.saved_tensors fp8 = isinstance(permuted_act_grad, Float8Tensor) if fp8: fp8_dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype permuted_act_grad = permuted_act_grad._data else: fp8_dtype = None - act_grad = triton_permutation.unpermute_with_mask_map( + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( permuted_act_grad, row_id_map, None, + permuted_probs_grad, ctx.num_tokens, ctx.num_experts, ctx.hidden_size, @@ -334,8 +350,12 @@ def backward( data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * ctx.num_experts, + shape=act_grad.shape, + dtype=fake_dtype, ) - return act_grad, None, None + if not ctx.needs_input_grad[3]: + probs_grad = None + return act_grad, None, None, probs_grad class _moe_unpermute_mask_map(torch.autograd.Function): @@ -346,12 +366,12 @@ def forward( ctx, inp: torch.Tensor, row_id_map: torch.Tensor, - probs: torch.Tensor, + merging_probs: torch.Tensor, restore_shape: torch.Size, ) -> torch.Tensor: # pylint: disable=missing-function-docstring if not inp.numel(): - ctx.probs = probs + ctx.merging_probs = merging_probs return inp if restore_shape is None: @@ -359,15 +379,9 @@ def forward( num_tokens, hidden_size = restore_shape num_experts = row_id_map.size(0) - with_probs = probs is not None + with_probs = merging_probs is not None if with_probs: - assert probs.is_cuda, "TransformerEngine needs CUDA." - if probs.dtype != torch.float32: - warnings.warn( - f"The data type of the input `probs` of Unpermute is {probs.dtype}! " - "The recommended type is torch.float32." - ) - probs = probs.to(torch.float32) + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -380,13 +394,15 @@ def forward( fp8_scale_inv = inp._scale_inv * num_experts else: fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: fp8_dtype = None - unpermuted_output = triton_permutation.unpermute_with_mask_map( + unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( inp, row_id_map, - probs, + merging_probs, + None, num_tokens, num_experts, hidden_size, @@ -394,11 +410,15 @@ def forward( ) if fp8: unpermuted_output = Float8Tensor( - data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=unpermuted_output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=unpermuted_output.shape, + dtype=fake_dtype, ) if with_probs: - ctx.save_for_backward(inp, row_id_map, probs) + ctx.save_for_backward(inp, row_id_map, merging_probs) else: ctx.save_for_backward(row_id_map) ctx.num_experts = num_experts @@ -412,13 +432,13 @@ def forward( def backward(ctx, unpermuted_act_grad): # pylint: disable=missing-function-docstring if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.probs, None + return unpermuted_act_grad, None, ctx.merging_probs, None act_grad = None probs_grad = None if ctx.needs_input_grad[0]: if ctx.with_probs: - fwd_input, row_id_map, probs = ctx.saved_tensors + fwd_input, row_id_map, merging_probs = ctx.saved_tensors else: (row_id_map,) = ctx.saved_tensors @@ -426,26 +446,30 @@ def backward(ctx, unpermuted_act_grad): if fp8: fp8_dtype = unpermuted_act_grad._fp8_dtype fp8_scale_inv = unpermuted_act_grad._scale_inv + fake_dtype = unpermuted_act_grad.dtype unpermuted_act_grad = unpermuted_act_grad._data else: fp8_dtype = None if ctx.with_probs: - act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs( - unpermuted_act_grad, - row_id_map, - fwd_input, - probs, - ctx.num_tokens, - ctx.num_experts, - ctx.num_permuted_tokens, - ctx.hidden_size, - fp8_dtype, + act_grad, probs_grad = ( + triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( + unpermuted_act_grad, + row_id_map, + fwd_input, + merging_probs, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, + fp8_dtype, + ) ) else: - act_grad = triton_permutation.permute_with_mask_map( + act_grad, _ = triton_permutation.permute_with_mask_map( unpermuted_act_grad, row_id_map, + None, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, @@ -454,7 +478,11 @@ def backward(ctx, unpermuted_act_grad): if fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, ) if not ctx.needs_input_grad[2]: @@ -494,20 +522,56 @@ def moe_permute( map_type: str, default = 'mask' Type of the routing map tensor. Options are: 'mask', 'index'. + Refer to `routing_map` for more details. """ if map_type == "index": return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens) + output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None) + return output, row_id_map raise ValueError("map_type should be one of 'mask' or 'index'") +def moe_permute_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Permute the tokens and probs based on the routing_map. + Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens, num_experts]. It will be permuted with the tokens + according to the routing_map. + routing_map: torch.Tensor + The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + num_out_tokens: int, default = -1 + The effective output token count, representing the number of tokens not dropped. + By default, set to '-1', meaning no tokens are dropped. + """ + output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + inp, routing_map, num_out_tokens, probs + ) + return output, permuted_probs, row_id_map + + def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, - probs: torch.Tensor = None, + merging_probs: torch.Tensor = None, restore_shape: torch.Tensor = None, map_type: str = "mask", + probs: torch.Tensor = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -520,7 +584,7 @@ def moe_unpermute( row_id_map: torch.Tensor The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of `Permute`. - probs: torch.Tensor + merging_probs: torch.Tensor, default = None The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. @@ -529,11 +593,20 @@ def moe_unpermute( map_type: str, default = 'mask' Type of the routing map tensor. Should be the same as the value passed to moe_permute. Options are: 'mask', 'index'. + probs: torch.Tensor, default = None + Renamed to merging_probs. Keep for backward compatibility. """ + if probs is not None: + if merging_probs is not None: + raise ValueError( + "Both merging_probs and probs kwarg are provided. probs is deprecated." + ) + warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.") + merging_probs = probs if map_type == "index": - return _moe_unpermute_index_map.apply(inp, row_id_map, probs) + return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) if map_type == "mask": - return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape) + return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape) raise ValueError("map_type should be one of 'mask' or 'index'") @@ -546,14 +619,17 @@ def forward( inp: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, + probs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): - return inp, torch.tensor([], device=inp.device) + return inp, probs assert inp.is_cuda, "TransformerEngine needs CUDA." assert split_sizes.is_cuda, "TransformerEngine needs CUDA." assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." num_tokens, hidden_size = inp.shape num_splits = split_sizes.size(0) @@ -563,51 +639,69 @@ def forward( if fp8: fp8_dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data - output, row_id_map = triton_permutation.sort_chunks_by_idx( + output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx( inp, split_sizes, sorted_idxs, + probs, num_tokens, hidden_size, num_splits, ) if fp8: - output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv) + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) ctx.save_for_backward(row_id_map) ctx.num_tokens = num_tokens ctx.hidden_size = hidden_size - return output + return output, permuted_probs @staticmethod def backward( ctx, permuted_act_grad: torch.Tensor, + permuted_probs_grad: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring if not permuted_act_grad.numel(): - return permuted_act_grad, None, None + return permuted_act_grad, None, None, permuted_probs_grad act_grad = None + probs_grad = None if ctx.needs_input_grad[0]: (row_id_map,) = ctx.saved_tensors fp8 = isinstance(permuted_act_grad, Float8Tensor) if fp8: fp8_dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype permuted_act_grad = permuted_act_grad._data - act_grad = triton_permutation.sort_chunks_by_map( + act_grad, probs_grad = triton_permutation.sort_chunks_by_map( permuted_act_grad, row_id_map, + permuted_probs_grad, ctx.num_tokens, ctx.hidden_size, ) if fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, ) - return act_grad, None, None + if not ctx.needs_input_grad[3]: + probs_grad = None + return act_grad, None, None, probs_grad def moe_sort_chunks_by_index( @@ -629,4 +723,33 @@ def moe_sort_chunks_by_index( sorted_indices: torch.Tensor Chunk indices used to permute the chunks. """ - return _moe_chunk_sort.apply(inp, split_sizes, sorted_index) + output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None) + return output + + +def moe_sort_chunks_by_index_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + split_sizes: torch.Tensor, + sorted_index: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split and sort the input tensor and probs based on the split_sizes and sorted indices. + The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted + according to the sorted_indices. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens]. It will be permuted with the tokens according to + the split_sizes and sorted_indices. + split_sizes: torch.Tensor + Chunk sizes of the inp tensor along the 0-th dimension. + sorted_indices: torch.Tensor + Chunk indices used to permute the chunks. + """ + output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs) + return output, permuted_probs diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 767362e8c1..4ed92b0c80 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -125,6 +125,8 @@ def _permute_kernel( input_ptr, output_ptr, row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, # sizes num_tokens, num_experts, @@ -134,7 +136,11 @@ def _permute_kernel( stride_input_hidden, stride_output_token, stride_output_hidden, + stride_probs_token, + stride_probs_expert, + stride_permuted_probs_token, # metas + PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -149,12 +155,19 @@ def _permute_kernel( if dst_row != -1: output_off = dst_row * stride_output_token + cur_off * stride_output_hidden tl.store(output_ptr + output_off, inp, mask=mask) + if PERMUTE_PROBS: + if cur_pos == 0: + prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) cur_pos += BLOCK_SIZE def permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, + probs: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -162,11 +175,17 @@ def permute_with_mask_map( ): # pylint: disable=missing-function-docstring output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None grid = (num_tokens,) _permute_kernel[grid]( inp, output, row_id_map, + probs, + permuted_probs, num_tokens, num_experts, hidden_size, @@ -174,8 +193,12 @@ def permute_with_mask_map( inp.stride(1), output.stride(0), output.stride(1), + probs.stride(0) if probs is not None else None, + probs.stride(1) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, ) - return output + return output, permuted_probs @triton.autotune( @@ -194,7 +217,9 @@ def _unpermute_kernel( input_ptr, output_ptr, row_id_map_ptr, - probs_ptr, + merging_probs_ptr, + permuted_probs_ptr, + unpermuted_probs_ptr, # sizes num_tokens, num_experts, @@ -204,24 +229,27 @@ def _unpermute_kernel( stride_input_hidden, stride_output_token, stride_output_hidden, - stride_probs_token, - stride_probs_expert, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_permuted_probs_token, + stride_unpermuted_probs_token, + stride_unpermuted_probs_expert, # metas - WITH_PROBS: tl.constexpr, + WITH_MERGING_PROBS: tl.constexpr, + PERMUTE_PROBS: tl.constexpr, FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): if FP8_DTYPE == "e5m2": - compute_type = tl.float16 data_type = tl.float8e5 pytorch_tensor_dtype = tl.uint8 elif FP8_DTYPE == "e4m3": - compute_type = tl.float16 data_type = tl.float8e4nv pytorch_tensor_dtype = tl.uint8 else: - compute_type = input_ptr.dtype.element_ty + data_type = input_ptr.dtype.element_ty assert FP8_DTYPE is None + compute_type = tl.float32 pid = tl.program_id(0) current_start = 0 @@ -235,18 +263,35 @@ def _unpermute_kernel( input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True).to(compute_type) - if WITH_PROBS: - prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + prob_off).to(compute_type) - inp *= prob + inp = inp.to(data_type, bitcast=True) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + inp *= merging_prob accumulator += inp + if PERMUTE_PROBS: + if current_start == 0: + unpermuted_prob_off = ( + pid * stride_unpermuted_probs_token + + expert_idx * stride_unpermuted_probs_expert + ) + if src_row != -1: + permuted_prob_off = src_row * stride_permuted_probs_token + prob = tl.load(permuted_probs_ptr + permuted_prob_off) + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) + else: + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) if FP8_DTYPE is not None: - if not WITH_PROBS: + if not WITH_MERGING_PROBS: # Directly adding these value may cause overflow for fp8, we scale it here. # The outside fp8_scale_inv is also scaled in the meantime. accumulator /= num_experts accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True) + else: + accumulator = accumulator.to(data_type) output_off = pid * stride_output_token + current_offset * stride_output_hidden tl.store(output_ptr + output_off, accumulator, mask=mask) current_start += BLOCK_SIZE @@ -255,7 +300,8 @@ def _unpermute_kernel( def unpermute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, - probs: Union[torch.Tensor, None], + merging_probs: Union[torch.Tensor, None], + permuted_probs: Union[torch.Tensor, None], num_tokens: int, num_experts: int, hidden_size: int, @@ -269,12 +315,20 @@ def unpermute_with_mask_map( else: fp8_dtype = None output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if permuted_probs is not None: + unpermuted_probs = torch.empty( + (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda" + ) + else: + unpermuted_probs = None grid = (num_tokens,) _unpermute_kernel[grid]( inp, output, row_id_map, - probs, + merging_probs, + permuted_probs, + unpermuted_probs, num_tokens, num_experts, hidden_size, @@ -282,12 +336,16 @@ def unpermute_with_mask_map( inp.stride(1), output.stride(0), output.stride(1), - probs.stride(0) if probs is not None else None, - probs.stride(1) if probs is not None else None, - WITH_PROBS=probs is not None, + merging_probs.stride(0) if merging_probs is not None else None, + merging_probs.stride(1) if merging_probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + unpermuted_probs.stride(0) if unpermuted_probs is not None else None, + unpermuted_probs.stride(1) if unpermuted_probs is not None else None, + WITH_MERGING_PROBS=merging_probs is not None, + PERMUTE_PROBS=permuted_probs is not None, FP8_DTYPE=fp8_dtype, ) - return output + return output, unpermuted_probs @triton.autotune( @@ -301,13 +359,13 @@ def unpermute_with_mask_map( key=["hidden_size"], ) @triton.jit -def _unpermute_bwd_with_probs_kernel( +def _unpermute_bwd_with_merging_probs_kernel( # pointers fwd_output_grad_ptr, fwd_input_grad_ptr, fwd_input_ptr, - probs_ptr, - probs_grad_ptr, + merging_probs_ptr, + merging_probs_grad_ptr, row_id_map_ptr, # sizes num_tokens, @@ -320,31 +378,30 @@ def _unpermute_bwd_with_probs_kernel( stride_fwd_input_grad_hidden, stride_fwd_input_token, stride_fwd_input_hidden, - stride_probs_token, - stride_probs_expert, - stride_probs_grad_token, - stride_probs_grad_expert, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_merging_probs_grad_token, + stride_merging_probs_grad_expert, # metas FP8_DTYPE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): if FP8_DTYPE == "e5m2": - compute_type = tl.float16 data_type = tl.float8e5 pytorch_tensor_dtype = tl.uint8 elif FP8_DTYPE == "e4m3": - compute_type = tl.float16 data_type = tl.float8e4nv pytorch_tensor_dtype = tl.uint8 else: - compute_type = fwd_output_grad_ptr.dtype.element_ty + data_type = fwd_output_grad_ptr.dtype.element_ty assert FP8_DTYPE is None + compute_type = tl.float32 pid = tl.program_id(0) for expert_idx in range(num_experts): dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) if dst_row != -1: - prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) current_start = 0 while current_start < hidden_size: current_offset = current_start + tl.arange(0, BLOCK_SIZE) @@ -355,12 +412,16 @@ def _unpermute_bwd_with_probs_kernel( ) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) if FP8_DTYPE is not None: - inp = inp.to(data_type, bitcast=True).to(compute_type) - probs_off = pid * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + probs_off).to(compute_type) - output = inp * prob + inp = inp.to(data_type, bitcast=True) + inp = inp.to(compute_type) + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + output = inp * merging_prob + output = output.to(data_type) if FP8_DTYPE is not None: - output = output.to(data_type).to(pytorch_tensor_dtype, bitcast=True) + output = output.to(pytorch_tensor_dtype, bitcast=True) output_off = ( dst_row * stride_fwd_input_grad_token + current_offset * stride_fwd_input_grad_hidden @@ -373,21 +434,27 @@ def _unpermute_bwd_with_probs_kernel( fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) if FP8_DTYPE is not None: fwd_input = fwd_input.to(data_type, bitcast=True) - prob_grad_accum += fwd_input.to(tl.float32) * inp.to(tl.float32) + prob_grad_accum += fwd_input.to(compute_type) * inp current_start += BLOCK_SIZE - probs_grad = tl.sum(prob_grad_accum) - probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert - tl.store(probs_grad_ptr + probs_grad_off, probs_grad) + probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) else: - probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert - tl.store(probs_grad_ptr + probs_grad_off, 0.0) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) -def unpermute_with_mask_map_bwd_with_probs( +def unpermute_with_mask_map_bwd_with_merging_probs( fwd_output_grad: torch.Tensor, row_id_map: torch.Tensor, fwd_input: torch.Tensor, - probs: torch.Tensor, + merging_probs: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -404,14 +471,16 @@ def unpermute_with_mask_map_bwd_with_probs( act_grad = torch.empty( (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" ) - probs_grad = torch.empty((num_tokens, num_experts), dtype=probs.dtype, device="cuda") + merging_probs_grad = torch.empty( + (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" + ) grid = (num_tokens,) - _unpermute_bwd_with_probs_kernel[grid]( + _unpermute_bwd_with_merging_probs_kernel[grid]( fwd_output_grad, act_grad, fwd_input, - probs, - probs_grad, + merging_probs, + merging_probs_grad, row_id_map, num_tokens, num_experts, @@ -422,13 +491,13 @@ def unpermute_with_mask_map_bwd_with_probs( act_grad.stride(1), fwd_input.stride(0), fwd_input.stride(1), - probs.stride(0), - probs.stride(1), - probs_grad.stride(0), - probs_grad.stride(1), + merging_probs.stride(0), + merging_probs.stride(1), + merging_probs_grad.stride(0), + merging_probs_grad.stride(1), fp8_dtype, ) - return act_grad, probs_grad + return act_grad, merging_probs_grad @triton.autotune( @@ -449,6 +518,8 @@ def _sort_chunks_by_idxs_kernel( sorted_indices_ptr, output_ptr, dst_rows_ptr, + probs_ptr, + permuted_probs_ptr, # sizes num_splits, hidden_size, @@ -457,7 +528,10 @@ def _sort_chunks_by_idxs_kernel( stride_input_hidden, stride_output_token, stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, # metas + PERMUTE_PROBS: tl.constexpr, IDX_LOAD_WIDTH: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -508,11 +582,18 @@ def _sort_chunks_by_idxs_kernel( tl.store(output_ptr + output_offsets, inp, mask=mask) current_start += BLOCK_SIZE + if PERMUTE_PROBS: + prob_off = pid * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + def sort_chunks_by_idx( inp: torch.Tensor, split_sizes: torch.Tensor, sorted_indices: torch.Tensor, + probs: torch.Tensor, num_tokens: int, hidden_size: int, num_splits: int, @@ -520,6 +601,10 @@ def sort_chunks_by_idx( # pylint: disable=missing-function-docstring row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None grid = (num_tokens,) _sort_chunks_by_idxs_kernel[grid]( inp, @@ -527,15 +612,20 @@ def sort_chunks_by_idx( sorted_indices, output, row_id_map, + probs, + permuted_probs, num_splits, hidden_size, inp.stride(0), inp.stride(1), output.stride(0), output.stride(1), - triton.next_power_of_2(num_splits), + probs.stride(0) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, + IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), ) - return output, row_id_map + return output, row_id_map, permuted_probs @triton.autotune( @@ -554,6 +644,8 @@ def _sort_chunks_by_map( input_ptr, output_ptr, row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, # sizes hidden_size, # strides @@ -561,7 +653,10 @@ def _sort_chunks_by_map( stride_input_hidden, stride_output_token, stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, # metas + PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -575,25 +670,40 @@ def _sort_chunks_by_map( inp = tl.load(input_ptr + input_offsets, mask=mask) tl.store(output_ptr + output_offsets, inp, mask=mask) current_start += BLOCK_SIZE + if PERMUTE_PROBS: + prob_off = dst_row * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = pid * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) def sort_chunks_by_map( inp: torch.Tensor, row_id_map: torch.Tensor, + probs: torch.Tensor, num_tokens: int, hidden_size: int, ): # pylint: disable=missing-function-docstring output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None grid = (num_tokens,) _sort_chunks_by_map[grid]( inp, output, row_id_map, + probs, + permuted_probs, hidden_size, inp.stride(0), inp.stride(1), output.stride(0), output.stride(1), + probs.stride(0) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, ) - return output + return output, permuted_probs