Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: n_ctx will reuse n_ctx_train when --ctx_size not set and make deepseek-v2 models meet out of memory crash even on a small output length. #8817

Closed
ClarkChin08 opened this issue Aug 2, 2024 · 12 comments · Fixed by #10136
Labels
bug Something isn't working bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)

Comments

@ClarkChin08
Copy link
Contributor

ClarkChin08 commented Aug 2, 2024

What happened?

deepseek-v2 model will meet out of memory issue with the kv buffer size allocating about 43G with a 160K context length from the model. But when you set the -c or --ctx_size 2048, then the inference can work normally.

Name and Version

./build/bin/llama-cli -m deepseek-v2-lite-chat-q4_0.gguf -p "how to build a website?" -n 32 -e -ngl 29 -sm none
Linux build on master branch :c8a0090922bad576623de4aae227717085249262

What operating system are you seeing the problem on?

No response

Relevant log output

No response

@ClarkChin08 ClarkChin08 added bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) labels Aug 2, 2024
@ClarkChin08
Copy link
Contributor Author

ClarkChin08 commented Aug 2, 2024

@characharm Your issue from #8483 is caused by the default n_ctx loaded from deepseek-v2 which is 163840 and will cause an allocation of kv buffer memory about 43G exceeding the GPU memory limit. You can use "-c 2048" to set the context length and it will work well on the newest commit on master: c8a0090
You can also set "-sm layer" mode to make the model run on multiple GPUs if you have multi-cards.

@ClarkChin08
Copy link
Contributor Author

ClarkChin08 commented Aug 2, 2024

@slaren Do you think this functionality is a bug when user not set the --ctx_size and llama.cpp will reuse the n_ctx_train as n_ctx from the model. For deepseek-v2 case, the n_ctx_train size is 160K, even the user's real input and output to be small it will keep allocating a super large kv buffer(in this case about 43G kv buffer). Should we calculate the real n_ctx from the user input instead of reuse n_ctx_train?

@airMeng
Copy link
Collaborator

airMeng commented Aug 2, 2024

I think this is not expected since the latest LLM will introduce much longer and longer training context length

@airMeng airMeng added the bug Something isn't working label Aug 2, 2024
@shibe2
Copy link
Contributor

shibe2 commented Aug 2, 2024

I would classify this as enhancement. I understand the idea as follows:

  1. when context size is not specified on command line,
  2. and size of input and output is known beforehand,
  3. and default context size is larger than what will be needed,

reduce the context size so it's just enough for the generation.

Note that condition 2 is not met in many cases, such as in interactive mode and in llama-server. If we want to avoid OOM condition in such cases, we will need some other ideas, such as:

  • have the back-end report how much memory it has for KV cache, and limit default context size based on that;
  • don't allocate memory for the whole cache at once, let it grow as needed and go OOM only when it's actually filled up to a critical point.

@slaren
Copy link
Collaborator

slaren commented Aug 3, 2024

@slaren Do you think this functionality is a bug when user not set the --ctx_size and llama.cpp will reuse the n_ctx_train as n_ctx from the model. For deepseek-v2 case, the n_ctx_train size is 160K, even the user's real input and output to be small it will keep allocating a super large kv buffer(in this case about 43G kv buffer). Should we calculate the real n_ctx from the user input instead of reuse n_ctx_train?

It is certainly not a bug, and the value of n_ctx is already taken from user input. However I agree that setting the default n_ctx to the maximum is not good, and it would be good to revert it back to a smaller default (originally it was 512).

@airMeng
Copy link
Collaborator

airMeng commented Aug 4, 2024

@shibe2 I agree it is more of a feature enhancement. I think it will be quite useful if llama.cpp can calculate the appropriate n_ctx especially for serving, any plans on it?

@ngxson
Copy link
Collaborator

ngxson commented Aug 8, 2024

I agree with @slaren that we should revert to the default hard-coded value instead of using n_ctx_train. This is quite annoying for example with Phi-3-mini-128k-instruct where ctx is set to 128k by default.

@shibe2
Copy link
Contributor

shibe2 commented Aug 8, 2024

Tossing another idea around: set some value at compile time with some default, then use that value to limit default n_ctx.

@ggerganov
Copy link
Owner

The main reason to use n_ctx_train by default was that there were several issues filed where people were using very short context + context shifts and didn't realize that the context gets discarded. We could set a new default, but I'm surprised that the current solution is causing confusion - the error message should hint that the KV cache does not fit in memory. Is it not clear? If so, maybe we can improve it?

@shibe2
Copy link
Contributor

shibe2 commented Aug 9, 2024

When I most recently hit the OOM because of large context, it said:

CUDA error: out of memory

and

Aborted (core dumped)

And that's on a system that doesn't even have CUDA installed.

@ggerganov
Copy link
Owner

On my CUDA machine I get the following error:

GGML_CUDA=1 make -j && ./llama-cli -m models/llama-7b-v2/ggml-model-q4_0.gguf \
    -ngl 99 -p "I believe the meaning of life is" -c 1000000
...
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
llm_load_tensors: ggml ctx size =    0,27 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =    70,31 MiB
llm_load_tensors:      CUDA0 buffer size =  3577,56 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 1000000
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000,0
llama_new_context_with_model: freq_scale = 1
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 500000,00 MiB on device 0: cudaMalloc failed: out of memory
llama_kv_cache_init: failed to allocate buffer for kv cache
llama_new_context_with_model: llama_kv_cache_init() failed for self-attention cache
llama_init_from_gpt_params: error: failed to create context with model 'models/llama-7b-v2/ggml-model-q4_0.gguf'
main: error: unable to load model

And that's on a system that doesn't even have CUDA installed.

Huh, that does not seem right. Try to clean/rebuild

@ngxson
Copy link
Collaborator

ngxson commented Aug 9, 2024

My initial idea was to add a hint message if llama_kv_cache_init fails, for example:

Failed to allocate memory for KV cache.
Hint: Context length is currently %d, which maybe too big. Please use a smaller value, for example -c 1024

However, on my machine (Mac M3), running with -ngl 99 causes OOM at failed to allocate compute buffers (not caused by llama_kv_cache_init). Running with -ngl 0 give no error, although part of memory is moved to swap. So at least on Mac we don't have a reliable way to tell to the user that kv is too big.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants