Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral: Mixture of Experts quantization #251

Merged
merged 21 commits into from
Dec 22, 2023
Merged

Mixtral: Mixture of Experts quantization #251

merged 21 commits into from
Dec 22, 2023

Conversation

casper-hansen
Copy link
Owner

@casper-hansen casper-hansen commented Dec 11, 2023

BIG NOTE: Pending more perplexity numbers. Looking to see if we can optimize before merging.

  • FP16: Perplexity 4.137
  • First iteration: Perplexity 6.469 581b416
  • Second iteration: Perplexity 5.165 240fdc8
  • Third iteration: Perplexity 4.294 075822c

@casper-hansen casper-hansen mentioned this pull request Dec 14, 2023
awq/models/mixtral.py Outdated Show resolved Hide resolved
@noah-kim-theori
Copy link

Referring to your code, I implemented Mixtral (transformers==4.36.2) in AutoAWQ 0.1.7 as a single file gist. The all FFN of the MoE has post_attention_layernorm as the previous operation at transformers.models.mixtral.modeling_mixtral#L748. In your code, is there a specific reason why the remaining expert has a previous operation as the last ffn layer at the previous expert?

@L1aoXingyu
Copy link

Referring to your code, I implemented Mixtral (transformers==4.36.2) in AutoAWQ 0.1.7 as a single file gist. The all FFN of the MoE has post_attention_layernorm as the previous operation at transformers.models.mixtral.modeling_mixtral#L748. In your code, is there a specific reason why the remaining expert has a previous operation as the last ffn layer at the previous expert?

I have the similar implementation and got the ppl result like this

model wikitext-2 ptb c4
fp16 4.135127067565918 15.207035064697266 8.126235008239746
int4-rtn 4.332545280456543 15.017653465270996 8.374772071838379
int4-awq 4.277602672576904 15.137899398803711 8.237201690673828

To my surprise, just using rtn can get a very strong performance.

@noah-kim-theori
Copy link

noah-kim-theori commented Dec 21, 2023

I checked right before, and I also think that using RTN alone produces better results too. Thanks for your share.

@casper-hansen
Copy link
Owner Author

casper-hansen commented Dec 21, 2023

Referring to your code, I implemented Mixtral (transformers==4.36.2) in AutoAWQ 0.1.7 as a single file gist. The all FFN of the MoE has post_attention_layernorm as the previous operation at transformers.models.mixtral.modeling_mixtral#L748. In your code, is there a specific reason why the remaining expert has a previous operation as the last ffn layer at the previous expert?

Thanks for a reference implementation. I have been exhausting GPU credits trying to scale this model effectively. There is no specific reason for the current approach other than it worked the best in my tests - however, your implementation is better as is evident by the results.

Do you want to raise a PR to merge your changes into this branch/PR so we can merge it into AutoAWQ? I can also do it if you don’t mind.

@casper-hansen
Copy link
Owner Author

Referring to your code, I implemented Mixtral (transformers==4.36.2) in AutoAWQ 0.1.7 as a single file gist. The all FFN of the MoE has post_attention_layernorm as the previous operation at transformers.models.mixtral.modeling_mixtral#L748. In your code, is there a specific reason why the remaining expert has a previous operation as the last ffn layer at the previous expert?

I have the similar implementation and got the ppl result like this

model wikitext-2 ptb c4
fp16 4.135127067565918 15.207035064697266 8.126235008239746
int4-rtn 4.332545280456543 15.017653465270996 8.374772071838379
int4-awq 4.277602672576904 15.137899398803711 8.237201690673828
To my surprise, just using rtn can get a very strong performance.

I updated the code with the new quantization of layers, I got Perplexity 4.294. What did you do differently from the current implementation?

@noah-kim-theori
Copy link

I checked your commit, and it's fine to use it as is. Feel free to use it.

@vince62s
Copy link

Did you guys run a MMLU benchmark on the quantized model? I'm a bit disappointed. getting 60 vs 71
did not have such a discrepancy with the mistral instruct.

@casper-hansen
Copy link
Owner Author

Did you guys run a MMLU benchmark on the quantized model? I'm a bit disappointed. getting 60 vs 71 did not have such a discrepancy with the mistral instruct.

Did you evaluate with fused modules?

@casper-hansen casper-hansen merged commit 5b9f3c4 into main Dec 22, 2023
@vince62s
Copy link

Well I'm using OpenNMT-py but I benchmarked (speed-wise) your code and in fact the only 2 big things are fasttransformer (that I replaced by flash2 with kv cache doing the same stuff) and "your" RMSnorm kernel, both of them making the nice speed. btw gemv works fine for batches > 1, just a little slower than gemm but works ok.

@casper-hansen
Copy link
Owner Author

Well I'm using OpenNMT-py but I benchmarked (speed-wise) your code and in fact the only 2 big things are fasttransformer (that I replaced by flash2 with kv cache doing the same stuff) and "your" RMSnorm kernel, both of them making the nice speed. btw gemv works fine for batches > 1, just a little slower than gemm but works ok.

Nice, I have been looking to replace FasterTransformer modules with Flash Attention. The kernels that are in AutoAWQ are imported from other projects to maximize inference speed and to create generalized modules. GEMV is great in many cases, especially for local models!

@vince62s
Copy link

vince62s commented Dec 22, 2023

False alarm, I am getting 67.1 using the right Rope Theta. btw don't forget to make it an option bc @younesbelkada is already tagging this PR :)
NB: I using a non scaled, non clipped awq gemv version, so maybe yours will improve a little bit.

@younesbelkada
Copy link
Collaborator

Thanks @vince62s you mean in the transformers integration for fused modules?

@vince62s
Copy link

yes here: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/fused/attn.py#L224
@casper-hansen knows because he mentioned this in an issue but in case some people already tries this PR.

@younesbelkada
Copy link
Collaborator

younesbelkada commented Dec 22, 2023

Ah yes makes sense, thanks for the heads up, will update once I raise the PR in transformers!

@casper-hansen
Copy link
Owner Author

Ahh this was probably the problem I had with perplexity earlier. I forgot to modify everything to support the correct theta value. Thanks for pointing it out @vince62s, I now remember this as a problem :)

@casper-hansen casper-hansen deleted the mixtral_moe branch December 23, 2023 14:04
@exceedzhang
Copy link

I ran Mixtral8*7b-v0.1 model.quantize error!
File "/data1/apps/miniconda3/envs/Mixtral/lib/python3.10/site-packages/awq/modules/linear.py", line 79, in from_linear
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
IndexError: index 0 is out of bounds for dimension 1 with size 0
image

image

@vince62s
Copy link

for the sake of completeness, I ran my same mmlu script on the HF model from @casper-hansen
ACC-all: 0.6682
So the calibration impact on PPL is clear but the impact on the MMLU benchmark is nil. I am wondering whether the actual output is better or not after calibration.

@casper-hansen
Copy link
Owner Author

I ran Mixtral8*7b-v0.1 model.quantize error! File "/data1/apps/miniconda3/envs/Mixtral/lib/python3.10/site-packages/awq/modules/linear.py", line 79, in from_linear qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) IndexError: index 0 is out of bounds for dimension 1 with size 0 image

image

Please reference the mixtral_quant script as it has special instructions!

for the sake of completeness, I ran my same mmlu script on the HF model from @casper-hansen ACC-all: 0.6682 So the calibration impact on PPL is clear but the impact on the MMLU benchmark is nil. I am wondering whether the actual output is better or not after calibration.

Glad to hear it’s performing well on MMLU. Can you share your benchmark script? I’m in the process of adding more evaluation scripts to AutoAWQ. I was thinking of using vLLM for optimized parallel evaluation.

@vince62s
Copy link

I am using my own adaptation (for OpenNMT-py) of this script https://github.com/FranxYao/chain-of-thought-hub/tree/main/MMLU which is almost the original implementation of the MMLU (slightly different of lm_eval harness used by HF leader board).
In MMLU you expect one token being A, B, C, D. If the output is "Doe" instead of [A, B, C, D] then wrong answer, HF Leader board will take the best score out of [A, B, C, D] so there is always an answer. Anyway just a slight difference and the above script is way faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants