Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL bug: DeepSeek-V2-Lite-Chat-Q4_K_M does not work as expected #12390

Closed
aubreyli opened this issue Mar 14, 2025 · 18 comments · Fixed by #12399
Closed

SYCL bug: DeepSeek-V2-Lite-Chat-Q4_K_M does not work as expected #12390

aubreyli opened this issue Mar 14, 2025 · 18 comments · Fixed by #12399
Labels
bug Something isn't working

Comments

@aubreyli
Copy link
Contributor

aubreyli commented Mar 14, 2025

Name and Version

root@alc-ai:/home/aubrey/work/llama-gpu# ./build/bin/llama-cli --version
version: 4887 (8fcb563)
built with Intel(R) oneAPI DPC++/C++ Compiler 2025.0.4 (2025.0.4.20241205) for x86_64-unknown-linux-gnu

Operating systems

No response

Which llama.cpp modules do you know to be affected?

No response

Command line

 ./build/bin/llama-cli -m /srv/models/DeepSeek-V2-Lite-Chat-Q4_K_M/DeepSeek-V2-Lite-64x1.5B-Chat-Q4_K_M.gguf -ngl 99 -sm none -mg 0 -p "what is your name?" -n 30 -no-cnv

Problem description & steps to reproduce

root@alc-ai:/home/aubrey/work/llama-gpu# ./build/bin/llama-cli -m /srv/models/DeepSeek-V2-Lite-Chat-Q4_K_M/DeepSeek-V2-Lite-64x1.5B-Chat-Q4_K_M.gguf -ngl 99 -sm none -mg 0 -p "what is your name?" -n 30 -no-cnv
build: 4887 (8fcb563) with Intel(R) oneAPI DPC++/C++ Compiler 2025.0.4 (2025.0.4.20241205) for x86_64-unknown-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device SYCL0 (Intel(R) Arc(TM) A770 Graphics) - 15473 MiB free
llama_model_loader: loaded meta data with 47 key-value pairs and 377 tensors from /srv/models/DeepSeek-V2-Lite-Chat-Q4_K_M/DeepSeek-V2-Lite-64x1.5B-Chat-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = deepseek2
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.name str = DeepSeek V2 Lite Chat
llama_model_loader: - kv 3: general.finetune str = Chat
llama_model_loader: - kv 4: general.basename str = DeepSeek-V2-Lite
llama_model_loader: - kv 5: general.size_label str = 64x1.5B
llama_model_loader: - kv 6: general.license str = other
llama_model_loader: - kv 7: general.license.name str = deepseek
llama_model_loader: - kv 8: general.license.link str = https://github.com/deepseek-ai/DeepSe...
llama_model_loader: - kv 9: deepseek2.block_count u32 = 27
llama_model_loader: - kv 10: deepseek2.context_length u32 = 163840
llama_model_loader: - kv 11: deepseek2.embedding_length u32 = 2048
llama_model_loader: - kv 12: deepseek2.feed_forward_length u32 = 10944
llama_model_loader: - kv 13: deepseek2.attention.head_count u32 = 16
llama_model_loader: - kv 14: deepseek2.attention.head_count_kv u32 = 16
llama_model_loader: - kv 15: deepseek2.rope.freq_base f32 = 10000.000000
llama_model_loader: - kv 16: deepseek2.attention.layer_norm_rms_epsilon f32 = 0.000001
llama_model_loader: - kv 17: deepseek2.expert_used_count u32 = 6
llama_model_loader: - kv 18: deepseek2.leading_dense_block_count u32 = 1
llama_model_loader: - kv 19: deepseek2.vocab_size u32 = 102400
llama_model_loader: - kv 20: deepseek2.attention.kv_lora_rank u32 = 512
llama_model_loader: - kv 21: deepseek2.attention.key_length u32 = 192
llama_model_loader: - kv 22: deepseek2.attention.value_length u32 = 128
llama_model_loader: - kv 23: deepseek2.expert_feed_forward_length u32 = 1408
llama_model_loader: - kv 24: deepseek2.expert_count u32 = 64
llama_model_loader: - kv 25: deepseek2.expert_shared_count u32 = 2
llama_model_loader: - kv 26: deepseek2.expert_weights_scale f32 = 1.000000
llama_model_loader: - kv 27: deepseek2.expert_weights_norm bool = false
llama_model_loader: - kv 28: deepseek2.expert_gating_func u32 = 1
llama_model_loader: - kv 29: deepseek2.rope.dimension_count u32 = 64
llama_model_loader: - kv 30: deepseek2.rope.scaling.type str = yarn
llama_model_loader: - kv 31: deepseek2.rope.scaling.factor f32 = 40.000000
llama_model_loader: - kv 32: deepseek2.rope.scaling.original_context_length u32 = 4096
llama_model_loader: - kv 33: deepseek2.rope.scaling.yarn_log_multiplier f32 = 0.070700
llama_model_loader: - kv 34: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 35: tokenizer.ggml.pre str = deepseek-llm
llama_model_loader: - kv 36: tokenizer.ggml.tokens arr[str,102400] = ["!", """, "#", "$", "%", "&", "'", ...
llama_model_loader: - kv 37: tokenizer.ggml.token_type arr[i32,102400] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 38: tokenizer.ggml.merges arr[str,99757] = ["Ġ Ġ", "Ġ t", "Ġ a", "i n", "h e...
llama_model_loader: - kv 39: tokenizer.ggml.bos_token_id u32 = 100000
llama_model_loader: - kv 40: tokenizer.ggml.eos_token_id u32 = 100001
llama_model_loader: - kv 41: tokenizer.ggml.padding_token_id u32 = 100001
llama_model_loader: - kv 42: tokenizer.ggml.add_bos_token bool = true
llama_model_loader: - kv 43: tokenizer.ggml.add_eos_token bool = false
llama_model_loader: - kv 44: tokenizer.chat_template str = {% if not add_generation_prompt is de...
llama_model_loader: - kv 45: general.quantization_version u32 = 2
llama_model_loader: - kv 46: general.file_type u32 = 15
llama_model_loader: - type f32: 108 tensors
llama_model_loader: - type q5_0: 14 tensors
llama_model_loader: - type q8_0: 13 tensors
llama_model_loader: - type q4_K: 229 tensors
llama_model_loader: - type q6_K: 13 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type = Q4_K - Medium
print_info: file size = 9.65 GiB (5.28 BPW)
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 2
load: token to piece cache size = 0.6408 MB
print_info: arch = deepseek2
print_info: vocab_only = 0
print_info: n_ctx_train = 163840
print_info: n_embd = 2048
print_info: n_layer = 27
print_info: n_head = 16
print_info: n_head_kv = 16
print_info: n_rot = 64
print_info: n_swa = 0
print_info: n_swa_pattern = 1
print_info: n_embd_head_k = 192
print_info: n_embd_head_v = 128
print_info: n_gqa = 1
print_info: n_embd_k_gqa = 3072
print_info: n_embd_v_gqa = 2048
print_info: f_norm_eps = 0.0e+00
print_info: f_norm_rms_eps = 1.0e-06
print_info: f_clamp_kqv = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale = 0.0e+00
print_info: f_attn_scale = 0.0e+00
print_info: n_ff = 10944
print_info: n_expert = 64
print_info: n_expert_used = 6
print_info: causal attn = 1
print_info: pooling type = 0
print_info: rope type = 0
print_info: rope scaling = yarn
print_info: freq_base_train = 10000.0
print_info: freq_scale_train = 0.025
print_info: n_ctx_orig_yarn = 4096
print_info: rope_finetuned = unknown
print_info: ssm_d_conv = 0
print_info: ssm_d_inner = 0
print_info: ssm_d_state = 0
print_info: ssm_dt_rank = 0
print_info: ssm_dt_b_c_rms = 0
print_info: model type = 16B
print_info: model params = 15.71 B
print_info: general.name = DeepSeek V2 Lite Chat
print_info: n_layer_dense_lead = 1
print_info: n_lora_q = 0
print_info: n_lora_kv = 512
print_info: n_ff_exp = 1408
print_info: n_expert_shared = 2
print_info: expert_weights_scale = 1.0
print_info: expert_weights_norm = 0
print_info: expert_gating_func = softmax
print_info: rope_yarn_log_mul = 0.0707
print_info: vocab type = BPE
print_info: n_vocab = 102400
print_info: n_merges = 99757
print_info: BOS token = 100000 '<|begin▁of▁sentence|>'
print_info: EOS token = 100001 '<|end▁of▁sentence|>'
print_info: EOT token = 100001 '<|end▁of▁sentence|>'
print_info: PAD token = 100001 '<|end▁of▁sentence|>'
print_info: LF token = 185 'Ċ'
print_info: EOG token = 100001 '<|end▁of▁sentence|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 27 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 28/28 layers to GPU
load_tensors: CPU_Mapped model buffer size = 112.50 MiB
load_tensors: SYCL0 model buffer size = 9767.98 MiB
.....................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch = 2048
llama_context: n_ubatch = 512
llama_context: causal_attn = 1
llama_context: flash_attn = 0
llama_context: freq_base = 10000.0
llama_context: freq_scale = 0.025
llama_context: n_ctx_per_seq (4096) < n_ctx_train (163840) -- the full capacity of the model will not be utilized
Running with Environment Variables:
GGML_SYCL_DEBUG: 0
GGML_SYCL_DISABLE_OPT: 0
Build with Macros:
GGML_SYCL_FORCE_MMQ: no
GGML_SYCL_F16: no
Found 2 SYCL devices:
| | | | |Max | |Max |Global | |
| | | | |compute|Max work|sub |mem | |

ID Device Type Name Version units group group size Driver version
0 [level_zero:gpu:0] Intel Arc A770 Graphics 12.55 512 1024 32 16225M 1.6.32224+14
1 [level_zero:gpu:1] Intel UHD Graphics 770 12.2 32 512 32 62707M 1.6.32224+14
SYCL Optimization Feature:
ID Device Type Reorder
-- ------------------- -------
0 [level_zero:gpu:0] Y
1 [level_zero:gpu:1] N
llama_context: SYCL_Host output buffer size = 0.39 MiB
init: kv_size = 4096, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 27, can_shift = 0
init: SYCL0 KV buffer size = 1080.00 MiB
llama_context: KV self size = 1080.00 MiB, K (f16): 648.00 MiB, V (f16): 432.00 MiB
llama_context: SYCL0 compute buffer size = 213.03 MiB
llama_context: SYCL_Host compute buffer size = 12.01 MiB
llama_context: graph nodes = 1924
llama_context: graph splits = 2
common_init_from_params: KV cache shifting is not supported for this context, disabling KV cache shifting
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 8

system_info: n_threads = 8 (n_threads_batch = 8) / 32 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

sampler seed: 2656463
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = 30, n_keep = 1

what is your name? is the difference between a man and a boy?
2 Answers | Add Yours
A man is an adult human male, while a boy is a

llama_perf_sampler_print: sampling time = 1.10 ms / 36 runs ( 0.03 ms per token, 32786.89 tokens per second)
llama_perf_context_print: load time = 3147.22 ms
llama_perf_context_print: prompt eval time = 288.22 ms / 6 tokens ( 48.04 ms per token, 20.82 tokens per second)
llama_perf_context_print: eval time = 1660.91 ms / 29 runs ( 57.27 ms per token, 17.46 tokens per second)
llama_perf_context_print: total time = 1952.91 ms / 35 tokens

First Bad Commit

No response

Relevant log output

@fairydreaming
Copy link
Collaborator

@aubreyli In non-interactive mode (-no-cnv) you have to include a proper prompt template in your prompt. For DeepSeek-V2 Lite it will be something like this: -p "User: what is your name?\n\nAssistant:"

@aubreyli
Copy link
Contributor Author

@fairydreaming same issue here:

root@alc-ai:/home/aubrey/work/llama-gpu# ./build/bin/llama-cli -m /srv/models/DeepSeek-V2-Lite-Chat-Q4_K_M/DeepSeek-V2-Lite-64x1.5B-Chat-Q4_K_M.gguf -ngl 99 -sm none -mg 0 -p "User: what is your name?\n\nAssistant:" -n 30 -no-cnv
build: 4887 (8fcb563) with Intel(R) oneAPI DPC++/C++ Compiler 2025.0.4 (2025.0.4.20241205) for x86_64-unknown-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device SYCL0 (Intel(R) Arc(TM) A770 Graphics) - 15473 MiB free
llama_model_loader: loaded meta data with 47 key-value pairs and 377 tensors from /srv/models/DeepSeek-V2-Lite-Chat-Q4_K_M/DeepSeek-V2-Lite-64x1.5B-Chat-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.

--snip--

sampler seed: 457823384
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = 30, n_keep = 1

User: what is your name?

Assistant: This is my first time using this program. I have a question about how to use it.

Sure, I'd be happy to

llama_perf_sampler_print: sampling time = 1.37 ms / 42 runs ( 0.03 ms per token, 30746.71 tokens per second)
llama_perf_context_print: load time = 3257.36 ms
llama_perf_context_print: prompt eval time = 501.55 ms / 12 tokens ( 41.80 ms per token, 23.93 tokens per second)
llama_perf_context_print: eval time = 1659.25 ms / 29 runs ( 57.22 ms per token, 17.48 tokens per second)
llama_perf_context_print: total time = 2165.43 ms / 41 tokens

@fairydreaming
Copy link
Collaborator

@aubreyli Hmm, that's weird. Where can I download this model file?

@aubreyli
Copy link
Contributor Author

@fairydreaming you can download it from here:
https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/tree/main
And convert it to Q4_K_M by the tools in llama.cpp.

@aubreyli
Copy link
Contributor Author

@fairydreaming I have the same model files working properly by CUDA.

$ ./build/bin/llama-cli -m /srv/models/DeepSeek-V2-Lite-Chat-Q4_K_M/DeepSeek-V2-Lite-64x1.5B-Chat-Q4_K_M.gguf -ngl 99 -sm none -mg 0 -p "User: what is your name?\n\nAssistant:" -n 30 -no-cnv
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
build: 4790 (438a839) with cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 for x86_64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 3060) - 11809 MiB free
llama_model_loader: loaded meta data with 47 key-value pairs and 377 tensors from /srv/models/DeepSeek-V2-Lite-Chat-Q4_K_M/DeepSeek-V2-Lite-64x1.5B-Chat-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.

----snip----

sampler seed: 563822659
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = 30, n_keep = 1

User: what is your name?

Assistant: I am DeepSeek Chat, an intelligent assistant developed by DeepSeek company. [end of text]

llama_perf_sampler_print: sampling time = 1.47 ms / 28 runs ( 0.05 ms per token, 19099.59 tokens per second)
llama_perf_context_print: load time = 2124.35 ms
llama_perf_context_print: prompt eval time = 81.80 ms / 12 tokens ( 6.82 ms per token, 146.69 tokens per second)
llama_perf_context_print: eval time = 197.45 ms / 15 runs ( 13.16 ms per token, 75.97 tokens per second)
llama_perf_context_print: total time = 289.15 ms / 27 tokens

@fairydreaming
Copy link
Collaborator

fairydreaming commented Mar 14, 2025

@aubreyli I confirm the problem, when using DeepSeek V2 Lite running on a GPU in a SYCL build (I used my RTX 4090 card) the model generates nonsense answers.

User: what is your name?

Assistant:entrenament d'equitació

Instruccions:

1. Llegeix l'escenari i escolta l'espectacle.

Model answers look somewhat coherent, but they seem to ignore the user prompt.

@fairydreaming fairydreaming added bug Something isn't working and removed bug-unconfirmed labels Mar 14, 2025
@qnixsynapse
Copy link
Contributor

Please run it with GGML_SYCL_DEBUG=1 set in env.

@fairydreaming
Copy link
Collaborator

Please run it with GGML_SYCL_DEBUG=1 set in env.

sycl_debug.txt

@qnixsynapse
Copy link
Contributor

qnixsynapse commented Mar 14, 2025

Please run it with GGML_SYCL_DEBUG=1 set in env.

sycl_debug.txt

Interesting. Does test-backend-ops passes without fail?

@fairydreaming
Copy link
Collaborator

@qnixsynapse Yeah:

$ ./bin/test-backend-ops test -b SYCL0
Testing 2 devices

Backend 1/2: SYCL0
Running with Environment Variables:
  GGML_SYCL_DEBUG: 0
  GGML_SYCL_DISABLE_OPT: 0
Build with Macros:
  GGML_SYCL_FORCE_MMQ: no
  GGML_SYCL_F16: no
Found 1 SYCL devices:
|  |                   |                                       |       |Max    |        |Max  |Global |                     |
|  |                   |                                       |       |compute|Max work|sub  |mem    |                     |
|ID|        Device Type|                                   Name|Version|units  |group   |group|size   |       Driver version|
|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|
| 0|       [cuda:gpu:0]|                NVIDIA GeForce RTX 4090|    8.9|    128|    1024|   32| 25386M|            CUDA 12.4|
SYCL Optimization Feature:
|ID|        Device Type|Reorder|
|--|-------------------|-------|
| 0|       [cuda:gpu:0]|      N|
  Device description: NVIDIA GeForce RTX 4090
  Device memory: 24210 MB (23818 MB free)

  ABS(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  ABS(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  SGN(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  SGN(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  NEG(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  NEG(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  STEP(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  STEP(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  TANH(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  TANH(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  ELU(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  ELU(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  RELU(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  RELU(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  SIGMOID(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  SIGMOID(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  GELU(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  GELU(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  GELU_QUICK(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  GELU_QUICK(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  SILU(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  SILU(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  HARDSWISH(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  HARDSWISH(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  HARDSIGMOID(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  HARDSIGMOID(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  EXP(type=f16,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  EXP(type=f16,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  ABS(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  ABS(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  SGN(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  SGN(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  NEG(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  NEG(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  STEP(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  STEP(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  TANH(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  TANH(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  ELU(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  ELU(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  RELU(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  RELU(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  SIGMOID(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  SIGMOID(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  GELU(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  GELU(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  GELU_QUICK(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  GELU_QUICK(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  SILU(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  SILU(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  HARDSWISH(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  HARDSWISH(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  HARDSIGMOID(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  HARDSIGMOID(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  EXP(type=f16,ne_a=[128,2,2,2],v=1): not supported [SYCL0] 
  EXP(type=f16,ne_a=[5,7,11,13],v=1): not supported [SYCL0] 
  ABS(type=f32,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  ABS(type=f32,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  SGN(type=f32,ne_a=[128,2,2,2],v=0): not supported [SYCL0] 
  SGN(type=f32,ne_a=[5,7,11,13],v=0): not supported [SYCL0] 
  NEG(type=f32,ne_a=[128,2,2,2],v=0): OK
  NEG(type=f32,ne_a=[5,7,11,13],v=0): OK
  STEP(type=f32,ne_a=[128,2,2,2],v=0): OK
  STEP(type=f32,ne_a=[5,7,11,13],v=0): OK
  TANH(type=f32,ne_a=[128,2,2,2],v=0): OK
  TANH(type=f32,ne_a=[5,7,11,13],v=0): OK
...
  CROSS_ENTROPY_LOSS(type=f32,ne=[10,5,4,3]): not supported [SYCL0] 
  CROSS_ENTROPY_LOSS(type=f32,ne=[30000,1,1,1]): not supported [SYCL0] 
  CROSS_ENTROPY_LOSS_BACK(type=f32,ne=[10,5,4,3]): not supported [SYCL0] 
  CROSS_ENTROPY_LOSS_BACK(type=f32,ne=[30000,1,1,1]): not supported [SYCL0] 
  OPT_STEP_ADAMW(type=f32,ne=[10,5,4,3]): not supported [SYCL0] 
  4193/4193 tests passed
  Backend SYCL0: OK

In attached zip there are txt files with printed tensor values from CPU and SYCL backends. They seem to diverge more and more as the inference progresses.

tensors.zip

@fairydreaming
Copy link
Collaborator

I narrowed the problem to GGML_OP_MUL_MAT/GGML_OP_MUL_MAT_ID. When I remove them from the list of operations supported by SYCL everything starts working correctly.

@qnixsynapse
Copy link
Contributor

They seem to diverge more and more as the inference progresses.

Does #12366 helps? I know there is a really bad memory leak in the backend.

Might also need to test with this: #12391

Unfortunately, I only have an A750 so can't test this model here.

@fairydreaming
Copy link
Collaborator

fairydreaming commented Mar 15, 2025

@qnixsynapse I narrowed it further to multiplication of these tensors (I filtered them by size):

name ne0 ne1 type
blk.*.ffn_down_exps.weight 1408 2048 8
blk.*.ffn_gate_shexp.weight 2048 2816 8
blk.*.ffn_up_shexp.weight 2048 2816 8

Remaining multiplications can be offloaded to SYCL without any negative consequences.

@fairydreaming
Copy link
Collaborator

I found the cause and it has nothing to do with matrix multiplication. The problem is in addition of tensor views in build_moe_ffn(). I added ggml_dup() to change it to addition of contiguous tensors and everything started working correctly:

diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 4e908733..3f7c99c3 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -933,8 +933,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
     // aggregate experts
     ggml_tensor * moe_out = nullptr;
     for (int i = 0; i < n_expert_used; ++i) {
-        ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
-                experts->nb[2], i*experts->nb[1]);
+        ggml_tensor * cur_expert = ggml_dup(ctx0, ggml_view_2d(ctx0, experts, n_embd, n_tokens,
+                experts->nb[2], i*experts->nb[1]));
 
         if (i == 0) {
             moe_out = cur_expert;

@aubreyli
Copy link
Contributor Author

This patch works on my side as well, Thanks @fairydreaming!

@fairydreaming
Copy link
Collaborator

@aubreyli I made a PR to fix this by supporting non-contiguous tensors in SYCL backend binary ops, if you have time please test if it works: #12399

@qnixsynapse
Copy link
Contributor

@fairydreaming Excellent work! Thank you! I think adding some nc tests for binary ops in tests-backend-ops will be useful too.

@aubreyli
Copy link
Contributor Author

@fairydreaming I tested #12399 works on my Arc770 for DeepSeek-V2-Lite-Chat model. Thanks for your great work!
I added a minor comment embedded the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants