Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeLlama 34B errors out after 3+ completions #70

Closed
abacaj opened this issue Sep 25, 2023 · 10 comments · Fixed by #75
Closed

CodeLlama 34B errors out after 3+ completions #70

abacaj opened this issue Sep 25, 2023 · 10 comments · Fixed by #75

Comments

@abacaj
Copy link

abacaj commented Sep 25, 2023

Running codellama 34b using latest autoawq (installed from repo):

  File "/home/anton/personal/transformer-experiments/env/lib/python3.10/site-packages/awq/modules/fused/attn.py", line 183, in forward
    self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
RuntimeError: The expanded size of the tensor (0) must match the existing size (17) at non-singleton dimension 2.  Target sizes: [1, 8, 0, 128].  Tensor sizes: [8, 17, 128]

To reproduce:

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_name_or_path = "TheBloke/CodeLlama-34B-AWQ"

# Load model
model = AutoAWQForCausalLM.from_quantized(
    model_name_or_path,
    fuse_layers=True,
    trust_remote_code=False,
    safetensors=True,
    max_new_tokens=1024,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=False)

tokens = tokenizer(
    "# Write a python function to loop to 1000\n\ndef", return_tensors="pt"
).to("cuda")

# Generate output
for _ in range(10):
    generation_output = model.generate(
        **tokens,
        do_sample=True,
        temperature=0.2,
        top_p=0.95,
        top_k=0,
        max_new_tokens=512,
    )

    print(tokenizer.decode(generation_output[0], skip_special_tokens=True))
@abacaj
Copy link
Author

abacaj commented Sep 25, 2023

Setting fuse_layers=False seems to work with the same code (though slower generations).

@casper-hansen
Copy link
Owner

Thank you for this! My best guess is that the number of tokens exceed the cache. Will have to investigate this

@gestalt73
Copy link

gestalt73 commented Sep 26, 2023

I've seen the same with other models. Thanks for the script @abacaj I'm going to run some other models through their paces to see if I can reproduce.

AutoAWQ=0.1.0, python=3.10, cuda=11.8, rtx 3090

I can reproduce the error with any model:

  • TheBloke/Llama-2-7b-Chat-AWQ
  • TheBloke/vicuna-7B-v1.5-AWQ
  • casperhansen/vicuna-7B-v1.5-AWQ

File "/home/alansrobotlab/anaconda3/envs/textgen/lib/python3.11/site-packages/awq/modules/fused/attn.py", line 183, in forward self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: The expanded size of the tensor (0) must match the existing size (24) at non-singleton dimension 2. Target sizes: [1, 32, 0, 128]. Tensor sizes: [32, 24, 128]

@casper-hansen
Copy link
Owner

casper-hansen commented Sep 26, 2023

Fixed this now in #75, at least I cannot produce this error anymore even when running for 1000 iterations:

@abacaj and @gestalt73, would appreciate it if you could take the time to test out the pull request to see if something else breaks

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_name_or_path = "casperhansen/vicuna-7b-v1.5-awq"
max_new_tokens = 1024

# Load model
model = AutoAWQForCausalLM.from_quantized(
    model_name_or_path,
    quant_filename="awq_model_w4_g128.pt",
    fuse_layers=True,
    trust_remote_code=False,
    safetensors=True,
    max_new_tokens=1024,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=False)

tokens = tokenizer(
    "# Write a python function to loop to 1000\n\ndef", return_tensors="pt"
).to("cuda")

# Generate output
cumulative_tokens = 0

for i in range(1000):
    if cumulative_tokens > max_new_tokens:
        cumulative_tokens = 0
    
    generation_output = model.generate(
        **tokens,
        do_sample=True,
        temperature=0.2,
        top_p=0.95,
        top_k=0,
        max_new_tokens=512,
    )

    num_tokens = len(generation_output[0])
    cumulative_tokens += num_tokens

    print(i, num_tokens, cumulative_tokens)

    # print(tokenizer.decode(generation_output[0], skip_special_tokens=True))

@gestalt73
Copy link

gestalt73 commented Sep 26, 2023

Hey @casper-hansen I ran it a bit with TheBloke/Llama-2-7b-Chat-AWQ and things look normal until the first cache clear, then things get weird. It doesn't error out though.

Take a look at the output after the first set of cache clear messages around line 192.

Output is consistent for the first x generations, then after the resetting cache message it starts ok in generation but gets interesting towards the end of line 209. from there on out it's hit or miss, but I'm also seeing the huge amount of newlines which I would occasionally see in 0.1.0.

Fix KV cache shapes error 75 results.txt

@abacaj
Copy link
Author

abacaj commented Sep 27, 2023

I don't see the expanded tensor error anymore. But model generations using fused=True are different (worse) compared to fused=False

@abacaj
Copy link
Author

abacaj commented Sep 27, 2023

Added fused_true and fused_false samples here. I turned sampling off so it should be greedy generation. For fused=False the output seems good

https://gist.github.com/abacaj/aefb5e9dd85a6fc8b54b5b655a9a632e

@casper-hansen
Copy link
Owner

Thank you all for testing. The fact that the outputs after resetting the cache are getting weird or not working as expected is not good enough for me to merge the PR. I will have to explore:

  1. How to reset the cache without weird outputs
  2. How to increase allocated cache dynamically as inputs are run through the model

@casper-hansen
Copy link
Owner

I switched up the approach entirely, and we are rolling over the cache now. This seems to produce correct outputs, and we get as close to HF output with FT modules. They are not meant to be the exact same outputs as slight numerical differences will lead to different outputs in some cases - however, they are very close now.

@casper-hansen
Copy link
Owner

I have closed this issue as the main error has been solved. However, it seems there is a problem with the fused modules and the CodeLlama models, although it should already be supported as GQA is implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants