From c782b239a0bf2a706c9e941a80a63c98773e9e36 Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 11 Dec 2024 15:35:29 -0500 Subject: [PATCH 1/6] add support for phi3 models --- .../models/phi3/modeling_phi3.py | 459 ++++++++++++++++++ 1 file changed, 459 insertions(+) create mode 100644 src/neuronx_distributed_inference/models/phi3/modeling_phi3.py diff --git a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py new file mode 100644 index 0000000..8fee8b6 --- /dev/null +++ b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py @@ -0,0 +1,459 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-3 model for NXD inference.""" + +import gc +from typing import List, Optional, Tuple, Type +from transformers import Phi3ForCausalLM +import torch +from neuronx_distributed.parallel_layers import parallel_state # noqa: E402 +from neuronx_distributed.parallel_layers.layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import _gather_along_dim +from torch import nn +import torch.utils.checkpoint + +from transformers.activations import ACT2FN + + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig # noqa: E402 +from neuronx_distributed_inference.models.model_base import ( # noqa: E402 + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.gqa import ( # noqa: E402 + GroupQueryAttention_QKV, + GroupQueryAttention_O, +) + +from neuronx_distributed.parallel_layers import utils +from transformers.models.phi3.modeling_phi3 import ( + Phi3RotaryEmbedding, + Phi3RMSNorm, + Phi3LongRoPEScaledRotaryEmbedding, +) +from transformers.models.phi3.configuration_phi3 import Phi3Config +import logging + +# Set up basic configuration for logging +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(levelname)s - %(message)s", + filename="debug.txt", # This will write to a file named debug.log + filemode="w", +) # 'w' mode overwrites the file each time + +# Create a logger +logger = logging.getLogger(__name__) +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_PHI3_ATTENTION_CLASSES = {} + + +def get_rmsnorm_cls(): + return Phi3RMSNorm + + +def _register_module(key: str, cls: Type[nn.Module]): + _PHI3_ATTENTION_CLASSES[key] = cls + + +def register_module(key: str): + """ + Register a module for use in NeuronLlama. + + Arguments: + key: String used to identify the module + + Example: + @register_module("NeuronPhi3Attention") + class NeuronPhi3Attention(nn.Module): + ... + """ + + def inner(cls: Type[nn.Module]): + _register_module(key, cls) + return cls + + return inner + + +def convert_state_dict_to_non_fused_qkv(phi3_state_dict, cfg: InferenceConfig): + for l in range(cfg.num_hidden_layers): + # Keep the original fused weight as Wqkv.weight + phi3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = phi3_state_dict[ + f"layers.{l}.self_attn.qkv_proj.weight" + ] + + # Get the fused QKV weight + fused_weight = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"] + fused_gate_up = phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"] + + # Split the fused weight into Q, K, and V using torch.chunk + q_weight, k_weight, v_weight = torch.chunk(fused_weight, 3, dim=0) + gate, up = torch.chunk(fused_gate_up, 2, dim=0) + + # Add the split weights to the state dict + phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.q_proj.weight"] = q_weight + phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.k_proj.weight"] = k_weight + phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.v_proj.weight"] = v_weight + phi3_state_dict[f"layers.{l}.mlp.gate_proj.weight"] = gate + phi3_state_dict[f"layers.{l}.mlp.up_proj.weight"] = up + + # Remove the original qkv_proj weight + del phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"] + del phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"] + + gc.collect() + + return phi3_state_dict + + +class NeuronPhi3InferenceConfig(InferenceConfig): + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "vocab_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "pad_token_id", + "hidden_act", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronPhi3MLP(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + + self.config = config + self.neuron_config = config.neuron_config + + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.activation_fn = ACT2FN[config.hidden_act] + + self.sequence_parallel_enabled = getattr( + self.neuron_config, "sequence_parallel_enabled", False + ) + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + + if parallel_state.model_parallel_is_initialized(): + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + ) + + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + ) + + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + ) + else: + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=False + ) + + def forward(self, hidden_state): + if self.sequence_parallel_enabled: + x = _gather_along_dim(x, self.sequence_dimension) + else: + x = hidden_state + + return self.down_proj(self.activation_fn(self.gate_proj(x)) * self.up_proj(x)) + + +@register_module("NeuronPhi3Attention") +class NeuronPhi3Attention(NeuronAttentionBase): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + self.neuron_config = config.neuron_config + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.padding_side = config.neuron_config.padding_side + self.torch_dtype = config.neuron_config.torch_dtype + + if parallel_state.model_parallel_is_initialized(): + self.tp_degree = parallel_state.get_tensor_model_parallel_size() + else: + self.tp_degree = 1 + + self.fused_qkv = config.neuron_config.fused_qkv + self.clip_qkv = None + + self.sequence_parallel_enabled = self.neuron_config.sequence_parallel_enabled + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + + self.init_custom_gqa_properties() + + self.init_rope() + + def init_custom_gqa_properties(self): + if (self.head_dim * self.num_attention_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_attention_heads})." + ) + + self.qkv_proj = GroupQueryAttention_QKV( + hidden_size=self.hidden_size, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + tp_degree=self.tp_degree, + dtype=self.torch_dtype, + bias=False, + gather_output=False, + fused_qkv=self.fused_qkv, + clip_qkv=self.clip_qkv, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + ) + self.o_proj = GroupQueryAttention_O( + hidden_size=self.hidden_size, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + tp_degree=self.tp_degree, + dtype=self.torch_dtype, + bias=False, + input_is_parallel=True, + layer_name=self.o_proj_layer_name, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + ) + self.num_heads = utils.divide( + self.qkv_proj.get_num_attention_heads(), self.tp_degree + ) + self.num_key_value_heads = utils.divide( + self.qkv_proj.get_num_key_value_heads(), self.tp_degree + ) + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + def init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling.get("type") + if scaling_type == "longrope": + self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + self.head_dim, self.config + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + +class NeuronPhi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + + self.self_attn = NeuronPhi3Attention( + config=config, + ) + self.hidden_size = config.hidden_size + + self.mlp = NeuronPhi3MLP(config) + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outs = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + + hidden_states, present_key_value = attn_outs + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return (hidden_states, present_key_value) + + +class NeuronPhi3Model(NeuronBaseModel): + def setup_attr_for_model(self, config: InferenceConfig): + # Needed for init_inference_optimization() + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + # self._attn_implementation = config._attn_implementation + + def init_model(self, config: InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + # We choose to shard across embedding dimension because this stops XLA from introducing + # rank specific constant parameters into the HLO. We could shard across vocab, but that + # would require us to use non SPMD parallel_model_trace. + pad=True, + ) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + pad=True, + ) + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + ) + self.lm_head = nn.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + pad=True, + ) + + self.layers = nn.ModuleList( + [ + NeuronPhi3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + +class NeuronPhi3ForCausalLM(NeuronBaseForCausalLM): + """ + This class extends Phi3ForCausalLM create traceable + blocks for Neuron. + + Args: + LlamaForCausalLM (_type_): _description_ + """ + + _model_cls = NeuronPhi3Model + + @staticmethod + def load_hf_model(model_path): + return Phi3ForCausalLM.from_pretrained(model_path) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + """This function should be over-ridden in child classes as needed""" + + state_dict = convert_state_dict_to_non_fused_qkv(state_dict, config) + print(state_dict) + return state_dict + + @classmethod + def get_config_cls(cls): + return NeuronPhi3InferenceConfig From 3765ca4c968a7097d9dca0fde5af90e44b533b7d Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Sat, 21 Dec 2024 15:04:55 -0500 Subject: [PATCH 2/6] add clone().detach() to convert_to_neuron function --- .../models/phi3/modeling_phi3.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py index 8fee8b6..80da03b 100644 --- a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py +++ b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py @@ -96,16 +96,16 @@ def inner(cls: Type[nn.Module]): return inner -def convert_state_dict_to_non_fused_qkv(phi3_state_dict, cfg: InferenceConfig): +def convert_state_dict_to_neuron(phi3_state_dict, cfg: InferenceConfig): for l in range(cfg.num_hidden_layers): # Keep the original fused weight as Wqkv.weight phi3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = phi3_state_dict[ f"layers.{l}.self_attn.qkv_proj.weight" - ] + ].clone().detach() # Get the fused QKV weight - fused_weight = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"] - fused_gate_up = phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"] + fused_weight = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"].clone().detach() + fused_gate_up = phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"].clone().detach() # Split the fused weight into Q, K, and V using torch.chunk q_weight, k_weight, v_weight = torch.chunk(fused_weight, 3, dim=0) @@ -450,8 +450,7 @@ def convert_hf_to_neuron_state_dict( ) -> dict: """This function should be over-ridden in child classes as needed""" - state_dict = convert_state_dict_to_non_fused_qkv(state_dict, config) - print(state_dict) + state_dict = convert_state_dict_to_neuron(state_dict, config) return state_dict @classmethod From e2f42e979a0e7fbaeed483816f7eb74ac4303cbc Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Thu, 9 Jan 2025 20:53:07 -0500 Subject: [PATCH 3/6] Update for 2.21 --- .../models/phi3/modeling_phi3.py | 911 ++++++++++++++---- 1 file changed, 724 insertions(+), 187 deletions(-) diff --git a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py index 80da03b..89847dd 100644 --- a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py +++ b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py @@ -1,5 +1,10 @@ # coding=utf-8 -# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +17,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""PyTorch Phi-3 model for NXD inference.""" - +"""PyTorch Phi model for NXD inference.""" +import copy import gc +import logging +import math from typing import List, Optional, Tuple, Type -from transformers import Phi3ForCausalLM + import torch from neuronx_distributed.parallel_layers import parallel_state # noqa: E402 from neuronx_distributed.parallel_layers.layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 @@ -25,60 +31,80 @@ ParallelEmbedding, RowParallelLinear, ) -from neuronx_distributed.parallel_layers.mappings import _gather_along_dim +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from neuronx_distributed.parallel_layers.utils import get_padding_length +from neuronx_distributed.quantization.quantization_config import QuantizationType, QuantizedDtype +from neuronx_distributed.quantization.quantization_layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + QuantizedColumnParallel, + QuantizedRowParallel, +) +from neuronxcc.nki._private_kernels.mlp import ( + mlp_fused_add_isa_kernel, + mlp_isa_kernel, + quant_mlp_fused_add_isa_kernel, + quant_mlp_isa_kernel, +) +from neuronxcc.nki._private_kernels.rmsnorm import rmsnorm_quant_isa_kernel +from neuronxcc.starfish.penguin.targets.nki.private_api import vnc from torch import nn -import torch.utils.checkpoint - +from torch_neuronx.xla_impl.ops import nki_jit +from transformers import Phi3ForCausalLM from transformers.activations import ACT2FN - +from transformers.models.phi3.modeling_phi3 import ( + Phi3RotaryEmbedding, + Phi3RMSNorm, + Phi3LongRoPEScaledRotaryEmbedding, +) from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig # noqa: E402 from neuronx_distributed_inference.models.model_base import ( # noqa: E402 NeuronBaseForCausalLM, NeuronBaseModel, ) -from neuronx_distributed_inference.modules.attention.attention_base import ( - NeuronAttentionBase, -) +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase from neuronx_distributed_inference.modules.attention.gqa import ( # noqa: E402 - GroupQueryAttention_QKV, - GroupQueryAttention_O, + BaseGroupQueryAttention, ) - -from neuronx_distributed.parallel_layers import utils -from transformers.models.phi3.modeling_phi3 import ( - Phi3RotaryEmbedding, - Phi3RMSNorm, - Phi3LongRoPEScaledRotaryEmbedding, +from neuronx_distributed_inference.modules.attention.utils import ( + RotaryEmbedding, + preprocess_quantized_linear_layer, + transpose_parallel_linear_layer, ) -from transformers.models.phi3.configuration_phi3 import Phi3Config -import logging +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.flashdecode.utils import calculate_num_cores_per_group +from neuronx_distributed_inference.modules.lora_serving.lora_module import is_lora_module +from neuronx_distributed_inference.utils.distributed import get_tp_group -# Set up basic configuration for logging -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename="debug.txt", # This will write to a file named debug.log - filemode="w", -) # 'w' mode overwrites the file each time +_PHI3_MODULE_MAP = {} -# Create a logger -logger = logging.getLogger(__name__) -_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" -_PHI3_ATTENTION_CLASSES = {} +logger = logging.getLogger("Neuron") def get_rmsnorm_cls(): - return Phi3RMSNorm + # Initialize to the appropriate implementation of RMSNorm + # If infer on NXD -> CustomRMSNorm + # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) + return CustomRMSNorm if parallel_state.model_parallel_is_initialized() else Phi3RMSNorm + + +def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: str) -> bool: + if isinstance(module, (BaseGroupQueryAttention,)): + return module.preshard_hook(model_state_dict, prefix) + + return False def _register_module(key: str, cls: Type[nn.Module]): - _PHI3_ATTENTION_CLASSES[key] = cls + _PHI3_MODULE_MAP[key] = cls def register_module(key: str): """ - Register a module for use in NeuronLlama. + Register a module for use in NeuronPhi3. Arguments: key: String used to identify the module @@ -128,6 +154,14 @@ def convert_state_dict_to_neuron(phi3_state_dict, cfg: InferenceConfig): class NeuronPhi3InferenceConfig(InferenceConfig): + def add_derived_config(self): + self.num_cores_per_group = 1 + if self.neuron_config.flash_decoding_enabled: + num_attn_heads, num_kv_heads = self.num_attention_heads, self.num_key_value_heads + self.num_cores_per_group = calculate_num_cores_per_group( + num_attn_heads, num_kv_heads, self.neuron_config.tp_degree + ) + def get_required_attributes(self) -> List[str]: return [ "hidden_size", @@ -148,94 +182,441 @@ def get_neuron_config_cls(cls) -> Type[NeuronConfig]: class NeuronPhi3MLP(nn.Module): + """ + This class just replace the linear layers (gate_proj, up_proj and down_proj) with column and row parallel layers + """ + def __init__(self, config: InferenceConfig): super().__init__() - self.config = config self.neuron_config = config.neuron_config - self.tp_degree = config.neuron_config.tp_degree self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.activation_fn = ACT2FN[config.hidden_act] + self.act_fn = ACT2FN[config.hidden_act] self.sequence_parallel_enabled = getattr( self.neuron_config, "sequence_parallel_enabled", False ) self.sequence_dimension = 1 if self.sequence_parallel_enabled else None - + self.rms_norm_eps = config.rms_norm_eps + self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled + self.quantized_mlp_kernel_enabled = self.neuron_config.quantized_mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = self.neuron_config.rmsnorm_quantize_kernel_enabled + self.quantized_kernel_lower_bound = self.neuron_config.quantized_kernel_lower_bound + self.logical_neuron_cores = self.neuron_config.logical_neuron_cores + mlp_bias = getattr(config, "mlp_bias", False) if parallel_state.model_parallel_is_initialized(): - self.gate_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, - ) + if self.quantized_mlp_kernel_enabled: + # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad + tp_degree = self.neuron_config.tp_degree + self.intermediate_size += ( + get_padding_length(self.intermediate_size // tp_degree, 128) * tp_degree + ) + logger.debug(f"Quantized intermediate_size: {self.intermediate_size}") + + quantization_type = QuantizationType(self.neuron_config.quantization_type) + quantized_dtype = QuantizedDtype.F8E4M3 + self.gate_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = QuantizedRowParallel( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=mlp_bias, + quantization_type=quantization_type, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + sequence_parallel_enabled=False, + quantization_per_channel_axis=0, + tensor_model_parallel_group=get_tp_group(config), + ) - self.up_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, + else: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=mlp_bias, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=get_tp_group(config), + reduce_dtype=config.neuron_config.rpl_reduce_dtype, + ) + + if self.mlp_kernel_enabled: + if self.quantized_mlp_kernel_enabled: + preprocess_quantized_linear_layer(self.gate_proj) + preprocess_quantized_linear_layer(self.up_proj) + preprocess_quantized_linear_layer(self.down_proj) + + else: + # Transpose the weights to the layout expected by kernels + self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight) + self.up_proj.weight = transpose_parallel_linear_layer(self.up_proj.weight) + self.down_proj.weight = transpose_parallel_linear_layer(self.down_proj.weight) + + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias) + + def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + grid = (vnc(self.logical_neuron_cores),) + fused_residual = residual is not None + logger.debug( + f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Can't do residual add in the kernel if SP is enabled + if fused_residual: + assert ( + not self.sequence_parallel_enabled + ), "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!" + # Using fused residual add + _mlp_fwd_call = nki_jit()(quant_mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel) + + # Handle SP RMSnorm + x_orig_dtype = x.dtype + if self.sequence_parallel_enabled: + # This RMSNormQuant kernel will do quantization inside, so we pass the + # lower_bound for clipping. + # If we don't use this kernel, the MLP kernel below will do the + # quantization, so we also pass lower_bound to that kernel. + if self.rmsnorm_quantize_kernel_enabled: + logger.debug( + "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!" + ) + _rmsnorm_quant_fwd_call = nki_jit()(rmsnorm_quant_isa_kernel) + quant_rmsnorm_out = torch.zeros( + size=( + x.shape[0], # batch size + x.shape[1], # sequence length + x.shape[2] + 4, # hidden size + 4 bytes for packing fp32 scale + ), + dtype=torch.int8, + device=x.device, + ) + ln_w = rmsnorm.weight.unsqueeze(0) + lower_bound = self.quantized_kernel_lower_bound + _rmsnorm_quant_fwd_call[grid]( + x, ln_w, lower_bound, quant_rmsnorm_out, kernel_name="QuantOnly" + ) + x = gather_from_sequence_parallel_region( + quant_rmsnorm_out, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + + else: + logger.debug( + "Running Quantized MLP kernel with external (native compiler) sequence-parallel RMSnorm!" + ) + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x_orig_dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + gate_w_scale = self.gate_proj.weight_scale + up_w = self.up_proj.weight.data + up_w_scale = self.up_proj.weight_scale + down_w = self.down_proj.weight.data + down_w_scale = self.down_proj.weight_scale + lower_bound = self.quantized_kernel_lower_bound + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, ) + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + ) + residual = None - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=False, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=self.sequence_parallel_enabled, - sequence_dimension=self.sequence_dimension, + # All-reduce or reduce-scatter, depending on whether SP is enabled + if self.sequence_parallel_enabled: + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) ) else: - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False + output_tensor = reduce_from_tensor_model_parallel_region(output_tensor) + + logger.debug(f"Quantized MLP output shape {output_tensor.shape}") + return (output_tensor, residual) + + def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + fused_residual = residual is not None + logger.debug( + f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Choose which kernel to call + if fused_residual: + assert ( + not self.sequence_parallel_enabled + ), "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" + # Using fused residual add + _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(mlp_isa_kernel) + + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) ) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x.dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + up_w = self.up_proj.weight.data + down_w = self.down_proj.weight.data + + grid = (vnc(self.logical_neuron_cores),) + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + up_w, # up_w + down_w, # down_w + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, ) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias=False + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, + up_w, + down_w, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", ) + residual = None - def forward(self, hidden_state): + # All-reduce or reduce-scatter, depending on whether SP is enabled if self.sequence_parallel_enabled: - x = _gather_along_dim(x, self.sequence_dimension) + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) + ) else: - x = hidden_state + output_tensor = reduce_from_tensor_model_parallel_region( + output_tensor, process_group=get_tp_group(self.config) + ) + + logger.debug(f"MLP output shape {output_tensor.shape}") + return (output_tensor, residual) - return self.down_proj(self.activation_fn(self.gate_proj(x)) * self.up_proj(x)) + def _native_mlp(self, x, rmsnorm, adapter_ids=None): + logger.debug("MLP: native compiler") + # all-gather is done here instead of CPL layers to + # avoid 2 all-gathers from up and gate projections + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + gate_proj_output = ( + self.gate_proj(x) + if not is_lora_module(self.gate_proj) + else self.gate_proj(x, adapter_ids) + ) + up_proj_output = ( + self.up_proj(x) if not is_lora_module(self.up_proj) else self.up_proj(x, adapter_ids) + ) + down_proj_input = self.act_fn(gate_proj_output) * up_proj_output + output = ( + self.down_proj(down_proj_input) + if not is_lora_module(self.up_proj) + else self.down_proj(down_proj_input, adapter_ids) + ) + logger.debug(f"MLP output shape {output.shape}") + return output + + def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None): + """ + If residual is passed in, will fuse its add into the MLP kernel + + Returns a tuple of (output, residual), where residual is the output of the residual add + """ + if self.mlp_kernel_enabled: + fused_rmsnorm = not self.sequence_parallel_enabled + # Quantized MLP kernel + if self.quantized_mlp_kernel_enabled: + return self._kernel_enabled_quantized_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + # MLP kernel + return self._kernel_enabled_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + else: + # No kernel + return (self._native_mlp(x, rmsnorm, adapter_ids=adapter_ids), None) @register_module("NeuronPhi3Attention") class NeuronPhi3Attention(NeuronAttentionBase): - """Multi-headed attention from 'Attention Is All You Need' paper""" + """ + Compared with Phi3Attention, this class just + 1. replaces the q_proj, k_proj, v_proj with column parallel layer + 2. replaces the o_proj with row parallel layer + 3. update self.num_head to be self.num_head / tp_degree + 4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree + 5. update forward() method to adjust to changes from self.num_head + """ + + def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): + super().__init__(tensor_model_parallel_group=tensor_model_parallel_group) - def __init__(self, config: InferenceConfig): - super().__init__() self.config = config self.neuron_config = config.neuron_config self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.padding_side = config.neuron_config.padding_side self.torch_dtype = config.neuron_config.torch_dtype + self.is_medusa = config.neuron_config.is_medusa + self.flash_decoding_enabled = config.neuron_config.flash_decoding_enabled + self.num_cores_per_group = config.num_cores_per_group + self.bias = getattr(config, "attention_bias", False) + self.rpl_reduce_dtype = config.neuron_config.rpl_reduce_dtype + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rms_norm_eps = config.rms_norm_eps if parallel_state.model_parallel_is_initialized(): - self.tp_degree = parallel_state.get_tensor_model_parallel_size() + self.tp_degree = self.config.neuron_config.tp_degree else: self.tp_degree = 1 @@ -244,89 +625,146 @@ def __init__(self, config: InferenceConfig): self.sequence_parallel_enabled = self.neuron_config.sequence_parallel_enabled self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + logger.debug( + f"Hello from NeuronPhi3Attention init! Is SP enabled? {self.sequence_parallel_enabled}. Dim? {self.sequence_dimension}" + ) - self.init_custom_gqa_properties() + self.init_gqa_properties() self.init_rope() - def init_custom_gqa_properties(self): - if (self.head_dim * self.num_attention_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_attention_heads})." - ) - - self.qkv_proj = GroupQueryAttention_QKV( - hidden_size=self.hidden_size, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - tp_degree=self.tp_degree, - dtype=self.torch_dtype, - bias=False, - gather_output=False, - fused_qkv=self.fused_qkv, - clip_qkv=self.clip_qkv, - sequence_parallel_enabled=self.sequence_parallel_enabled, - sequence_dimension=self.sequence_dimension, - ) - self.o_proj = GroupQueryAttention_O( - hidden_size=self.hidden_size, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - tp_degree=self.tp_degree, - dtype=self.torch_dtype, - bias=False, - input_is_parallel=True, - layer_name=self.o_proj_layer_name, - sequence_parallel_enabled=self.sequence_parallel_enabled, - sequence_dimension=self.sequence_dimension, - ) - self.num_heads = utils.divide( - self.qkv_proj.get_num_attention_heads(), self.tp_degree - ) - self.num_key_value_heads = utils.divide( - self.qkv_proj.get_num_key_value_heads(), self.tp_degree - ) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - def init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = Phi3RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + if not hasattr(self.config, "rope_scaling") or self.config.rope_scaling is None: + # TODO(yihsian): Check if we can just use our own implementation + if self.is_medusa: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + self.rotary_emb = RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) else: - scaling_type = self.config.rope_scaling.get("type") - if scaling_type == "longrope": + rope_type = self.config.rope_scaling.get( + "rope_type", self.config.rope_scaling.get("type", None) + ) + if rope_type == "longrope": self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( self.head_dim, self.config ) else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + # Phi3RotaryEmbedding automatically chooses the correct scaling type from config. + # Warning: The HF implementation may have precision issues when run on Neuron. + # We include it here for compatibility with other scaling types. + self.rotary_emb = Phi3RotaryEmbedding(self.config) + + +# TODO: Modularize RotaryEmbedding. See how HF transformers does it in 4.43. +# class Phi33RotaryEmbedding(nn.Module): +# """ +# Adapted from Phi3 4.43 impl +# * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/Phi3/modeling_Phi3.py#L78 +# * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/modeling_rope_utils.py#L345 + +# This implementation ensures inv_freq is calculated and stored in fp32. +# """ + +# def __init__( +# self, +# dim, +# max_position_embeddings=131072, +# base=500000.0, +# factor=8.0, +# low_freq_factor=1.0, +# high_freq_factor=4.0, +# original_max_position_embeddings=8192, +# ): +# super().__init__() +# self.dim = dim +# self.max_position_embeddings = max_position_embeddings +# self.base = base +# self.factor = factor +# self.low_freq_factor = low_freq_factor +# self.high_freq_factor = high_freq_factor +# self.old_context_len = original_max_position_embeddings +# self.register_buffer("inv_freq", None, persistent=False) + +# @torch.no_grad() +# def forward(self, x, position_ids): +# # x: [bs, num_attention_heads, seq_len, head_size] +# if self.inv_freq is None: +# inv_freq = 1.0 / ( +# self.base +# ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) +# ) + +# low_freq_wavelen = self.old_context_len / self.low_freq_factor +# high_freq_wavelen = self.old_context_len / self.high_freq_factor +# new_freqs = [] +# for freq in inv_freq: +# wavelen = 2 * math.pi / freq +# if wavelen < high_freq_wavelen: +# new_freqs.append(freq) +# elif wavelen > low_freq_wavelen: +# new_freqs.append(freq / self.factor) +# else: +# assert low_freq_wavelen != high_freq_wavelen +# smooth = (self.old_context_len / wavelen - self.low_freq_factor) / ( +# self.high_freq_factor - self.low_freq_factor +# ) +# new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq) +# self.inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) + +# inv_freq_expanded = ( +# self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) +# ) +# position_ids_expanded = position_ids[:, None, :].float() +# with torch.autocast(device_type=x.device.type, enabled=False): +# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) +# emb = torch.cat((freqs, freqs), dim=-1) +# cos = emb.cos() +# sin = emb.sin() +# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class NeuronPhi3DecoderLayer(nn.Module): - def __init__(self, config: Phi3Config, layer_idx: int): - super().__init__() + """ + Just replace the attention with the NXD version, and MLP with the NXD version + """ + def __init__(self, config: InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size self.self_attn = NeuronPhi3Attention( - config=config, + config=config, tensor_model_parallel_group=get_tp_group(config) ) - self.hidden_size = config.hidden_size - self.mlp = NeuronPhi3MLP(config) - self.input_layernorm = get_rmsnorm_cls()( - config.hidden_size, eps=config.rms_norm_eps + logger.debug( + f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" ) - - self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) - self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.input_layernorm = None + if ( + not config.neuron_config.is_eagle_draft + or config.neuron_config.enable_eagle_draft_input_norm + ): + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) self.post_attention_layernorm = get_rmsnorm_cls()( - config.hidden_size, eps=config.rms_norm_eps + config.hidden_size, + eps=config.rms_norm_eps, ) + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = config.neuron_config.rmsnorm_quantize_kernel_enabled + self.mlp_kernel_fuse_residual_add = config.neuron_config.mlp_kernel_fuse_residual_add + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.config = config def forward( self, @@ -334,40 +772,95 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + adapter_ids=None, **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + + # RMSNorm (fused with QKV kernel when SP is disabled) + if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm: + hidden_states = self.input_layernorm(hidden_states) # Self Attention - attn_outs = self.self_attn( + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + adapter_ids=adapter_ids, + rmsnorm=self.input_layernorm, **kwargs, ) - hidden_states, present_key_value = attn_outs - hidden_states = residual + hidden_states + if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: + assert ( + not self.sequence_parallel_enabled + ), "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" + # First residual add handled in the MLP kernel + hidden_states, residual = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + residual=residual, + adapter_ids=adapter_ids, + ) + else: + hidden_states = residual + hidden_states + residual = hidden_states + # RMSNorm (fused with QKV kernel when SP is disabled) + if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + adapter_ids=adapter_ids, + ) - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return (hidden_states, present_key_value) + outputs = (hidden_states, present_key_value, cos_cache, sin_cache) + return outputs + + +class ResBlock(nn.Module): + """ + A Residual Block module. + + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual connection. + + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + # Initialize as an identity mapping + torch.nn.init.zeros_(self.linear.weight) + # Use SiLU activation to keep consistent with the Phi3 model + self.act = nn.SiLU() + + def forward(self, x): + """ + Forward pass of the ResBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) class NeuronPhi3Model(NeuronBaseModel): + """ + The neuron version of the Phi3Model + """ + def setup_attr_for_model(self, config: InferenceConfig): # Needed for init_inference_optimization() - self.on_device_sampling = ( - config.neuron_config.on_device_sampling_config is not None - ) + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None self.tp_degree = config.neuron_config.tp_degree self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads @@ -375,8 +868,6 @@ def setup_attr_for_model(self, config: InferenceConfig): self.max_batch_size = config.neuron_config.max_batch_size self.buckets = config.neuron_config.buckets - # self._attn_implementation = config._attn_implementation - def init_model(self, config: InferenceConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -387,17 +878,20 @@ def init_model(self, config: InferenceConfig): config.hidden_size, self.padding_idx, dtype=config.neuron_config.torch_dtype, - shard_across_embedding=True, - # We choose to shard across embedding dimension because this stops XLA from introducing - # rank specific constant parameters into the HLO. We could shard across vocab, but that - # would require us to use non SPMD parallel_model_trace. + shard_across_embedding=not config.neuron_config.vocab_parallel, + sequence_parallel_enabled=False, pad=True, + tensor_model_parallel_group=get_tp_group(config), + use_spmd_rank=config.neuron_config.vocab_parallel, ) + self.lm_head = ColumnParallelLinear( config.hidden_size, config.vocab_size, + gather_output=not self.on_device_sampling, bias=False, pad=True, + tensor_model_parallel_group=get_tp_group(config), ) else: self.embed_tokens = nn.Embedding( @@ -409,24 +903,48 @@ def init_model(self, config: InferenceConfig): config.hidden_size, config.vocab_size, bias=False, - pad=True, ) - self.layers = nn.ModuleList( - [ - NeuronPhi3DecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] - ) - self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - - def get_input_embeddings(self): - return self.embed_tokens + # In the target fp8 checkpoint, the 1st and last + # layers are not using fp8. + updated_configs = [] + for i in range(config.num_hidden_layers): + # TODO: Remove hardcoded code to have non-quantized MLPs for first and last decoder block + if i == 0 or i == config.num_hidden_layers - 1: + non_quant_config = copy.deepcopy(config) + non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False + updated_configs.append(non_quant_config) + else: + updated_configs.append(config) + self.layers = nn.ModuleList([NeuronPhi3DecoderLayer(conf) for conf in updated_configs]) + if not config.neuron_config.is_eagle_draft: + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + if config.neuron_config.is_eagle_draft: + fc_bias = getattr(config, "fc_bias", False) + self.fc = ColumnParallelLinear( + config.hidden_size * 2, config.hidden_size, bias=fc_bias, gather_output=True + ) + self.is_medusa = config.neuron_config.is_medusa + self.num_medusa_heads = config.neuron_config.num_medusa_heads + self.medusa_speculation_length = config.neuron_config.medusa_speculation_length - def set_input_embeddings(self, value): - self.embed_tokens = value + if self.is_medusa: + if parallel_state.model_parallel_is_initialized(): + medusa_head_cls = ColumnParallelLinear + else: + medusa_head_cls = nn.Linear + for i in range(self.num_medusa_heads): + medusa_head = nn.Sequential( + *([ResBlock(config.hidden_size)] * 1), + medusa_head_cls( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ), + ) + setattr(self, f"medusa_head_{i}", medusa_head) class NeuronPhi3ForCausalLM(NeuronBaseForCausalLM): @@ -435,7 +953,7 @@ class NeuronPhi3ForCausalLM(NeuronBaseForCausalLM): blocks for Neuron. Args: - LlamaForCausalLM (_type_): _description_ + Phi3ForCausalLM (_type_): _description_ """ _model_cls = NeuronPhi3Model @@ -445,14 +963,33 @@ def load_hf_model(model_path): return Phi3ForCausalLM.from_pretrained(model_path) @staticmethod - def convert_hf_to_neuron_state_dict( - state_dict: dict, config: InferenceConfig - ) -> dict: + def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict: """This function should be over-ridden in child classes as needed""" - + neuron_config = config.neuron_config + # if neuron_config.fused_qkv: state_dict = convert_state_dict_to_neuron(state_dict, config) + + if neuron_config.vocab_parallel: + # TODO: this hack can be removed after replication_id is ready to use + state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + # to facilitate rank usage in attention + num_layers = config.num_hidden_layers + tp_degree = neuron_config.tp_degree + for i in range(num_layers): + state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + # to facilitate rank usage in base model + state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) return state_dict + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + @classmethod def get_config_cls(cls): return NeuronPhi3InferenceConfig From f831564ea698f4bdc4c303cafc267e9674ff1987 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Tue, 18 Mar 2025 08:57:54 -0400 Subject: [PATCH 4/6] handle GQA in convert_state_dict --- .../models/phi3/modeling_phi3.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py index 89847dd..7f030fa 100644 --- a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py +++ b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py @@ -130,11 +130,16 @@ def convert_state_dict_to_neuron(phi3_state_dict, cfg: InferenceConfig): ].clone().detach() # Get the fused QKV weight - fused_weight = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"].clone().detach() + fused_attn = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"].clone().detach() fused_gate_up = phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"].clone().detach() - + # Potentially handle GQA + if cfg.num_attention_heads > cfg.num_key_value_heads: + q_features = cfg.hidden_size + q_weight = fused_attn[:q_features] + k_weight, v_weight = torch.chunk(fused_attn[q_features:], 2, dim=0) # Split the fused weight into Q, K, and V using torch.chunk - q_weight, k_weight, v_weight = torch.chunk(fused_weight, 3, dim=0) + else: + q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=0) gate, up = torch.chunk(fused_gate_up, 2, dim=0) # Add the split weights to the state dict From 14a947a61d64cdc78bd647ce653f295e8ff89765 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:11:53 -0400 Subject: [PATCH 5/6] update for new release --- .../models/phi3/modeling_phi3.py | 502 +++++++++--------- 1 file changed, 264 insertions(+), 238 deletions(-) diff --git a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py index 7f030fa..8df99df 100644 --- a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py +++ b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Phi model for NXD inference.""" +"""PyTorch Phi3 model for NXD inference.""" import copy import gc import logging @@ -37,11 +37,7 @@ reduce_scatter_to_sequence_parallel_region, ) from neuronx_distributed.parallel_layers.utils import get_padding_length -from neuronx_distributed.quantization.quantization_config import QuantizationType, QuantizedDtype -from neuronx_distributed.quantization.quantization_layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 - QuantizedColumnParallel, - QuantizedRowParallel, -) +from neuronx_distributed.utils import cpu_mode from neuronxcc.nki._private_kernels.mlp import ( mlp_fused_add_isa_kernel, mlp_isa_kernel, @@ -49,16 +45,12 @@ quant_mlp_isa_kernel, ) from neuronxcc.nki._private_kernels.rmsnorm import rmsnorm_quant_isa_kernel -from neuronxcc.starfish.penguin.targets.nki.private_api import vnc +from neuronxcc.nki.language import nc from torch import nn from torch_neuronx.xla_impl.ops import nki_jit from transformers import Phi3ForCausalLM from transformers.activations import ACT2FN -from transformers.models.phi3.modeling_phi3 import ( - Phi3RotaryEmbedding, - Phi3RMSNorm, - Phi3LongRoPEScaledRotaryEmbedding, -) +from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm, Phi3RotaryEmbedding from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig # noqa: E402 from neuronx_distributed_inference.models.model_base import ( # noqa: E402 @@ -79,7 +71,7 @@ from neuronx_distributed_inference.modules.lora_serving.lora_module import is_lora_module from neuronx_distributed_inference.utils.distributed import get_tp_group -_PHI3_MODULE_MAP = {} +_Phi3_MODULE_MAP = {} logger = logging.getLogger("Neuron") @@ -88,7 +80,7 @@ def get_rmsnorm_cls(): # Initialize to the appropriate implementation of RMSNorm # If infer on NXD -> CustomRMSNorm # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) - return CustomRMSNorm if parallel_state.model_parallel_is_initialized() else Phi3RMSNorm + return Phi3RMSNorm if cpu_mode() else CustomRMSNorm def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: str) -> bool: @@ -98,8 +90,49 @@ def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: st return False +# Get the modules_to_not_convert from the neuron configs +def get_modules_to_not_convert(neuron_config: NeuronConfig): + return getattr(neuron_config, "modules_to_not_convert", None) + + +def get_updated_configs(config: InferenceConfig): + """ + Generate a list of configurations for each hidden layer in a Phi3 model. + + This function creates a list of InferenceConfig objects, one for each layer. It + modifies the configurations for certain layers based on which modules should not + be converted to quantized format. The function uses get_modules_to_not_convert() + to determine which modules should not be converted. + + Args: + config (InferenceConfig): The inference configuration for the model. + + Returns: + list[InferenceConfig]: A list of InferenceConfig objects, one for each layer in the model. + Each config may be either the original config or a modified version + with "quantized_mlp_kernel_enabled" as False for that specific layer. + """ + updated_configs = [] + modules_to_not_convert = get_modules_to_not_convert(config.neuron_config) + if modules_to_not_convert is None: + modules_to_not_convert = [] + + for i in range(config.num_hidden_layers): + # If any of the MLP modules for this layer are in modules_to_not_convert + module_pattern = f"layers.{i}.mlp" + if any(module_pattern in module for module in modules_to_not_convert): + non_quant_config = copy.deepcopy(config) + non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False + non_quant_config.neuron_config.activation_quantization_type = None + non_quant_config.neuron_config.quantize_clamp_bound = float('inf') + updated_configs.append(non_quant_config) + else: + updated_configs.append(config) + return updated_configs + + def _register_module(key: str, cls: Type[nn.Module]): - _PHI3_MODULE_MAP[key] = cls + _Phi3_MODULE_MAP[key] = cls def register_module(key: str): @@ -122,35 +155,38 @@ def inner(cls: Type[nn.Module]): return inner -def convert_state_dict_to_neuron(phi3_state_dict, cfg: InferenceConfig): +def convert_state_dict_to_fused_qkv(phi3_state_dict, cfg: InferenceConfig): for l in range(cfg.num_hidden_layers): - # Keep the original fused weight as Wqkv.weight - phi3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = phi3_state_dict[ - f"layers.{l}.self_attn.qkv_proj.weight" - ].clone().detach() - # Get the fused QKV weight fused_attn = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"].clone().detach() fused_gate_up = phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"].clone().detach() # Potentially handle GQA - if cfg.num_attention_heads > cfg.num_key_value_heads: - q_features = cfg.hidden_size - q_weight = fused_attn[:q_features] - k_weight, v_weight = torch.chunk(fused_attn[q_features:], 2, dim=0) - # Split the fused weight into Q, K, and V using torch.chunk + if cfg.neuron_config.fused_qkv: + phi3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = phi3_state_dict[ + f"layers.{l}.self_attn.qkv_proj.weight" + ].clone().detach() else: - q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=0) - gate, up = torch.chunk(fused_gate_up, 2, dim=0) + q_features = cfg.hidden_size + q_weight = fused_attn[:q_features].clone() + k_v = fused_attn[q_features:].clone() + k_weight, v_weight = torch.chunk(k_v, 2, dim=0) + k_weight = k_weight.clone() # Ensure separate memory + v_weight = v_weight.clone() # Ensure separate memory + + # Store split weights with correct naming structure + phi3_state_dict[f"layers.{l}.self_attn.q_proj.weight"] = q_weight + phi3_state_dict[f"layers.{l}.self_attn.k_proj.weight"] = k_weight + phi3_state_dict[f"layers.{l}.self_attn.v_proj.weight"] = v_weight + + del phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"] + + gate_up_split = torch.chunk(fused_gate_up, 2, dim=0) + gate = gate_up_split[0].clone() # Ensure separate memory + up = gate_up_split[1].clone() - # Add the split weights to the state dict - phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.q_proj.weight"] = q_weight - phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.k_proj.weight"] = k_weight - phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.v_proj.weight"] = v_weight phi3_state_dict[f"layers.{l}.mlp.gate_proj.weight"] = gate phi3_state_dict[f"layers.{l}.mlp.up_proj.weight"] = up - # Remove the original qkv_proj weight - del phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"] del phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"] gc.collect() @@ -158,7 +194,7 @@ def convert_state_dict_to_neuron(phi3_state_dict, cfg: InferenceConfig): return phi3_state_dict -class NeuronPhi3InferenceConfig(InferenceConfig): +class Phi3InferenceConfig(InferenceConfig): def add_derived_config(self): self.num_cores_per_group = 1 if self.neuron_config.flash_decoding_enabled: @@ -173,11 +209,11 @@ def get_required_attributes(self) -> List[str]: "num_attention_heads", "num_hidden_layers", "num_key_value_heads", + "pad_token_id", "vocab_size", "max_position_embeddings", "rope_theta", "rms_norm_eps", - "pad_token_id", "hidden_act", ] @@ -208,97 +244,73 @@ def __init__(self, config: InferenceConfig): self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled self.quantized_mlp_kernel_enabled = self.neuron_config.quantized_mlp_kernel_enabled self.rmsnorm_quantize_kernel_enabled = self.neuron_config.rmsnorm_quantize_kernel_enabled - self.quantized_kernel_lower_bound = self.neuron_config.quantized_kernel_lower_bound - self.logical_neuron_cores = self.neuron_config.logical_neuron_cores + self.quantize_clamp_bound = self.neuron_config.quantize_clamp_bound + self.logical_nc_config = self.neuron_config.logical_nc_config + self.activation_quantization_type = self.neuron_config.activation_quantization_type mlp_bias = getattr(config, "mlp_bias", False) + + if self.neuron_config.quantized_mlp_kernel_enabled and self.quantize_clamp_bound == float('inf'): + logging.warning("quantize_clamp_bound is not specified in NeuronConfig. We will use the default value of 1200 for Phi3 models in quantized kernels.") + self.quantize_clamp_bound = 1200.0 if parallel_state.model_parallel_is_initialized(): - if self.quantized_mlp_kernel_enabled: - # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad + if self.neuron_config.quantized_mlp_kernel_enabled: + # # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad tp_degree = self.neuron_config.tp_degree self.intermediate_size += ( get_padding_length(self.intermediate_size // tp_degree, 128) * tp_degree ) logger.debug(f"Quantized intermediate_size: {self.intermediate_size}") - - quantization_type = QuantizationType(self.neuron_config.quantization_type) - quantized_dtype = QuantizedDtype.F8E4M3 - self.gate_proj = QuantizedColumnParallel( - input_size=self.hidden_size, - output_size=self.intermediate_size, - bias=mlp_bias, - gather_output=False, - sequence_parallel_enabled=False, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - quantization_type=quantization_type, - tensor_model_parallel_group=get_tp_group(config), - ) - self.up_proj = QuantizedColumnParallel( - input_size=self.hidden_size, - output_size=self.intermediate_size, - bias=mlp_bias, - gather_output=False, - sequence_parallel_enabled=False, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - quantization_type=quantization_type, - tensor_model_parallel_group=get_tp_group(config), - ) - self.down_proj = QuantizedRowParallel( - input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=mlp_bias, - quantization_type=quantization_type, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - sequence_parallel_enabled=False, - quantization_per_channel_axis=0, - tensor_model_parallel_group=get_tp_group(config), - ) - - else: - self.gate_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=mlp_bias, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, - tensor_model_parallel_group=get_tp_group(config), - ) - self.up_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=mlp_bias, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, - tensor_model_parallel_group=get_tp_group(config), - ) - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=mlp_bias, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=self.sequence_parallel_enabled, - sequence_dimension=self.sequence_dimension, - tensor_model_parallel_group=get_tp_group(config), - reduce_dtype=config.neuron_config.rpl_reduce_dtype, - ) - + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=mlp_bias, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=get_tp_group(config), + reduce_dtype=config.neuron_config.rpl_reduce_dtype, + ) if self.mlp_kernel_enabled: - if self.quantized_mlp_kernel_enabled: - preprocess_quantized_linear_layer(self.gate_proj) - preprocess_quantized_linear_layer(self.up_proj) - preprocess_quantized_linear_layer(self.down_proj) - + if self.neuron_config.quantized_mlp_kernel_enabled: + setattr( + self.gate_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, + ) + setattr( + self.up_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, + ) + setattr( + self.down_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, + ) else: # Transpose the weights to the layout expected by kernels self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight) @@ -310,11 +322,12 @@ def __init__(self, config: InferenceConfig): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias) - def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): - grid = (vnc(self.logical_neuron_cores),) + def _kernel_enabled_quantized_mlp(self, x, rmsnorm, residual, adapter_ids): + grid = (nc(self.logical_nc_config),) fused_residual = residual is not None + fused_rmsnorm = rmsnorm is not None logger.debug( - f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_nc_config={self.logical_nc_config}" ) # Can't do residual add in the kernel if SP is enabled @@ -327,13 +340,17 @@ def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, ada else: _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel) + if fused_rmsnorm: + ln_w = rmsnorm.weight.unsqueeze(0) + else: + ln_w = torch.zeros(size=(1, self.hidden_size), dtype=x.dtype, device=x.device) # Handle SP RMSnorm x_orig_dtype = x.dtype if self.sequence_parallel_enabled: # This RMSNormQuant kernel will do quantization inside, so we pass the - # lower_bound for clipping. + # clamp_bound for clipping. # If we don't use this kernel, the MLP kernel below will do the - # quantization, so we also pass lower_bound to that kernel. + # quantization, so we also pass clamp_bound to that kernel. if self.rmsnorm_quantize_kernel_enabled: logger.debug( "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!" @@ -348,10 +365,9 @@ def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, ada dtype=torch.int8, device=x.device, ) - ln_w = rmsnorm.weight.unsqueeze(0) - lower_bound = self.quantized_kernel_lower_bound + clamp_bound = self.quantize_clamp_bound _rmsnorm_quant_fwd_call[grid]( - x, ln_w, lower_bound, quant_rmsnorm_out, kernel_name="QuantOnly" + x, ln_w, clamp_bound, quant_rmsnorm_out, kernel_name="QuantOnly" ) x = gather_from_sequence_parallel_region( quant_rmsnorm_out, @@ -386,14 +402,13 @@ def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, ada # Grab weights # all weights of the layers are stored in (out, in) shape # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] - ln_w = rmsnorm.weight.unsqueeze(0) gate_w = self.gate_proj.weight.data - gate_w_scale = self.gate_proj.weight_scale + gate_w_scale = self.gate_proj.scale up_w = self.up_proj.weight.data - up_w_scale = self.up_proj.weight_scale + up_w_scale = self.up_proj.scale down_w = self.down_proj.weight.data - down_w_scale = self.down_proj.weight_scale - lower_bound = self.quantized_kernel_lower_bound + down_w_scale = self.down_proj.scale + clamp_bound = self.quantize_clamp_bound if fused_residual: _mlp_fwd_call[grid]( @@ -406,7 +421,7 @@ def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, ada up_w_scale, down_w, # down_w down_w_scale, - lower_bound, + clamp_bound, output_tensor, # out fused_rmsnorm=fused_rmsnorm, eps=self.rms_norm_eps, @@ -427,7 +442,7 @@ def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, ada up_w_scale, down_w, # down_w down_w_scale, - lower_bound, + clamp_bound, output_tensor, # out # Run RMSNorm inside the kernel if NOT using SP rmsnorm fused_rmsnorm=fused_rmsnorm, @@ -447,10 +462,11 @@ def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, ada logger.debug(f"Quantized MLP output shape {output_tensor.shape}") return (output_tensor, residual) - def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + def _kernel_enabled_mlp(self, x, rmsnorm, residual, adapter_ids): fused_residual = residual is not None + fused_rmsnorm = rmsnorm is not None logger.debug( - f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_nc_config={self.logical_nc_config}" ) # Choose which kernel to call @@ -487,12 +503,15 @@ def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): # Grab weights # all weights of the layers are stored in (out, in) shape # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] - ln_w = rmsnorm.weight.unsqueeze(0) + if fused_rmsnorm: + ln_w = rmsnorm.weight.unsqueeze(0) + else: + ln_w = torch.zeros(size=(1, self.hidden_size), dtype=x.dtype, device=x.device) gate_w = self.gate_proj.weight.data up_w = self.up_proj.weight.data down_w = self.down_proj.weight.data - grid = (vnc(self.logical_neuron_cores),) + grid = (nc(self.logical_nc_config),) if fused_residual: _mlp_fwd_call[grid]( @@ -540,7 +559,7 @@ def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): logger.debug(f"MLP output shape {output_tensor.shape}") return (output_tensor, residual) - def _native_mlp(self, x, rmsnorm, adapter_ids=None): + def _native_mlp(self, x, adapter_ids=None): logger.debug("MLP: native compiler") # all-gather is done here instead of CPL layers to # avoid 2 all-gathers from up and gate projections @@ -548,19 +567,20 @@ def _native_mlp(self, x, rmsnorm, adapter_ids=None): x = gather_from_sequence_parallel_region( x, self.sequence_dimension, process_group=get_tp_group(self.config) ) - gate_proj_output = ( self.gate_proj(x) if not is_lora_module(self.gate_proj) else self.gate_proj(x, adapter_ids) ) + up_proj_output = ( self.up_proj(x) if not is_lora_module(self.up_proj) else self.up_proj(x, adapter_ids) ) + down_proj_input = self.act_fn(gate_proj_output) * up_proj_output output = ( self.down_proj(down_proj_input) - if not is_lora_module(self.up_proj) + if not is_lora_module(self.down_proj) else self.down_proj(down_proj_input, adapter_ids) ) logger.debug(f"MLP output shape {output.shape}") @@ -569,23 +589,22 @@ def _native_mlp(self, x, rmsnorm, adapter_ids=None): def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None): """ If residual is passed in, will fuse its add into the MLP kernel + If rmsnorm is passed in, will fuse the rmsnorm into the MLP kernel Returns a tuple of (output, residual), where residual is the output of the residual add """ if self.mlp_kernel_enabled: - fused_rmsnorm = not self.sequence_parallel_enabled # Quantized MLP kernel if self.quantized_mlp_kernel_enabled: return self._kernel_enabled_quantized_mlp( - x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + x, rmsnorm, residual, adapter_ids=adapter_ids ) # MLP kernel - return self._kernel_enabled_mlp( - x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids - ) + return self._kernel_enabled_mlp(x, rmsnorm, residual, adapter_ids=adapter_ids) else: # No kernel - return (self._native_mlp(x, rmsnorm, adapter_ids=adapter_ids), None) + assert rmsnorm is None and residual is None + return (self._native_mlp(x, adapter_ids=adapter_ids), None) @register_module("NeuronPhi3Attention") @@ -607,7 +626,7 @@ def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.head_dim = self.hidden_size // self.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_attention_heads) self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.padding_side = config.neuron_config.padding_side @@ -642,11 +661,7 @@ def init_rope(self): if not hasattr(self.config, "rope_scaling") or self.config.rope_scaling is None: # TODO(yihsian): Check if we can just use our own implementation if self.is_medusa: - self.rotary_emb = Phi3RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = Phi3RotaryEmbedding(self.config) else: self.rotary_emb = RotaryEmbedding( self.head_dim, @@ -657,9 +672,17 @@ def init_rope(self): rope_type = self.config.rope_scaling.get( "rope_type", self.config.rope_scaling.get("type", None) ) - if rope_type == "longrope": - self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( - self.head_dim, self.config + if rope_type == "phi3": + self.rotary_emb = Phi3RotaryEmbedding( + dim=self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + factor=self.config.rope_scaling["factor"], + low_freq_factor=self.config.rope_scaling["low_freq_factor"], + high_freq_factor=self.config.rope_scaling["high_freq_factor"], + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], ) else: # Phi3RotaryEmbedding automatically chooses the correct scaling type from config. @@ -669,71 +692,71 @@ def init_rope(self): # TODO: Modularize RotaryEmbedding. See how HF transformers does it in 4.43. -# class Phi33RotaryEmbedding(nn.Module): -# """ -# Adapted from Phi3 4.43 impl -# * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/Phi3/modeling_Phi3.py#L78 -# * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/modeling_rope_utils.py#L345 - -# This implementation ensures inv_freq is calculated and stored in fp32. -# """ - -# def __init__( -# self, -# dim, -# max_position_embeddings=131072, -# base=500000.0, -# factor=8.0, -# low_freq_factor=1.0, -# high_freq_factor=4.0, -# original_max_position_embeddings=8192, -# ): -# super().__init__() -# self.dim = dim -# self.max_position_embeddings = max_position_embeddings -# self.base = base -# self.factor = factor -# self.low_freq_factor = low_freq_factor -# self.high_freq_factor = high_freq_factor -# self.old_context_len = original_max_position_embeddings -# self.register_buffer("inv_freq", None, persistent=False) - -# @torch.no_grad() -# def forward(self, x, position_ids): -# # x: [bs, num_attention_heads, seq_len, head_size] -# if self.inv_freq is None: -# inv_freq = 1.0 / ( -# self.base -# ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) -# ) - -# low_freq_wavelen = self.old_context_len / self.low_freq_factor -# high_freq_wavelen = self.old_context_len / self.high_freq_factor -# new_freqs = [] -# for freq in inv_freq: -# wavelen = 2 * math.pi / freq -# if wavelen < high_freq_wavelen: -# new_freqs.append(freq) -# elif wavelen > low_freq_wavelen: -# new_freqs.append(freq / self.factor) -# else: -# assert low_freq_wavelen != high_freq_wavelen -# smooth = (self.old_context_len / wavelen - self.low_freq_factor) / ( -# self.high_freq_factor - self.low_freq_factor -# ) -# new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq) -# self.inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) - -# inv_freq_expanded = ( -# self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) -# ) -# position_ids_expanded = position_ids[:, None, :].float() -# with torch.autocast(device_type=x.device.type, enabled=False): -# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) -# emb = torch.cat((freqs, freqs), dim=-1) -# cos = emb.cos() -# sin = emb.sin() -# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class Phi3RotaryEmbedding(nn.Module): + """ + Adapted from Phi3 4.43 impl + * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/Phi3/modeling_Phi3.py#L78 + * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/modeling_rope_utils.py#L345 + + This implementation ensures inv_freq is calculated and stored in fp32. + """ + + def __init__( + self, + dim, + max_position_embeddings=131072, + base=500000.0, + factor=8.0, + low_freq_factor=1.0, + high_freq_factor=4.0, + original_max_position_embeddings=8192, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.factor = factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.old_context_len = original_max_position_embeddings + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + + low_freq_wavelen = self.old_context_len / self.low_freq_factor + high_freq_wavelen = self.old_context_len / self.high_freq_factor + new_freqs = [] + for freq in inv_freq: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / self.factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (self.old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq) + self.inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) + + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + with torch.autocast(device_type=x.device.type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class NeuronPhi3DecoderLayer(nn.Module): @@ -744,7 +767,7 @@ class NeuronPhi3DecoderLayer(nn.Module): def __init__(self, config: InferenceConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = NeuronPhi3Attention( + self.self_attn = _Phi3_MODULE_MAP[config.neuron_config.attn_cls]( config=config, tensor_model_parallel_group=get_tp_group(config) ) self.mlp = NeuronPhi3MLP(config) @@ -766,11 +789,19 @@ def __init__(self, config: InferenceConfig): ) self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.quantized_mlp_kernel_enabled = config.neuron_config.quantized_mlp_kernel_enabled self.rmsnorm_quantize_kernel_enabled = config.neuron_config.rmsnorm_quantize_kernel_enabled self.mlp_kernel_fuse_residual_add = config.neuron_config.mlp_kernel_fuse_residual_add self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.is_prefill_stage = config.neuron_config.is_prefill_stage self.config = config + if self.is_prefill_stage and self.config.neuron_config.is_mlp_quantized(): + # for CTE, quantized MLP kernel does not support fused rmsnorm + self.mlp_kernel_fused_rmsnorm = False + else: + self.mlp_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + def forward( self, hidden_states: torch.Tensor, @@ -778,6 +809,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, adapter_ids=None, + rotary_position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -794,6 +826,7 @@ def forward( past_key_value=past_key_value, adapter_ids=adapter_ids, rmsnorm=self.input_layernorm, + rotary_position_ids=rotary_position_ids, **kwargs, ) @@ -812,11 +845,14 @@ def forward( hidden_states = residual + hidden_states residual = hidden_states # RMSNorm (fused with QKV kernel when SP is disabled) - if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: + if self.mlp_kernel_enabled and self.mlp_kernel_fused_rmsnorm: + rmsnorm = self.post_attention_layernorm + else: hidden_states = self.post_attention_layernorm(hidden_states) + rmsnorm = None hidden_states, _ = self.mlp( hidden_states, - rmsnorm=self.post_attention_layernorm, + rmsnorm=rmsnorm, adapter_ids=adapter_ids, ) @@ -910,18 +946,9 @@ def init_model(self, config: InferenceConfig): bias=False, ) - # In the target fp8 checkpoint, the 1st and last - # layers are not using fp8. - updated_configs = [] - for i in range(config.num_hidden_layers): - # TODO: Remove hardcoded code to have non-quantized MLPs for first and last decoder block - if i == 0 or i == config.num_hidden_layers - 1: - non_quant_config = copy.deepcopy(config) - non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False - updated_configs.append(non_quant_config) - else: - updated_configs.append(config) + updated_configs = get_updated_configs(config) self.layers = nn.ModuleList([NeuronPhi3DecoderLayer(conf) for conf in updated_configs]) + if not config.neuron_config.is_eagle_draft: self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) @@ -971,8 +998,7 @@ def load_hf_model(model_path): def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict: """This function should be over-ridden in child classes as needed""" neuron_config = config.neuron_config - # if neuron_config.fused_qkv: - state_dict = convert_state_dict_to_neuron(state_dict, config) + state_dict = convert_state_dict_to_fused_qkv(state_dict, config) if neuron_config.vocab_parallel: # TODO: this hack can be removed after replication_id is ready to use @@ -997,4 +1023,4 @@ def update_state_dict_for_tied_weights(state_dict): @classmethod def get_config_cls(cls): - return NeuronPhi3InferenceConfig + return Phi3InferenceConfig \ No newline at end of file From f3aa3eb8606f0c0d40b5246cd0fc8d9bb05df405 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:43:48 -0400 Subject: [PATCH 6/6] Add detach().contiguous() --- .../models/phi3/modeling_phi3.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py index 8df99df..5d03e91 100644 --- a/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py +++ b/src/neuronx_distributed_inference/models/phi3/modeling_phi3.py @@ -1,4 +1,4 @@ -# coding=utf-8 +s # coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -164,14 +164,14 @@ def convert_state_dict_to_fused_qkv(phi3_state_dict, cfg: InferenceConfig): if cfg.neuron_config.fused_qkv: phi3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = phi3_state_dict[ f"layers.{l}.self_attn.qkv_proj.weight" - ].clone().detach() + ].clone().detach().contiguous() else: q_features = cfg.hidden_size - q_weight = fused_attn[:q_features].clone() - k_v = fused_attn[q_features:].clone() + q_weight = fused_attn[:q_features].clone().detach().contiguous() + k_v = fused_attn[q_features:].clone().detach().contiguous() k_weight, v_weight = torch.chunk(k_v, 2, dim=0) - k_weight = k_weight.clone() # Ensure separate memory - v_weight = v_weight.clone() # Ensure separate memory + k_weight = k_weight.clone().detach().contiguous() #Ensure separate memory + v_weight = v_weight.clone().detach().contiguous() #Ensure separate memory # Store split weights with correct naming structure phi3_state_dict[f"layers.{l}.self_attn.q_proj.weight"] = q_weight @@ -181,8 +181,8 @@ def convert_state_dict_to_fused_qkv(phi3_state_dict, cfg: InferenceConfig): del phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"] gate_up_split = torch.chunk(fused_gate_up, 2, dim=0) - gate = gate_up_split[0].clone() # Ensure separate memory - up = gate_up_split[1].clone() + gate = gate_up_split[0].clone().detach().contiguous() #Ensure separate memory + up = gate_up_split[1].clone().detach().contiguous() phi3_state_dict[f"layers.{l}.mlp.gate_proj.weight"] = gate phi3_state_dict[f"layers.{l}.mlp.up_proj.weight"] = up