diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py
index be6dd97836..12a9dd718f 100644
--- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py
+++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py
@@ -117,6 +117,43 @@ def combine_expert_gate_up(*hf_params, dtype):
                 ),
             )
 
+        # map MLA kv_b_proj weight
+        attn = f"model.layers.{i}.self_attn"
+        mapping.add_mapping(
+            f"{attn}.w_uk",
+            [f"{attn}.kv_b_proj.weight"],
+            functools.partial(
+                lambda kv_b_proj, dtype: np.split(
+                    kv_b_proj.reshape(
+                        model_config.num_key_value_heads,
+                        model_config.qk_nope_head_dim + model_config.v_head_dim,
+                        model_config.kv_lora_rank,
+                    ),
+                    indices_or_sections=[model_config.qk_nope_head_dim],
+                    axis=1,
+                )[0]
+                .transpose(0, 2, 1)
+                .astype(dtype),
+                dtype=mlc_param.dtype,
+            ),
+        )
+        mapping.add_mapping(
+            f"{attn}.w_uv",
+            [f"{attn}.kv_b_proj.weight"],
+            functools.partial(
+                lambda kv_b_proj, dtype: np.split(
+                    kv_b_proj.reshape(
+                        model_config.num_key_value_heads,
+                        model_config.qk_nope_head_dim + model_config.v_head_dim,
+                        model_config.kv_lora_rank,
+                    ),
+                    indices_or_sections=[model_config.qk_nope_head_dim],
+                    axis=1,
+                )[1].astype(dtype),
+                dtype=mlc_param.dtype,
+            ),
+        )
+
     for mlc_name, mlc_param in named_parameters.items():
         if mlc_name not in mapping.param_map:
             mapping.add_mapping(
diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
index 566a429003..447d5edf4d 100644
--- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
+++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
@@ -220,6 +220,8 @@ def __init__(self, config: DeepseekV2Config):
             self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
             bias=False,
         )
+        self.w_uk = nn.Parameter((self.num_heads, config.kv_lora_rank, self.qk_nope_head_dim))
+        self.w_uv = nn.Parameter((self.num_heads, self.v_head_dim, config.kv_lora_rank))
 
         self.o_proj = nn.Linear(
             self.num_heads * self.v_head_dim,
@@ -241,6 +243,106 @@ def forward(
         paged_kv_cache: PagedKVCache,
         layer_id: int,
         query_positions: Tensor,
+    ):
+        return self.forward_absorb(hidden_states, paged_kv_cache, layer_id, query_positions)
+
+    def forward_absorb(
+        self,
+        hidden_states: Tensor,
+        paged_kv_cache: PagedKVCache,
+        layer_id: int,
+        query_positions: Tensor,
+    ):
+        b, s, _ = hidden_states.shape
+
+        if self.q_lora_rank is None:
+            q = self.q_proj(hidden_states)
+        else:
+            q = self.q_b_proj(
+                self.q_a_layernorm(self.q_a_proj(hidden_states))
+            )  # (b, s, num_heads * q_head_dim)
+        q = op.reshape(q, (b, s, self.num_heads, self.q_head_dim))  # (b, s, num_heads, q_head_dim)
+        q_nope, q_pe = op.split(
+            q, [self.qk_nope_head_dim], axis=-1
+        )  # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, qk_rope_head_dim)
+        q_nope = (
+            op.matmul(
+                q_nope.reshape(b * s, self.num_heads, self.qk_nope_head_dim).permute_dims(1, 0, 2),
+                self.w_uk.permute_dims(0, 2, 1),
+            )
+            .permute_dims(1, 0, 2)
+            .reshape(b, s, self.num_heads, self.kv_lora_rank)
+        )  # (b, s, num_heads, kv_lora_rank)
+
+        compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape(
+            b, s, 1, self.kv_lora_rank + self.qk_rope_head_dim
+        )  # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
+        compressed_kv, k_pe = op.split(
+            compressed_kv, [self.config.kv_lora_rank], axis=-1
+        )  # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)
+
+        compressed_kv = self.kv_a_layernorm(compressed_kv)
+        k_nope = compressed_kv  # (b, s, 1, kv_lora_rank)
+        value_states = compressed_kv  # (b, s, 1, kv_lora_rank)
+
+        q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)
+
+        def concat_nope_pe(num_heads: int):
+            def f_concat_nope_pe(var_nope: te.Tensor, var_pe: te.Tensor):
+                return te.compute(
+                    (b, s, num_heads, self.kv_lora_rank + self.qk_rope_head_dim),
+                    lambda _b, _s, _h, _d: te.if_then_else(
+                        _d < self.kv_lora_rank,
+                        var_nope[_b, _s, _h, _d],
+                        var_pe[_b, _s, _h, _d - self.kv_lora_rank],
+                    ),
+                )
+
+            return f_concat_nope_pe
+
+        query_states = op.tensor_expr_op(
+            concat_nope_pe(num_heads=self.num_heads), "concat_q", [q_nope, q_pe]
+        )  # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
+        key_states = op.tensor_expr_op(
+            concat_nope_pe(num_heads=1), "concat_k", [k_nope, k_pe]
+        )  # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
+        value_states = op.pad(
+            value_states, [0, 0, 0, 0, 0, 0, 0, self.qk_rope_head_dim]
+        )  # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
+
+        qkv = op.concat(
+            [query_states, key_states, value_states], dim=2
+        )  # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
+        output, _ = op.split(
+            paged_kv_cache.attention_with_fused_qkv(
+                layer_id,
+                qkv,
+                self.num_heads,
+                self.softmax_scale
+                * math.sqrt(
+                    self.kv_lora_rank + self.qk_rope_head_dim
+                ),  # This is to cancel out the 1/sqrt(d) in normal attention
+            ),
+            indices_or_sections=[self.kv_lora_rank],
+            axis=-1,
+        )  # (b, s, num_heads, kv_lora_rank)
+        output = (
+            op.matmul(
+                output.reshape(b * s, self.num_heads, self.kv_lora_rank).permute_dims(1, 0, 2),
+                self.w_uv.permute_dims(0, 2, 1),
+            )
+            .permute_dims(1, 0, 2)
+            .reshape(b, s, self.num_heads * self.v_head_dim)
+        )
+
+        return self.o_proj(output)
+
+    def forward_normal(
+        self,
+        hidden_states: Tensor,
+        paged_kv_cache: PagedKVCache,
+        layer_id: int,
+        query_positions: Tensor,
     ):
         b, s, _ = hidden_states.shape
 
@@ -450,6 +552,14 @@ def _set(layer, hint):
                 self.self_attn.kv_b_proj.weight,
                 tp.ShardSingleDim("_shard_kv_b_weight", dim=0),
             )
+            _set(
+                self.self_attn.w_uk,
+                tp.ShardSingleDim("_shard_kv_b_weight_w_uk", dim=0),
+            )
+            _set(
+                self.self_attn.w_uv,
+                tp.ShardSingleDim("_shard_kv_b_weight_w_uv", dim=0),
+            )
             _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
 
             if isinstance(self.mlp, DeepseekV2MoE):
@@ -517,7 +627,6 @@ def __init__(self, config: DeepseekV2Config):
 
     def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
         hidden_states = inputs
-        print(f"inputs.shape = {inputs.shape}")
         query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1])
         for layer_id, layer in enumerate(self.layers):
             hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions)
@@ -535,6 +644,8 @@ def __init__(self, config: DeepseekV2Config):
         self.intermediate_size = config.intermediate_size
         self.num_attention_heads = config.num_attention_heads
         self.num_key_value_heads = config.num_key_value_heads
+        self.kv_lora_rank = config.kv_lora_rank
+        self.qk_rope_head_dim = config.qk_rope_head_dim
         self.rms_norm_eps = config.rms_norm_eps
         self.rope_theta = config.rope_theta
         self.vocab_size = config.vocab_size
@@ -621,8 +732,8 @@ def create_paged_kv_cache(  # pylint: disable=too-many-arguments
             support_sliding_window=support_sliding_window,
             num_hidden_layers=self.num_hidden_layers,
             num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
-            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
-            head_dim=256,
+            num_key_value_heads=1,
+            head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
             rope_mode=RopeMode.NONE,
             rope_scale=1,
             rope_theta=self.rope_theta,
diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py
index 3ec70b61b3..dffa582c46 100644
--- a/tests/python/integration/test_model_compile.py
+++ b/tests/python/integration/test_model_compile.py
@@ -114,6 +114,9 @@ def test_model_compile():  # pylint: disable=too-many-locals
                 if not target.startswith("cuda") and quant == "q4f16_ft":
                     # FasterTransformer only works with cuda
                     continue
+                if "deepseek_v2" in model and "32" in quant:
+                    # Skip f32 for deepseek v2 model for now.
+                    continue
                 log_file = os.path.join(tmp_dir, f"lib{idx}.log")
                 cmd = [
                     sys.executable,