Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

imatrix : offload to GPU support #4957

Merged
merged 10 commits into from
Jan 17, 2024
Merged

imatrix : offload to GPU support #4957

merged 10 commits into from
Jan 17, 2024

Conversation

ggerganov
Copy link
Owner

close #4931

Make use of the new backend scheduler eval callback introduced in #4935 in order to grab activations from the GPU memory.

Usage:

# Metal
make -j imatrix && ./imatrix -m model.gguf -f data.txt -ngl 99
# CUDA
LLAMA_CUBLAS=1 make -j imatrix && ./imatrix -m model.gguf -f data.txt -ngl 99

The performance should be significantly faster. I haven't confirmed the correctness of the results yet, so please let me know when you give this a try and see if the numbers are as expected.

Copy link
Contributor

@ikawrakow ikawrakow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Tested on LLaMA-v1-7B. With 100 chunks calculation is ready in 52 seconds on the GPU (vs 9 minutes on the CPU)! Using the GPU imatrix with IQ2_XXS I get PPL = 8.5372. With the CPU imatrix I had PPL = 8.5435. So, it looks like it is working.

@Artefact2
Copy link
Collaborator

ROCm also seems to work on Linux.

make LLAMA_HIPBLAS=1 AMDGPU_TARGETS=gfx1030 imatrix

compute_imatrix: 1.41 seconds per pass - ETA 2 hours 9.52 minutes
[1]5.3541,[2]7.6622,[3]8.4022,[4]9.5683,[5]10.0294
make imatrix

compute_imatrix: 27.02 seconds per pass - ETA 41 hours 26.30 minutes
[1]5.3145,[2]7.6356,[3]8.4072,[4]9.5662,[5]10.0225

@TheBloke
Copy link
Contributor

TheBloke commented Jan 15, 2024

I've just made my first GGUF repo that uses the new imatrix method, here: https://huggingface.co/TheBloke/Yi-34B-200K-DARE-megamerge-v8-GGUF

I used this PR so as to speed up the imatrix creation.

On a 34B model, with a 5000-line (76,859 word - I didn't count the tokens) dataset, it took 21 minutes with -c 4096 -b 1024 on an L40 48GB GPU, and -t 10

I used this command:

CUDA_VISIBLE_DEVICES=7  ./imatrix -m /workspace/process/brucethemoose_yi-34b-200k-dare-megamerge-v8/gguf/yi-34b-200k-dare-megamerge-v8.fp16.gguf -f /workspace/datasets/open-instruct-5K.txt -o /workspace/process/brucethemoose_yi-34b-200k-dare-megamerge-v8/gguf/yi-34b-200k-dare-megamerge-v8.imatrix -t 10 -c 4096 -ngl 35 -b 1024

Hope I did it right! The model is coherent at least! :)

@ikawrakow
Copy link
Contributor

@TheBloke My experience is that it is better to use a context of 512 when computing the imatrix. When running inference with a quantized model where the imatrix calculation used a context of 512, I get a lower perplexity even for a context of 4096. Not sure why this is the case.

@TheBloke
Copy link
Contributor

Ah interesting, thanks for the info. I've not done any PPL testing on the result yet.

I'll try 512 next time then, thanks.

@kalomaze
Copy link
Contributor

kalomaze commented Jan 16, 2024

Not sure why this is the case.

My theory is that having more unique contexts in total is beneficial because it makes the diversity of how many "starting contexts" there are significantly larger, and therefore, you get more unique data for activations.
Maybe the sweetspot is like 128-256 tokens...

@askmyteapot
Copy link

askmyteapot commented Jan 17, 2024

@ggerganov
Just did a test. Does not work with partial GPU offloading.
I attempted doing it on a Mixtral 8x7B instruct at 8bit.

No errors generated, however the file size was one tenth of the CPU only run. Then, when i tried to quantize using the new imatrix, it couldnt find a boatload of layers in the matrix file. Worked fine for CPU only.

Let me know if you require any further details.

Loading info [and command used]
D:\llama.cpp - Copy\build\bin\Release>imatrix -m D:\text-generation-webui\models\Mixtral-8x7B-Instruct-v0.1\ggml-model-q8.gguf -f wiki.train.raw -ngl 15 -c 512 -o D:\text-generation-webui\models\Mixtral-8x7B-Instruct-v0.1\ggml-model-q8.gguf.imat.data --chunks 2000
main: build = 1884 (0b2fca9a)
main: built with MSVC 19.38.33133.0 for x64
main: seed  = 1705443892
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   yes
ggml_init_cublas: CUDA_USE_TENSOR_CORES: no
ggml_init_cublas: found 1 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
llama_model_loader: loaded meta data with 25 key-value pairs and 995 tensors from D:\text-generation-webui\models\Mixtral-8x7B-Instruct-v0.1\ggml-model-q8.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = F:\LLM_Models
llama_model_loader: - kv   2:                       llama.context_length u32              = 32768
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                         llama.expert_count u32              = 8
llama_model_loader: - kv  11:                    llama.expert_used_count u32              = 2
llama_model_loader: - kv  12:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  13:                          general.file_type u32              = 7
llama_model_loader: - kv  14:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  16:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  17:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  18:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  20:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  21:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  22:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  23:                    tokenizer.chat_template str              = {{ bos_token }}{% for message in mess...
llama_model_loader: - kv  24:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type  f16:   32 tensors
llama_model_loader: - type q8_0:  898 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 32768
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 8
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 32768
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 46.70 B
llm_load_print_meta: model size       = 46.22 GiB (8.50 BPW)
llm_load_print_meta: general.name     = F:\LLM_Models
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.76 MiB
llm_load_tensors: offloading 15 repeating layers to GPU
llm_load_tensors: offloaded 15/33 layers to GPU
llm_load_tensors:        CPU buffer size = 47324.64 MiB
llm_load_tensors:      CUDA0 buffer size = 22058.91 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    34.00 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =    30.00 MiB
llama_new_context_with_model: KV self size  =   64.00 MiB, K (f16):   32.00 MiB, V (f16):   32.00 MiB
llama_new_context_with_model: graph splits (measure): 5
llama_new_context_with_model:      CUDA0 compute buffer size =   108.03 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   109.04 MiB

system_info: n_threads = 8 / 16 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 |
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 6012.49 ms
compute_imatrix: computing over 2000 chunks with batch_size 512
compute_imatrix: 17.95 seconds per pass - ETA 9 hours 58.30 minutes 
Final line
save_imatrix: stored collected data after 2000 chunks in D:\text-generation-webui\models\Mixtral-8x7B-Instruct-v0.1\ggml-model-q8.gguf.imat.data
[2000]4.7332,
Final estimate: PPL = 4.7332 +/- 0.01479

save_imatrix: stored collected data after 2000 chunks in D:\text-generation-webui\models\Mixtral-8x7B-Instruct-v0.1\ggml-model-q8.gguf.imat.data

llama_print_timings:        load time =   35827.98 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time = 25796936.71 ms / 1024000 tokens (   25.19 ms per token,    39.69 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time = 25838386.90 ms / 1024001 tokens 

UPDATE:
I can't replicate it with Mistral 7b instruct, but can replicate it with mixtral 8x7b. Will do some more testing.

UPDATE2:
It's only Mixtral 8x7b that has the failure. I tested Mistral 7b with various combinations of f16/q8_0 gguf, full and partial offloading. All tested Mistral 7b combinations resulted in expected file size and outcome.
So there is something with the GPU offloading change that has impacted calculating imatrix in MoE models.

@ggerganov
Copy link
Owner Author

Yes, there is an issue with Mixtral - will look into fixing it

@ikawrakow
Copy link
Contributor

If I try Mixtral-8x7B with partial offload with --verbosity 2, I see that only the attention related activation are collected:

...
llm_load_tensors: offloading 10 repeating layers to GPU
llm_load_tensors: offloaded 10/33 layers to GPU
llm_load_tensors:        CPU buffer size = 47324.64 MiB
llm_load_tensors:      CUDA0 buffer size = 14705.94 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    44.00 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =    20.00 MiB
llama_new_context_with_model: KV self size  =   64.00 MiB, K (f16):   32.00 MiB, V (f16):   32.00 MiB
llama_new_context_with_model: graph splits (measure): 5
llama_new_context_with_model:      CUDA0 compute buffer size =   108.03 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   109.04 MiB

system_info: n_threads = 32 / 64 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 7285.27 ms
compute_imatrix: computing over 100 chunks with batch_size 512
collect_imatrix[0]: blk.0.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.0.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.0.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.0.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.0.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.1.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.1.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.1.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.1.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.1.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.2.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.2.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.2.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.2.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.2.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.3.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.3.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.3.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.3.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.3.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.4.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.4.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.4.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.4.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.4.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.5.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.5.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.5.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.5.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.5.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.6.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.6.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.6.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.6.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.6.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.7.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.7.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.7.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.7.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.7.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.8.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.8.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.8.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.8.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.8.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.9.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.9.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.9.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.9.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.9.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.10.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.10.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.10.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.10.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.10.ffn_gate_inp.weight, 4096 x 512, 0
collect_imatrix[1]: blk.11.attn_q.weight, 4096 x 512, 0
collect_imatrix[1]: blk.11.attn_k.weight, 4096 x 512, 0
collect_imatrix[1]: blk.11.attn_v.weight, 4096 x 512, 0
collect_imatrix[1]: blk.11.attn_output.weight, 4096 x 512, 0
collect_imatrix[1]: blk.11.ffn_gate_inp.weight, 4096 x 512, 0
...

@ggerganov
Copy link
Owner Author

Mixtral should be fixed now - the MUL_MAT_ID ops weren't handled in the callback. Please give it a try again

@askmyteapot
Copy link

Mixtral should be fixed now - the MUL_MAT_ID ops weren't handled in the callback. Please give it a try again

Did a quick test to 20 chunks and built a quant with it. Working now.

Thanks for that.

@JianbangZ
Copy link

JianbangZ commented Jan 17, 2024

This is great. I computed an imatrix with 1000 chunks around 10 minutes for a 13B/14B model. THis allows us doing some extensive experimetn on large calibration dataset.

Base automatically changed from gg/sched-eval-callback-4931 to master January 17, 2024 16:39
@ggerganov ggerganov merged commit ba69bbc into master Jan 17, 2024
36 of 47 checks passed
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Feb 3, 2024
* backend : add eval callback

ggml-ci

* backend : group nodes in a single compute when user don't need them

* backend : clean-up the implementation

ggml-ci

* simple : do not perform tensor data copy if not needed

* simple : fix

* imatrix : offload to GPU support

* imatrix : fix ggml_mul_mat_id hanlding

ggml-ci

* ci : add imatrix test

ggml-ci

* ci : rearrange output

ggml-ci
@Mihaiii
Copy link
Contributor

Mihaiii commented Feb 12, 2024

I've just made my first GGUF repo that uses the new imatrix method, here: https://huggingface.co/TheBloke/Yi-34B-200K-DARE-megamerge-v8-GGUF

I used this PR so as to speed up the imatrix creation.

On a 34B model, with a 5000-line (76,859 word - I didn't count the tokens) dataset, it took 21 minutes with -c 4096 -b 1024 on an L40 48GB GPU, and -t 10

I used this command:

CUDA_VISIBLE_DEVICES=7  ./imatrix -m /workspace/process/brucethemoose_yi-34b-200k-dare-megamerge-v8/gguf/yi-34b-200k-dare-megamerge-v8.fp16.gguf -f /workspace/datasets/open-instruct-5K.txt -o /workspace/process/brucethemoose_yi-34b-200k-dare-megamerge-v8/gguf/yi-34b-200k-dare-megamerge-v8.imatrix -t 10 -c 4096 -ngl 35 -b 1024

Hope I did it right! The model is coherent at least! :)

Thanks for the exact command, but what is the content of "open-instruct-5K.txt"? Can I use for imatrix a part of the dataset I finetuned with? What if I finetune in chat format and therefore is not free text (ex: wiki) - how should I format the txt I use for imatrix?

I think @TheBloke is on vacation these days, but if anyone else has some hints/clarifications, they would be much appreciated.

hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* backend : add eval callback

ggml-ci

* backend : group nodes in a single compute when user don't need them

* backend : clean-up the implementation

ggml-ci

* simple : do not perform tensor data copy if not needed

* simple : fix

* imatrix : offload to GPU support

* imatrix : fix ggml_mul_mat_id hanlding

ggml-ci

* ci : add imatrix test

ggml-ci

* ci : rearrange output

ggml-ci
@Iridescent-gcrace
Copy link

How to use ./perplexity to measure the model after "imatrix".
I use this command to imatrix the model.
CUDA_VISIBLE_DEVICES=0 LLAMA_CUBLAS=1 ./imatrix -m models/Mistral-7B-v0.1/Mistral-7B-v0.1-7B-F32.gguf -o models/Mistral-7B-v0.1/Mistral-q4-imatrix.gguf -f tests/wikitext-2-raw/wiki.train.txt -ngl 99

get the Mistral-q4-imatrix.gguf model and run this command
./perplexity -m models/Mistral-7B-v0.1/Mistral-q4-imatrix.gguf -f tests/wikitext-2-raw/wiki.test.raw
But I get "[can't load model](error: unable to load model)"

And I try mv Mistral-q4-imatrix.gguf Mistral-q4-imatrix.imatrix And Also get the same error.

I want to ask if this support perplexity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

Add importance matrix calculation to non-CPU back-ends
9 participants