diff --git a/python/ctranslate2/converters/opennmt_py.py b/python/ctranslate2/converters/opennmt_py.py index 8143977e0..cbc40d7ea 100644 --- a/python/ctranslate2/converters/opennmt_py.py +++ b/python/ctranslate2/converters/opennmt_py.py @@ -103,8 +103,12 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd with_alibi = getattr(opt, "max_relative_positions", 0) == -2 activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu") num_heads = getattr(opt, "heads", 8) + num_kv = getattr(opt, "num_kv", 0) + if num_kv == num_heads: + num_kv = None rotary_dim = 0 if with_rotary else None ffn_glu = activation_fn == "silu" + sliding_window = getattr(opt, "sliding_window", 0) model_spec = transformer_spec.TransformerDecoderModelSpec.from_config( opt.dec_layers, @@ -117,8 +121,12 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd rotary_dim=rotary_dim, rotary_interleave=True, multi_query_attention=getattr(opt, "multiquery", False), + num_heads_kv=num_kv, + sliding_window=sliding_window, ) + model_spec.config.layer_norm_epsilon = getattr(opt, "norm_eps", 1e-6) + set_transformer_decoder( model_spec.decoder, variables,