-
Notifications
You must be signed in to change notification settings - Fork 707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Port to ROCm/HIP #178
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: fxzjshm <[email protected]>
Signed-off-by: fxzjshm <[email protected]>
Signed-off-by: fxzjshm <[email protected]>
Signed-off-by: fxzjshm <[email protected]>
Signed-off-by: fxzjshm <[email protected]>
I will look into this on the MUSA side. |
如果 flash_attention 安装不了, 可暂时注释掉使用 flash_attention 的部分代码: diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py
index 5d2e911..7af4e0d 100644
--- a/ktransformers/operators/models.py
+++ b/ktransformers/operators/models.py
@@ -19,7 +19,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention
+# from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention
from ktransformers.server.config.config import Config
import os
import yaml 这边暂时没有足够好的硬件来测速, 目前仅能确认 AMD 单卡可运行 (默认配置需要 > 32 GB 显存), 速率见上条; 同样的原因摩尔线程测试不了, 因此没做进一步改动. 在 32 GB 的 MI100 下可使用如下配置做拆分; 其它显存大小可相应调整层数. DeepSeek-V3-Chat-mixed.yaml
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12345][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12345][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
generate_op: "KLinearCPUInfer"
prefill_op: "KLinearCPUInfer"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12345][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12345][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12345][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cpu"
prefill_op: "KExpertsCPU"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cpu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12345][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
60: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12345][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
对于多卡下 hipGraph 崩溃的问题, 暂时没有头绪. @kingofotaku @carsonfly 能否帮忙测试一下? |
Not enough VRAM to run on a single Radeon 7900XTX 24GB without GPTQ_marlin and crashed on 2 * Radeon 7900XTX 24GB. |
Does it support a ROCm feature similar to CUDA Graph? Without a compute graph, the decode process can get stuck in Python, resulting in a loss of nearly 50%. |
Is there any chance to update the |
注意到多卡情形的崩溃出现在 hipGraph 模块内, 这可能表明 hipGraph 已经被启用, 并且可能有问题.
这大概得问原作者. |
Hi @fxzjshm I am trying like below on single and 2 x W7900.
But i get below error
Do you know what am i doing wrong ? |
@citrix123 Since |
Thanks for the quick response @fxzjshm Seems i am still seeing the same issue ? Below is the change i have made
I have even tried to rebuild using
Is there something i misunderstood ? |
Thanks, I can reproduce issue mentioned by you.
Backtrace#0 0x00007a93b8e76df8 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#1 0x00007a93b8ea697f in ?? () from /opt/rocm/lib/libamdhip64.so.6
#2 0x00007a93b8e6a876 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#3 0x00007a93b8ce58a0 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#4 0x00007a93b8ce9103 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#5 0x00007a93b8ceb2ee in ?? () from /opt/rocm/lib/libamdhip64.so.6
#6 0x00007a93e30bc0ed in at::native::copy_device_to_device(at::TensorIterator&, bool, bool) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_hip.so
#7 0x00007a93f09122e9 in at::native::copy_impl(at::Tensor&, at::Tensor const&, bool) [clone .isra.0] () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#8 0x00007a93f0913c50 in at::native::copy_(at::Tensor&, at::Tensor const&, bool) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#9 0x00007a93f16d4b98 in at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#10 0x00007a93f0c0e772 in at::native::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) ()
from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#11 0x00007a93f1aab4f1 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___to_copy>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat> > >, at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007a93f1114fcc in at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007a93f18d3d54 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>), &at::(anonymous namespace)::_to_copy>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat> > >, at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007a93f1114fcc in at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007a93f385f428 in torch::autograd::VariableType::(anonymous namespace)::_to_copy(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007a93f385f8b4 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>), &torch::autograd::VariableType::(anonymous namespace)::_to_copy>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat> > >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007a93f11ce626 in at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) ()
from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#18 0x00007a93f0c0b7c4 in at::native::to(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>) ()
from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#19 0x00007a93f1cac6d7 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd_dtype_layout_to>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat> > >, at::Tensor (at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#20 0x00007a93f139076b in at::_ops::to_dtype_layout::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>) ()
from /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so |
@citrix123 Yes, same command for multi-GPU case. We managed to get access of an EPYC 9654 platform, results updated in the first post. |
An attempt to port code to AMD ROCm platform.
GPTQ_marlin temporarily disabled, waiting for porting by experts.
Compat layer from llama.cpp; also comtains some MUSA-related port code, so @MooreThreads
Related: #159
Credit: @leavelet
Tested on 1 * Radeon 7900XTX 48GB + 1 * EPYC 9654 (12 * DDR5-4800) w/ pytorch 2.6.0+rocm6.2.4, DeepSeek-R1-Q4_K_M,
Tested on 1 * Radeon 7900XTX 48GB + 2 * EPYC 9174F (but only (4 + 3) * DDR5-4800) w/ pytorch 2.6.0+rocm6.2.4, DeepSeek-R1-Q4_K_M,
Not working now for 2 * 7900XTX 48GB (model loads but segmentation fault in hipGraph):
publishing this here in case can help someone.