Skip to content

Commit df0d4c3

Browse files
committed
llama-model : add dots.llm1 architecture support (#14044)
Adds: * Dots1Model to convert_hf_to_gguf.py * Computation graph code to llama-model.cpp * Chat template to llama-chat.cpp to detect this model's template. --- The model is called "dots.llm1" (I decided to shorten it to dots1 or DOTS1 in the code generally) architecture. The only models that exist as of writing of this commit that follow this architecture are "dots.llm1.inst" and "dots.llm1.base" from here: * https://huggingface.co/rednote-hilab/dots.llm1.inst * https://huggingface.co/rednote-hilab/dots.llm1.base The model architecture is a combination of Qwen and Deepseek parts, as seen here: https://github.com/huggingface/transformers/blob/ffe12627b4e84489d2ab91dd0ec00614855edc79/src/transformers/models/dots1/modular_dots1.py
1 parent fb85a28 commit df0d4c3

File tree

9 files changed

+326
-1
lines changed

9 files changed

+326
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5262,6 +5262,34 @@ def prepare_tensors(self):
52625262
raise ValueError(f"Unprocessed experts: {experts}")
52635263

52645264

5265+
@ModelBase.register("Dots1ForCausalLM")
5266+
class Dots1Model(Qwen2MoeModel):
5267+
model_arch = gguf.MODEL_ARCH.DOTS1
5268+
5269+
def __init__(self, *args, **kwargs):
5270+
super().__init__(*args, **kwargs)
5271+
self.hparams["num_experts"] = self.hparams["n_routed_experts"]
5272+
5273+
def set_gguf_parameters(self):
5274+
super().set_gguf_parameters()
5275+
self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"])
5276+
self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"])
5277+
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
5278+
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
5279+
5280+
if self.hparams["scoring_func"] == "noaux_tc":
5281+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
5282+
else:
5283+
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
5284+
5285+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
5286+
if name.endswith("e_score_correction_bias"):
5287+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
5288+
if "shared_experts" in name:
5289+
return [(self.map_tensor_name(name), data_torch)]
5290+
return super().modify_tensors(data_torch, name, bid)
5291+
5292+
52655293
@ModelBase.register("PLMForCausalLM")
52665294
class PLMModel(TextModel):
52675295
model_arch = gguf.MODEL_ARCH.PLM

gguf-py/gguf/constants.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ class MODEL_ARCH(IntEnum):
343343
WAVTOKENIZER_DEC = auto()
344344
PLM = auto()
345345
BAILINGMOE = auto()
346+
DOTS1 = auto()
346347

347348

348349
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -623,6 +624,7 @@ class MODEL_TENSOR(IntEnum):
623624
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
624625
MODEL_ARCH.PLM: "plm",
625626
MODEL_ARCH.BAILINGMOE: "bailingmoe",
627+
MODEL_ARCH.DOTS1: "dots1"
626628
}
627629

628630
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2044,6 +2046,30 @@ class MODEL_TENSOR(IntEnum):
20442046
MODEL_TENSOR.FFN_DOWN_SHEXP,
20452047
MODEL_TENSOR.FFN_UP_SHEXP,
20462048
],
2049+
MODEL_ARCH.DOTS1: [
2050+
MODEL_TENSOR.TOKEN_EMBD,
2051+
MODEL_TENSOR.OUTPUT_NORM,
2052+
MODEL_TENSOR.OUTPUT,
2053+
MODEL_TENSOR.ATTN_NORM,
2054+
MODEL_TENSOR.ATTN_Q,
2055+
MODEL_TENSOR.ATTN_Q_NORM,
2056+
MODEL_TENSOR.ATTN_K,
2057+
MODEL_TENSOR.ATTN_K_NORM,
2058+
MODEL_TENSOR.ATTN_V,
2059+
MODEL_TENSOR.ATTN_OUT,
2060+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2061+
MODEL_TENSOR.FFN_NORM,
2062+
MODEL_TENSOR.FFN_GATE,
2063+
MODEL_TENSOR.FFN_GATE_EXP,
2064+
MODEL_TENSOR.FFN_GATE_INP,
2065+
MODEL_TENSOR.FFN_GATE_SHEXP,
2066+
MODEL_TENSOR.FFN_DOWN,
2067+
MODEL_TENSOR.FFN_DOWN_EXP,
2068+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2069+
MODEL_TENSOR.FFN_UP,
2070+
MODEL_TENSOR.FFN_UP_EXP,
2071+
MODEL_TENSOR.FFN_UP_SHEXP,
2072+
],
20472073
# TODO
20482074
}
20492075

gguf-py/gguf/tensor_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class TensorNameMap:
305305
),
306306

307307
MODEL_TENSOR.FFN_EXP_PROBS_B: (
308-
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
308+
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
309309
),
310310

311311
# Feed-forward up

src/llama-arch.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7272
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
7373
{ LLM_ARCH_PLM, "plm" },
7474
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
75+
{ LLM_ARCH_DOTS1, "dots1" },
7576
{ LLM_ARCH_UNKNOWN, "(unknown)" },
7677
};
7778

@@ -1555,6 +1556,34 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
15551556
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
15561557
},
15571558
},
1559+
{
1560+
LLM_ARCH_DOTS1,
1561+
{
1562+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1563+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1564+
{ LLM_TENSOR_OUTPUT, "output" },
1565+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1566+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1567+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1568+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1569+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1570+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1571+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1572+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1573+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1574+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1575+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1576+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1577+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1578+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1579+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1580+
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
1581+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1582+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1583+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1584+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1585+
}
1586+
},
15581587
{
15591588
LLM_ARCH_UNKNOWN,
15601589
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ enum llm_arch {
7676
LLM_ARCH_WAVTOKENIZER_DEC,
7777
LLM_ARCH_PLM,
7878
LLM_ARCH_BAILINGMOE,
79+
LLM_ARCH_DOTS1,
7980
LLM_ARCH_UNKNOWN,
8081
};
8182

src/llama-chat.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
183183
return LLM_CHAT_TEMPLATE_BAILING;
184184
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
185185
return LLM_CHAT_TEMPLATE_LLAMA4;
186+
} else if (tmpl_contains("<|endofuserprompt|>")) {
187+
return LLM_CHAT_TEMPLATE_DOTS1;
186188
}
187189
return LLM_CHAT_TEMPLATE_UNKNOWN;
188190
}
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
643645
if (add_ass) {
644646
ss << "Assistant:";
645647
}
648+
} else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
649+
// dots.llm1.inst (DOTS1)
650+
for (auto message : chat) {
651+
std::string role(message->role);
652+
if (role == "system") {
653+
ss << "<|system|>" << message->content << "<|endofsystem|>";
654+
} else if (role == "user") {
655+
ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
656+
} else {
657+
ss << "<|response|>" << message->content << "<|endofresponse|>";
658+
}
659+
}
660+
if (add_ass) {
661+
ss << "<|response|>";
662+
}
646663
} else {
647664
// template not supported
648665
return -1;

src/llama-chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ enum llm_chat_template {
4343
LLM_CHAT_TEMPLATE_BAILING,
4444
LLM_CHAT_TEMPLATE_LLAMA4,
4545
LLM_CHAT_TEMPLATE_SMOLVLM,
46+
LLM_CHAT_TEMPLATE_DOTS1,
4647
LLM_CHAT_TEMPLATE_UNKNOWN,
4748
};
4849

0 commit comments

Comments
 (0)