diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index a7183a30d..0286acc37 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -37,6 +37,32 @@ namespace ctranslate2 { const bool _tensor_parallel; }; + class Moe : public Layer + { + public: + Moe(const models::Model& model, + const std::string& scope, + const bool pre_norm = true, + const ops::ActivationType activation_type = ops::ActivationType::ReLU); + + void operator()(StorageView& input, StorageView& output) const; + DataType output_type() const override { + return _ffn_layers.back()->output_type(); + } + + dim_t output_size() const override { + return _ffn_layers.back()->output_size(); + } + + private: + const std::unique_ptr _layer_norm; + const Dense _gate; + const bool _pre_norm; + const ops::ActivationType _activation_type; + const dim_t _num_experts_per_tok; + const std::vector> _ffn_layers; + }; + class TransformerEncoderLayer : public Layer { public: @@ -96,11 +122,17 @@ namespace ctranslate2 { dim_t offset = 0) const; DataType output_type() const override { - return _ff.output_type(); + if (_ff) + return _ff->output_type(); + else + return _moe->output_type(); } dim_t output_size() const override { - return _ff.output_size(); + if (_ff) + return _ff->output_size(); + else + return _moe->output_size(); } bool has_cross_attention() const { @@ -117,7 +149,8 @@ namespace ctranslate2 { const std::unique_ptr _input_layer_norm; const std::unique_ptr _post_attention_layer_norm; const std::unique_ptr _encoder_attention; - const FeedForwardNetwork _ff; + const std::unique_ptr _moe; + const std::unique_ptr _ff; }; class TransformerEncoder : public Encoder diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index cc5176806..7054b914c 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1568,6 +1568,112 @@ def set_decoder(self, spec, module): gc.collect() +@register_loader("MixtralConfig") +class MistralLoader(ModelLoader): + @property + def architecture_name(self): + return "MixtralForCausalLM" + + def get_model_spec(self, model): + num_layers = model.config.num_hidden_layers + + num_heads = model.config.num_attention_heads + num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads) + if num_heads_kv == num_heads: + num_heads_kv = None + + sliding_window = getattr(model.config, "sliding_window", 0) + + rope_scaling = getattr(model.config, "rope_scaling", None) + if rope_scaling: + rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"]) + rotary_scaling_factor = rope_scaling["factor"] + + if rotary_scaling_type is None: + raise NotImplementedError( + "RoPE scaling type '%s' is not yet implemented. " + "The following RoPE scaling types are currently supported: %s" + % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys())) + ) + else: + rotary_scaling_type = None + rotary_scaling_factor = 1 + + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + num_layers, + num_heads, + activation=common_spec.Activation.SWISH, + pre_norm=True, + ffn_glu=True, + rms_norm=True, + rotary_dim=0, + rotary_interleave=False, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=getattr(model.config, "rope_theta", 10000), + num_heads_kv=num_heads_kv, + sliding_window=sliding_window, + num_local_experts=getattr(model.config, "num_local_experts", 8), + num_experts_per_tok=getattr(model.config, "num_experts_per_tok", 2) + ) + + self.set_decoder(spec.decoder, model.model) + self.set_linear(spec.decoder.projection, model.lm_head) + return spec + + def get_vocabulary(self, model, tokenizer): + tokens = super().get_vocabulary(model, tokenizer) + + extra_ids = model.config.vocab_size - len(tokens) + for i in range(extra_ids): + tokens.append("" % i) + + return tokens + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = tokenizer.bos_token + config.eos_token = tokenizer.eos_token + config.unk_token = tokenizer.unk_token + config.layer_norm_epsilon = model.config.rms_norm_eps + + def set_layer_norm(self, spec, layer_norm): + spec.gamma = layer_norm.weight + + def set_decoder(self, spec, module): + spec.scale_embeddings = False + self.set_embeddings(spec.embeddings, module.embed_tokens) + self.set_layer_norm(spec.layer_norm, module.norm) + + for layer_spec, layer in zip(spec.layer, module.layers): + self.set_layer_norm( + layer_spec.self_attention.layer_norm, layer.input_layernorm + ) + self.set_layer_norm( + layer_spec.moe.layer_norm, layer.post_attention_layernorm + ) + + wq = layer.self_attn.q_proj.weight + wk = layer.self_attn.k_proj.weight + wv = layer.self_attn.v_proj.weight + wo = layer.self_attn.o_proj.weight + + layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv]) + layer_spec.self_attention.linear[1].weight = wo + + self.set_linear(layer_spec.moe.gate, layer.block_sparse_moe.gate) + for ffn_spec, ffn in zip(layer_spec.moe.experts, layer.block_sparse_moe.experts): + self.set_linear(ffn_spec.linear_0, ffn.w1) + self.set_linear(ffn_spec.linear_0_noact, ffn.w3) + self.set_linear(ffn_spec.linear_1, ffn.w2) + + delattr(layer, "self_attn") + delattr(layer, "block_sparse_moe") + gc.collect() + + @register_loader("MixFormerSequentialConfig") class MixFormerSequentialLoader(ModelLoader): @property diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index c3f8d91be..bb3c9e017 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -9,20 +9,20 @@ class TransformerEncoderSpec(model_spec.LayerSpec): def __init__( - self, - num_layers: int, - num_heads: int, - pre_norm: bool = True, - no_final_norm: bool = False, - activation: common_spec.Activation = common_spec.Activation.RELU, - num_source_embeddings: int = 1, - embeddings_merge: common_spec.EmbeddingsMerge = common_spec.EmbeddingsMerge.CONCAT, - layernorm_embedding: bool = False, - relative_position: bool = False, - relative_attention_bias: bool = False, - ffn_glu: bool = False, - rms_norm: bool = False, - multi_query_attention: bool = False, + self, + num_layers: int, + num_heads: int, + pre_norm: bool = True, + no_final_norm: bool = False, + activation: common_spec.Activation = common_spec.Activation.RELU, + num_source_embeddings: int = 1, + embeddings_merge: common_spec.EmbeddingsMerge = common_spec.EmbeddingsMerge.CONCAT, + layernorm_embedding: bool = False, + relative_position: bool = False, + relative_attention_bias: bool = False, + ffn_glu: bool = False, + rms_norm: bool = False, + multi_query_attention: bool = False, ): """Initializes a Transformer encoder specification. @@ -74,35 +74,37 @@ def __init__( class TransformerDecoderSpec(model_spec.LayerSpec): def __init__( - self, - num_layers: int, - num_heads: int, - pre_norm: bool = True, - activation: common_spec.Activation = common_spec.Activation.RELU, - layernorm_embedding: bool = False, - with_encoder_attention: bool = True, - no_final_norm: bool = False, - project_in_out: bool = False, - relative_position: bool = False, - relative_attention_bias: bool = False, - alignment_layer: int = -1, - alignment_heads: int = 1, - ffn_glu: bool = False, - rms_norm: bool = False, - alibi: bool = False, - alibi_use_positive_positions: bool = False, - scale_alibi: bool = False, - rotary_dim: Optional[int] = None, - rotary_interleave: bool = True, - rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, - rotary_scaling_factor: float = 1, - rotary_base: float = 10000, - parallel_residual: bool = False, - shared_layer_norm: bool = False, - multi_query_attention: bool = False, - num_heads_kv: Optional[int] = None, - head_dim: Optional[int] = None, - sliding_window: Optional[int] = None, + self, + num_layers: int, + num_heads: int, + pre_norm: bool = True, + activation: common_spec.Activation = common_spec.Activation.RELU, + layernorm_embedding: bool = False, + with_encoder_attention: bool = True, + no_final_norm: bool = False, + project_in_out: bool = False, + relative_position: bool = False, + relative_attention_bias: bool = False, + alignment_layer: int = -1, + alignment_heads: int = 1, + ffn_glu: bool = False, + rms_norm: bool = False, + alibi: bool = False, + alibi_use_positive_positions: bool = False, + scale_alibi: bool = False, + rotary_dim: Optional[int] = None, + rotary_interleave: bool = True, + rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, + rotary_scaling_factor: float = 1, + rotary_base: float = 10000, + parallel_residual: bool = False, + shared_layer_norm: bool = False, + multi_query_attention: bool = False, + num_heads_kv: Optional[int] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + num_local_experts: Optional[int] = None, + num_experts_per_tok: Optional[int] = None, ): """Initializes a Transformer decoder specification. @@ -142,6 +144,8 @@ def __init__( multi_query_attention: Use multi-query attention (alias for num_heads_kv=1). num_heads_kv: Number of attention heads for the key and value. sliding_window: Max sequence length to retain in KV Cache. + num_local_experts: total experts in moe layer + num_experts_per_tok: number of experts used by each token """ if parallel_residual: if not pre_norm: @@ -176,10 +180,10 @@ def __init__( if sliding_window is not None: self.sliding_window = np.dtype("int32").type(sliding_window) if ( - not relative_position - and not relative_attention_bias - and not alibi - and rotary_dim is None + not relative_position + and not relative_attention_bias + and not alibi + and rotary_dim is None ): self.position_encodings = PositionEncoderSpec() if pre_norm and not no_final_norm: @@ -204,12 +208,14 @@ def __init__( num_heads_kv=num_heads_kv, head_dim=head_dim, sliding_window=sliding_window, + num_local_experts=num_local_experts, + num_experts_per_tok=num_experts_per_tok, ) for _ in range(num_layers) ] self.start_from_zero_embedding = False self.multi_query_attention = multi_query_attention or ( - num_heads_kv != num_heads + num_heads_kv != num_heads ) if project_in_out: @@ -219,13 +225,13 @@ def __init__( class TransformerEncoderLayerSpec(model_spec.LayerSpec): def __init__( - self, - relative_position=False, - relative_attention_bias=False, - ffn_glu=False, - rms_norm=False, - num_heads_kv=None, - sliding_window=None, + self, + relative_position=False, + relative_attention_bias=False, + ffn_glu=False, + rms_norm=False, + num_heads_kv=None, + sliding_window=None, ): self.self_attention = attention_spec.MultiHeadAttentionSpec( self_attention=True, @@ -240,22 +246,24 @@ def __init__( class TransformerDecoderLayerSpec(model_spec.LayerSpec): def __init__( - self, - with_encoder_attention=True, - relative_position=False, - relative_attention_bias=False, - ffn_glu=False, - rms_norm=False, - rotary_dim=None, - rotary_interleave=True, - rotary_scaling_type=None, - rotary_scaling_factor=1, - rotary_base=10000, - parallel_residual=False, - shared_layer_norm=False, - num_heads_kv=None, - head_dim=None, - sliding_window=None, + self, + with_encoder_attention=True, + relative_position=False, + relative_attention_bias=False, + ffn_glu=False, + rms_norm=False, + rotary_dim=None, + rotary_interleave=True, + rotary_scaling_type=None, + rotary_scaling_factor=1, + rotary_base=10000, + parallel_residual=False, + shared_layer_norm=False, + num_heads_kv=None, + head_dim=None, + sliding_window=None, + num_local_experts=None, + num_experts_per_tok=None, ): self.self_attention = attention_spec.MultiHeadAttentionSpec( self_attention=True, @@ -279,7 +287,13 @@ def __init__( sliding_window=sliding_window, ) - self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) + if num_local_experts is not None and num_experts_per_tok is not None: + self.moe = MoeSpec(glu=ffn_glu, + rms_norm=rms_norm, + num_experts=num_local_experts, + num_experts_per_tok=num_experts_per_tok) + else: + self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) if parallel_residual: if shared_layer_norm: @@ -301,6 +315,29 @@ def __init__(self, glu=False, rms_norm=False): self.linear_0_noact = common_spec.LinearSpec() +class MoeFeedForwardSpec(model_spec.LayerSpec): + def __init__(self, glu=False): + self.linear_0 = common_spec.LinearSpec() + self.linear_1 = common_spec.LinearSpec() + if glu: + self.linear_0_noact = common_spec.LinearSpec() + + +class MoeSpec(model_spec.LayerSpec): + def __init__(self, glu=False, + rms_norm=False, + num_experts=0, + num_experts_per_tok=2): + self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) + self.gate = common_spec.LinearSpec() + self.num_experts_per_tok = np.dtype("int32").type(num_experts_per_tok) + self.experts = [MoeFeedForwardSpec( + glu=glu + ) + for _ in range(num_experts) + ] + + class PositionEncoderSpec(model_spec.LayerSpec): def __init__(self): self.encodings = model_spec.OPTIONAL @@ -327,7 +364,7 @@ class TransformerSpec(model_spec.SequenceToSequenceModelSpec): """ def __init__( - self, encoder: TransformerEncoderSpec, decoder: TransformerDecoderSpec + self, encoder: TransformerEncoderSpec, decoder: TransformerDecoderSpec ): """Initializes a Transformer model specification. @@ -349,22 +386,22 @@ def __init__( @classmethod def from_config( - cls, - num_layers: Union[int, Tuple[int, int]], - num_heads: int, - with_relative_position: bool = False, - pre_norm: bool = True, - no_final_norm: bool = False, - activation: common_spec.Activation = common_spec.Activation.RELU, - alignment_layer: int = -1, - alignment_heads: int = 1, - num_source_embeddings: int = 1, - embeddings_merge: common_spec.EmbeddingsMerge = common_spec.EmbeddingsMerge.CONCAT, - layernorm_embedding: bool = False, - relative_attention_bias: bool = False, - ffn_glu: bool = False, - rms_norm: bool = False, - multi_query_attention: bool = False, + cls, + num_layers: Union[int, Tuple[int, int]], + num_heads: int, + with_relative_position: bool = False, + pre_norm: bool = True, + no_final_norm: bool = False, + activation: common_spec.Activation = common_spec.Activation.RELU, + alignment_layer: int = -1, + alignment_heads: int = 1, + num_source_embeddings: int = 1, + embeddings_merge: common_spec.EmbeddingsMerge = common_spec.EmbeddingsMerge.CONCAT, + layernorm_embedding: bool = False, + relative_attention_bias: bool = False, + ffn_glu: bool = False, + rms_norm: bool = False, + multi_query_attention: bool = False, ): """Creates a Transformer model specification. @@ -480,31 +517,33 @@ def __init__(self, decoder: TransformerDecoderSpec): @classmethod def from_config( - cls, - num_layers: int, - num_heads: int, - pre_norm: bool = True, - activation: common_spec.Activation = common_spec.Activation.RELU, - layernorm_embedding: bool = False, - no_final_norm: bool = False, - project_in_out: bool = False, - with_relative_position: bool = False, - ffn_glu: bool = False, - rms_norm: bool = False, - alibi: bool = False, - alibi_use_positive_positions: bool = False, - scale_alibi: bool = False, - rotary_dim: Optional[int] = None, - rotary_interleave: bool = True, - rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, - rotary_scaling_factor: float = 1, - rotary_base: float = 10000, - parallel_residual: bool = False, - shared_layer_norm: bool = False, - multi_query_attention: bool = False, - num_heads_kv: Optional[int] = None, - head_dim: Optional[int] = None, - sliding_window: Optional[int] = None, + cls, + num_layers: int, + num_heads: int, + pre_norm: bool = True, + activation: common_spec.Activation = common_spec.Activation.RELU, + layernorm_embedding: bool = False, + no_final_norm: bool = False, + project_in_out: bool = False, + with_relative_position: bool = False, + ffn_glu: bool = False, + rms_norm: bool = False, + alibi: bool = False, + alibi_use_positive_positions: bool = False, + scale_alibi: bool = False, + rotary_dim: Optional[int] = None, + rotary_interleave: bool = True, + rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, + rotary_scaling_factor: float = 1, + rotary_base: float = 10000, + parallel_residual: bool = False, + shared_layer_norm: bool = False, + multi_query_attention: bool = False, + num_heads_kv: Optional[int] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + num_local_experts: Optional[int] = None, + num_experts_per_tok: Optional[int] = None, ): """Creates a Transformer decoder model specification. @@ -538,6 +577,8 @@ def from_config( multi_query_attention: Use multi-query attention (alias for num_heads_kv=1). num_heads_kv: Number of attention heads for the key and value. sliding_window: max sequence length to retain KV cache + num_local_experts: total experts in moe layer + num_experts_per_tok: number of experts used by each token """ decoder = TransformerDecoderSpec( num_layers, @@ -565,6 +606,8 @@ def from_config( num_heads_kv=num_heads_kv, head_dim=head_dim, sliding_window=sliding_window, + num_local_experts=num_local_experts, + num_experts_per_tok=num_experts_per_tok ) return cls(decoder) @@ -601,10 +644,10 @@ class TransformerEncoderModelSpec(model_spec.LanguageModelSpec): """Describes a Transformer encoder model (e.g. BERT).""" def __init__( - self, - encoder: TransformerEncoderSpec, - pooling_layer: bool = False, - pooling_activation: common_spec.Activation = common_spec.Activation.Tanh, + self, + encoder: TransformerEncoderSpec, + pooling_layer: bool = False, + pooling_activation: common_spec.Activation = common_spec.Activation.Tanh, ): """Initializes a Transformer encoder model specification. diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 97b5669c1..2086cb976 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -1,4 +1,5 @@ #include "ctranslate2/layers/transformer.h" +#include "ctranslate2/sampling.h" #include @@ -54,6 +55,98 @@ namespace ctranslate2 { } } + Moe::Moe(const models::Model &model, const std::string &scope, const bool pre_norm, + const ops::ActivationType activation_type) + : _layer_norm(build_optional_layer(model, scope + "/layer_norm")) + , _gate(model, scope + "/gate") + , _pre_norm(pre_norm) + , _activation_type(activation_type) + , _num_experts_per_tok(model.get_attribute_with_default(scope + "/num_experts_per_tok", 2)) + , _ffn_layers(build_layers_list( + model, + scope + "/experts", + pre_norm, + activation_type)) { + } + + void Moe::operator()(ctranslate2::StorageView &input, ctranslate2::StorageView &output) const { + auto orig_shape = input.shape(); + StorageView* x = &input; + if (_layer_norm && _pre_norm) { + (*_layer_norm)(input, output); + x = &output; + } + + const Device device = input.device(); + const DataType dtype = input.dtype(); + + x->reshape({-1, x->dim(-1)}); + StorageView score(dtype, device); + // gate + _gate(*x, score); + + StorageView expert_indices(DataType::INT32); + StorageView expert_weights(dtype); + // topk + const BestSampler sampler; + sampler(score, expert_indices, expert_weights, _num_experts_per_tok); + if (device != Device::CPU) + expert_weights = expert_weights.to(device); + + StorageView f_used(dtype, device); + ops::SoftMax()(expert_weights, f_used); + expert_weights = std::move(f_used); + + expert_indices.reshape({-1}); + ops::Tile(0, _num_experts_per_tok)(*x); + f_used.resize(x->shape()); + + auto expert_indices_vector = expert_indices.to_vector(); + auto expert_weight_shape = expert_weights.shape(); + expert_weight_shape.push_back(-1); + expert_weights.reshape({-1}); + + for (int i = 0; i < expert_indices.size(); ++i) { + ops::Slide slide_ops(0, i, 1, true); + StorageView xtmp(dtype, device); + StorageView ytmp(dtype, device); + StorageView tmp_weight(dtype, device); + slide_ops(*x, xtmp); + slide_ops(f_used, ytmp); + slide_ops(expert_weights, tmp_weight); + (*_ffn_layers[expert_indices_vector[i]])(xtmp, ytmp); + ops::Mul()(ytmp, tmp_weight.to(Device::CPU).reshape({}), ytmp); + } + f_used.reshape(expert_weight_shape); + + std::vector partitions(_num_experts_per_tok, 1); + std::vector output_list(_num_experts_per_tok, StorageView(dtype, device)); + std::vector p_output_list; + p_output_list.reserve(output_list.size()); // Reserve space for efficiency + + // Convert objects to pointers + std::transform(output_list.begin(), output_list.end(), std::back_inserter(p_output_list), + [](StorageView& obj) { return &obj; }); + ops::Split split_ops(1, partitions); + split_ops(f_used, p_output_list); + auto shape_output = f_used.shape(); + shape_output[1] = 1; + ops::Slide(1, 0, 1)(f_used, output); + while (output_list.size() > 1) { + StorageView item = output_list.back(); + ops::Add()(output, item, output); + output_list.pop_back(); + } + output.reshape(std::move(orig_shape)); + + if (_layer_norm) { + ops::Add()(input, output, output); + + if (!_pre_norm) + (*_layer_norm)(output, output); + } + } + TransformerEncoderLayer::TransformerEncoderLayer(const models::Model& model, const std::string& scope, @@ -89,6 +182,16 @@ namespace ctranslate2 { _ff(context, output); } + static std::unique_ptr make_moe(const models::Model& model, + const std::string& scope, + const bool pre_norm = true, + const ops::ActivationType activation_type = ops::ActivationType::ReLU) { + const dim_t num_experts_per_tok = model.get_attribute_with_default(scope + "/num_experts_per_tok", -1); + if (num_experts_per_tok < 0) + return nullptr; + return std::make_unique(model, scope, pre_norm, activation_type); + } + TransformerDecoderLayer::TransformerDecoderLayer(const models::Model& model, const std::string& scope, @@ -113,7 +216,8 @@ namespace ctranslate2 { /*self_attention=*/false, pre_norm, /*is_decoder=*/true)) - , _ff(model, scope + "/ffn", pre_norm, activation_type) { + , _moe(make_moe(model, scope + "/moe", pre_norm, activation_type)) + , _ff(!_moe ? std::make_unique(model, scope + "/ffn", pre_norm, activation_type) : nullptr) { } void TransformerDecoderLayer::operator()(const StorageView& input, @@ -164,7 +268,10 @@ namespace ctranslate2 { if (_post_attention_layer_norm) (*_post_attention_layer_norm)(input, hidden); - _ff(hidden, output); + if (_moe) + (*_moe)(hidden, output); + else + (*_ff)(hidden, output); ops::Add()(output, input, output); ops::Add()(output, attn, output); @@ -201,7 +308,10 @@ namespace ctranslate2 { context = std::move(output); } - _ff(context, output); + if (_moe) + (*_moe)(context, output); + else + (*_ff)(context, output); }