-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DCU] fix DCU w8a8c8 GEMM shape #9115
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9115 +/- ##
===========================================
- Coverage 53.34% 53.33% -0.01%
===========================================
Files 652 652
Lines 105401 105404 +3
===========================================
Hits 56222 56222
- Misses 49179 49182 +3 ☔ View full report in Codecov by Sentry. |
@@ -631,6 +631,7 @@ def __init__(self, config: LlamaConfig): | |||
) | |||
|
|||
else: | |||
print("self.quant_type: ", self.quant_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -372,7 +372,7 @@ def __init__(self, config: Qwen2Config): | |||
use_neox_rotary_style=self.use_neox, | |||
cachekv_int8_type=config.cachekv_int8_type, | |||
rank_id=config.tensor_parallel_rank, | |||
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True), | |||
trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个就是为了修复w8a8c8时的gemm shape ~
f2f33dc
to
f5b2995
Compare
f5b2995
to
81c30d8
Compare
PR types
Bug fixes
PR changes
Others
Description
fix DCU GEMM shape when quant_type == "a8w8c8"