Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: EETQ quantizer support #30262

Merged
merged 31 commits into from
Apr 22, 2024
Merged

[FEAT]: EETQ quantizer support #30262

merged 31 commits into from
Apr 22, 2024

Conversation

dtlzhuangz
Copy link
Contributor

@dtlzhuangz dtlzhuangz commented Apr 16, 2024

What does this PR do?

EETQ supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.
NetEase-FuXi/EETQ#13

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Fixes NetEase-FuXi/EETQ#13

@dtlzhuangz
Copy link
Contributor Author

dtlzhuangz commented Apr 16, 2024

@younesbelkada
Please review the code and document to see if there is anything inappropriate.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR @dtlzhuangz ! EETQ library is written in such way that the integration is very smooth. We can quantize on the fly, serialize the quantized model and even reload it with minimal changes in transformers 🔥 I left a few minor comments. Make sure to fix the style with make style.

docs/source/en/main_classes/quantization.md Outdated Show resolved Hide resolved
Comment on lines 648 to 658
Make sure you have eetq installed via the source code https://github.com/NetEase-FuXi/EETQ
```
git clone https://github.com/NetEase-FuXi/EETQ.git
cd EETQ/
git submodule update --init --recursive
pip install .
```
Copy link
Member

@SunMarc SunMarc Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a plan to release EETQ on pypi ?

docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
src/transformers/integrations/__init__.py Outdated Show resolved Hide resolved
tests/quantization/eetq_integration/test_eetq.py Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_eetq.py Outdated Show resolved Hide resolved
src/transformers/quantizers/auto.py Outdated Show resolved Hide resolved
src/transformers/quantizers/auto.py Outdated Show resolved Hide resolved
src/transformers/quantizers/auto.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this great work ! In addition to @SunMarc 's comments I have tiny additional comments
1- Can you add pip install git+https://github.com/NetEase-FuXi/EETQ.git inside the quantization docker file here: https://github.com/huggingface/transformers/blob/main/docker/transformers-quantization-latest-gpu/Dockerfile
2- Can you elaborate on the hardware restrictions in the documentation section? (i.e. if it works only from cuda compute capability 8.0 and above, or also 7.0 etc)
3- Yes let's use camel case for the newly introduced files (EETQ --> Eetq)
4- Can you make sure the styling checks pass make fixup to make the CI happy?
Thanks again and looking forward to merging this !

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dtlzhuangz, thanks for the fast response. I've answered your questions. After fixing the issues pointed by @younesbelkada, we will ask a core maintainer for a final review.

@dtlzhuangz
Copy link
Contributor Author

Hi @dtlzhuangz, thanks for the fast response. I've answered your questions. After fixing the issues pointed by @younesbelkada, we will ask a core maintainer for a final review.

Sorry, could you help me fix the error of 'Import block is un-sorted or un-formatted'? I'm not quite familiar with the CI.

@SunMarc
Copy link
Member

SunMarc commented Apr 17, 2024

Yes, I took care of that. You just needed to do make style. I will try to run the test on my setup to see if everything works !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dtlzhuangz
Copy link
Contributor Author

Yes, I took care of that. You just needed to do make style. I will try to run the test on my setup to see if everything works !

Thank you so much for your guidance and effort!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very smooth integration ! Thanks for delivering this to the community ! LGTM with only two nits

docs/source/en/quantization.md Outdated Show resolved Hide resolved
src/transformers/integrations/eetq.py Show resolved Hide resolved
r"""
Safety checker that arguments are correct
"""
accepted_weights = ["int8"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity: is there any plans to support 4-bit group-wise quantization as well ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, no at the moment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok no worries!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to build the dockerfile and the tests are passing 🔥 Thanks again @dtlzhuangz for the clean PR. Do you any plan to release the package on pypi ? Installing from source is not ideal since it takes quite a lot of time to build the wheels + users are subject to breaking changes since there is no release yet.

@dtlzhuangz
Copy link
Contributor Author

dtlzhuangz commented Apr 18, 2024

I was able to build the dockerfile and the tests are passing 🔥 Thanks again @dtlzhuangz for the clean PR. Do you any plan to release the package on pypi ? Installing from source is not ideal since it takes quite a lot of time to build the wheels + users are subject to breaking changes since there is no release yet.

Sorry for replying to the question late. My colleague and I are setting out to do it but the built files depend on the version of torch. Error occurs if the version mismatches. If there is no solution, we have to install a specific version of torch when installing EETQ

@dtlzhuangz
Copy link
Contributor Author

dtlzhuangz commented Apr 19, 2024

Hi @SunMarc @younesbelkada @amyeroberts, we have released the .whl in the release page and updated the document. Please make a review. Thanks!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work thanks! I will let @amyeroberts make a final review and merge it if all is good ! Thanks again for all your great work ! @dtlzhuangz

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Just a few small comments to address

docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
Comment on lines +104 to +108
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to use sets here - otherwise we can end up with duplicate modules added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added modules_to_not_convert = list(set(modules_to_not_convert))

Comment on lines 42 to 43
if current_key_name is None:
current_key_name = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go outside of the for-loop, we only need to check it for noneness once

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. It has been done.

def test_raise_if_non_quantized(self):
model_id = "facebook/opt-125m"
quantization_config = EetqConfig()
_ = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=quantization_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't test any error is raised here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed it.

Comment on lines +84 to +83
if torch_dtype is None:
torch_dtype = torch.float16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 logger.info message should be added here

@dtlzhuangz
Copy link
Contributor Author

dtlzhuangz commented Apr 21, 2024

Hi @amyeroberts . I have fixed all the comments. I think ci errors should not be because of me, the errors occured after I modified the quantization.md. Please make a check.

@amyeroberts
Copy link
Collaborator

@dtlzhuangz Regarding the failing tests - a fix has been merged into main. Could you rebase?

@younesbelkada
Copy link
Contributor

Re-ran the testing suite and tests seem to pass now ! 🤞

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this and iterating!

@amyeroberts amyeroberts merged commit b4c18a8 into huggingface:main Apr 22, 2024
22 checks passed
@dtlzhuangz
Copy link
Contributor Author

Thank you all for your help! @SunMarc @amyeroberts @younesbelkada

@younesbelkada
Copy link
Contributor

Great work thanks everyone involved in this !

itazap pushed a commit that referenced this pull request May 14, 2024
* [FEAT]: EETQ quantizer support

* Update quantization.md

* Update docs/source/en/main_classes/quantization.md

Co-authored-by: Marc Sun <[email protected]>

* Update docs/source/en/quantization.md

Co-authored-by: Marc Sun <[email protected]>

* Update docs/source/en/quantization.md

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/integrations/__init__.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/integrations/__init__.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/integrations/eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/integrations/eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/integrations/eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update tests/quantization/eetq_integration/test_eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/quantizers/auto.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/quantizers/auto.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/quantizers/auto.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/quantizers/quantizer_eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update tests/quantization/eetq_integration/test_eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/quantizers/quantizer_eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update tests/quantization/eetq_integration/test_eetq.py

Co-authored-by: Marc Sun <[email protected]>

* Update tests/quantization/eetq_integration/test_eetq.py

Co-authored-by: Marc Sun <[email protected]>

* [FEAT]: EETQ quantizer support

* [FEAT]: EETQ quantizer support

* remove whitespaces

* update quantization.md

* style

* Update docs/source/en/quantization.md

Co-authored-by: Younes Belkada <[email protected]>

* add copyright

* Update quantization.md

* Update docs/source/en/quantization.md

Co-authored-by: amyeroberts <[email protected]>

* Update docs/source/en/quantization.md

Co-authored-by: amyeroberts <[email protected]>

* Address the comments by amyeroberts

* style

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Integration with Hugging Face transformers library
5 participants