Skip to content

Commit aa7ccec

Browse files
Inf1delisarthw
authored andcommitted
llama : add Deepseek MoE v1 & GigaChat models (ggml-org#10827)
* Add deepseek v1 arch & gigachat template * improve template code * add readme * delete comments * remove comment * fix format * lint llama.cpp * fix order of deepseek and deepseek2, move gigachat temlate to the end of func * fix order of deepseek and deepseek2 in constants; mark shared exp as deepseek arch need * remove comments * move deepseek above deepseek2 * change placement of gigachat chat template
1 parent 1fa6dc2 commit aa7ccec

7 files changed

+423
-3
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
132132
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
133133
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
134134
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
135+
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
135136

136137
#### Multimodal
137138

convert_hf_to_gguf.py

+94
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
664664
if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65":
665665
# ref: https://huggingface.co/sentence-transformers/stsb-roberta-base
666666
res = "roberta-bpe"
667+
if chkhsh == "ad851be1dba641f2e3711822f816db2c265f788b37c63b4e1aeacb9ee92de8eb":
668+
# ref: https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct
669+
res = "gigachat"
667670

668671
if res is None:
669672
logger.warning("\n")
@@ -3427,6 +3430,97 @@ def prepare_tensors(self):
34273430
raise ValueError(f"Unprocessed experts: {experts}")
34283431

34293432

3433+
@Model.register("DeepseekForCausalLM")
3434+
class DeepseekModel(Model):
3435+
model_arch = gguf.MODEL_ARCH.DEEPSEEK
3436+
3437+
def set_vocab(self):
3438+
try:
3439+
self._set_vocab_sentencepiece()
3440+
except FileNotFoundError:
3441+
self._set_vocab_gpt2()
3442+
3443+
def set_gguf_parameters(self):
3444+
super().set_gguf_parameters()
3445+
hparams = self.hparams
3446+
if "head_dim" in hparams:
3447+
rope_dim = hparams["head_dim"]
3448+
else:
3449+
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
3450+
3451+
self.gguf_writer.add_rope_dimension_count(rope_dim)
3452+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
3453+
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
3454+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
3455+
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
3456+
self.gguf_writer.add_expert_weights_scale(1.0)
3457+
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
3458+
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
3459+
3460+
_experts: list[dict[str, Tensor]] | None = None
3461+
3462+
@staticmethod
3463+
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
3464+
if n_head_kv is not None and n_head != n_head_kv:
3465+
n_head = n_head_kv
3466+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
3467+
.swapaxes(1, 2)
3468+
.reshape(weights.shape))
3469+
3470+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3471+
n_head = self.hparams["num_attention_heads"]
3472+
n_kv_head = self.hparams.get("num_key_value_heads")
3473+
3474+
if name.endswith(("q_proj.weight", "q_proj.bias")):
3475+
data_torch = DeepseekModel.permute(data_torch, n_head, n_head)
3476+
if name.endswith(("k_proj.weight", "k_proj.bias")):
3477+
data_torch = DeepseekModel.permute(data_torch, n_head, n_kv_head)
3478+
3479+
# process the experts separately
3480+
if name.find("mlp.experts") != -1:
3481+
n_experts = self.hparams["n_routed_experts"]
3482+
assert bid is not None
3483+
3484+
if self._experts is None:
3485+
self._experts = [{} for _ in range(self.block_count)]
3486+
3487+
self._experts[bid][name] = data_torch
3488+
3489+
if len(self._experts[bid]) >= n_experts * 3:
3490+
tensors: list[tuple[str, Tensor]] = []
3491+
3492+
# merge the experts into a single 3d tensor
3493+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
3494+
datas: list[Tensor] = []
3495+
3496+
for xid in range(n_experts):
3497+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
3498+
datas.append(self._experts[bid][ename])
3499+
del self._experts[bid][ename]
3500+
3501+
data_torch = torch.stack(datas, dim=0)
3502+
3503+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
3504+
3505+
new_name = self.map_tensor_name(merged_name)
3506+
3507+
tensors.append((new_name, data_torch))
3508+
return tensors
3509+
else:
3510+
return []
3511+
3512+
return [(self.map_tensor_name(name), data_torch)]
3513+
3514+
def prepare_tensors(self):
3515+
super().prepare_tensors()
3516+
3517+
if self._experts is not None:
3518+
# flatten `list[dict[str, Tensor]]` into `list[str]`
3519+
experts = [k for d in self._experts for k in d.keys()]
3520+
if len(experts) > 0:
3521+
raise ValueError(f"Unprocessed experts: {experts}")
3522+
3523+
34303524
@Model.register("DeepseekV2ForCausalLM")
34313525
class DeepseekV2Model(Model):
34323526
model_arch = gguf.MODEL_ARCH.DEEPSEEK2

convert_hf_to_gguf_update.py

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class TOKENIZER_TYPE(IntEnum):
104104
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
105105
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", },
106106
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
107+
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
107108
]
108109

109110

gguf-py/gguf/constants.py

+29
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class MODEL_ARCH(IntEnum):
249249
OLMOE = auto()
250250
OPENELM = auto()
251251
ARCTIC = auto()
252+
DEEPSEEK = auto()
252253
DEEPSEEK2 = auto()
253254
CHATGLM = auto()
254255
BITNET = auto()
@@ -412,6 +413,7 @@ class MODEL_TENSOR(IntEnum):
412413
MODEL_ARCH.OLMOE: "olmoe",
413414
MODEL_ARCH.OPENELM: "openelm",
414415
MODEL_ARCH.ARCTIC: "arctic",
416+
MODEL_ARCH.DEEPSEEK: "deepseek",
415417
MODEL_ARCH.DEEPSEEK2: "deepseek2",
416418
MODEL_ARCH.CHATGLM: "chatglm",
417419
MODEL_ARCH.BITNET: "bitnet",
@@ -1158,6 +1160,29 @@ class MODEL_TENSOR(IntEnum):
11581160
MODEL_TENSOR.FFN_DOWN_EXP,
11591161
MODEL_TENSOR.FFN_UP_EXP,
11601162
],
1163+
MODEL_ARCH.DEEPSEEK: [
1164+
MODEL_TENSOR.TOKEN_EMBD,
1165+
MODEL_TENSOR.OUTPUT_NORM,
1166+
MODEL_TENSOR.OUTPUT,
1167+
MODEL_TENSOR.ROPE_FREQS,
1168+
MODEL_TENSOR.ATTN_NORM,
1169+
MODEL_TENSOR.ATTN_Q,
1170+
MODEL_TENSOR.ATTN_K,
1171+
MODEL_TENSOR.ATTN_V,
1172+
MODEL_TENSOR.ATTN_OUT,
1173+
MODEL_TENSOR.ATTN_ROT_EMBD,
1174+
MODEL_TENSOR.FFN_GATE_INP,
1175+
MODEL_TENSOR.FFN_NORM,
1176+
MODEL_TENSOR.FFN_GATE,
1177+
MODEL_TENSOR.FFN_DOWN,
1178+
MODEL_TENSOR.FFN_UP,
1179+
MODEL_TENSOR.FFN_GATE_EXP,
1180+
MODEL_TENSOR.FFN_DOWN_EXP,
1181+
MODEL_TENSOR.FFN_UP_EXP,
1182+
MODEL_TENSOR.FFN_GATE_SHEXP,
1183+
MODEL_TENSOR.FFN_DOWN_SHEXP,
1184+
MODEL_TENSOR.FFN_UP_SHEXP,
1185+
],
11611186
MODEL_ARCH.DEEPSEEK2: [
11621187
MODEL_TENSOR.TOKEN_EMBD,
11631188
MODEL_TENSOR.OUTPUT_NORM,
@@ -1380,6 +1405,10 @@ class MODEL_TENSOR(IntEnum):
13801405
MODEL_TENSOR.ROPE_FREQS,
13811406
MODEL_TENSOR.ATTN_ROT_EMBD,
13821407
],
1408+
MODEL_ARCH.DEEPSEEK: [
1409+
MODEL_TENSOR.ROPE_FREQS,
1410+
MODEL_TENSOR.ATTN_ROT_EMBD,
1411+
],
13831412
MODEL_ARCH.DEEPSEEK2: [
13841413
MODEL_TENSOR.ROPE_FREQS,
13851414
MODEL_TENSOR.ATTN_ROT_EMBD,

gguf-py/gguf/tensor_mapping.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ class TensorNameMap:
306306

307307
MODEL_TENSOR.FFN_UP_SHEXP: (
308308
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
309-
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
309+
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
310310
),
311311

312312
# AWQ-activation gate
@@ -338,7 +338,7 @@ class TensorNameMap:
338338

339339
MODEL_TENSOR.FFN_GATE_SHEXP: (
340340
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
341-
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
341+
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
342342
),
343343

344344
# Feed-forward down
@@ -379,7 +379,7 @@ class TensorNameMap:
379379

380380
MODEL_TENSOR.FFN_DOWN_SHEXP: (
381381
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
382-
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
382+
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
383383
),
384384

385385
MODEL_TENSOR.ATTN_Q_NORM: (

0 commit comments

Comments
 (0)