-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support W4A8 quantization for vllm #5218
Conversation
Thanks for the PR! The team for NM will review this. We are going to have to split this work out into a few pieces. Will get back to you |
@HandH1998 I synced up with @comaniac This PR does 2 things:
For (2), we need to make a more holistic plan to support this concept and NM and Anyscale are already working on this. We do not want to just hack up one model file to support this as it will become very unmaintainable So, in terms of path forward for this PR, please remove the A8 compute from This will allow QQQ to land quickly with the Linear running at A8. @alexm-neuralmagic will review the kernels when he has some time this week. We can then incorporate the other layers as part of the broader effort we have going on in a separate PR |
@robertgshaw2-neuralmagic This will reduce some performance gains. Could you share a more general approach and timeline? Perhaps our team can collaborate to implement the corresponding features in this PR, and then you can modify it into a general version later. |
@zhyncs The goal of making this W4A8 optimization "production-ready" is exactly why I also think it is a good idea to land the first step as simply having this only be enabled for Linear modules as We have an RFC tracking Int8 W8A8 that lays out some of the high-level goals, including graph fusion through torch.compile #3975. We have been landing many kernels recently and hope to share the torch.compile prototype soon! |
I see what you want to do. Hope you can finish supporting dynamic Activation Quantization successfully. I will try to split the activation quantization with |
SG - then we can add in the quantized execution of non-Linear layers as part of our broader project |
@robertgshaw2-neuralmagic We have finished the mentioned work. This may help you achieve your high-level goals. |
Thank you! This looks much cleaner |
@HandH1998 - where are these kernels from? Is there an open source library that we are tracking for these? I know these are adapted from Marlin so I think we could ramp up on them quickly, but I wanted to understand if this is something that we [Neural Magic + rest of the vLLM team] need to maintain or if we are just simply tracking a remote repo |
It's adapted from Marlin and modified as necessary.
yep.
There seems to be no need at the moment. |
Im just trying to understand who is responsible for maintaining these kernels moving forward? We currently have two models
So I am trying to understand which case this is |
I think it should be the latter. On another note, speaking off the topic, the performance of AWQ in vLLM is almost the worst among several mainstream frameworks. I also do not recommend continuing to use 3rd party, but rather maintaining our own set. |
sg - we are reviewing re: AWQ. I think we can run AWQ models with Marlin if we support zero points in Marlin. We would welcome a contribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed the marlin modified kernel, great work guys! Left some comments.
// may be evicted immediately; used for quantized weights B, which are only | ||
// accessed precisely once and should thus not pollute the L2 cache which we | ||
// need for inputs A and outputs C. | ||
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have found in our tests that "stream" PTX is crashing on H100. Is it possible for you to use "cp.async.cg" and "cp.async.ca" without any createpolicy of cachehints? You can also test it on H100 and see if it crashes for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try to modify functions cp_async4_stream
and cp_async1_stream
as you suggest next week. I think it should be easy to do.
"l"(glob_ptr), "n"(BYTES)); | ||
} | ||
|
||
// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for documenting this! It is helpful
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as s1's | ||
// size is not fixed, we can not shuffle before inference we shuffle it when | ||
// fetching s1 from global memory to shared memory, that's why s1_sh_wr is | ||
// like this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "s1" statically defined with the quantization of the model or it is computed dynamically during inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter. Actually It is the dynamic per-token quantization scale, needing to be computed online.
FragC frag_c[thread_m_blocks][4][2]; | ||
FragS_GROUP frag_s3[2][4]; | ||
FragS_CHANNEL frag_s1[thread_m_blocks]; | ||
FragS_CHANNEL frag_s2[2][4]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During execution, the code uses only FragS_GROUP o FragS_CHANNEL. You can use if (constexpr) to avoid always defining both, or you are relying on the compiler to eliminate them anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it only uses FragS_CHANNEL for per-channel dequantization, while it needs both FragS_GROUP and FragS_CHANNEL for per-group weight quantization. I will consider your suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I am trying to define the register array frag_s3
in if constexpr
, but it is invisible to external variables. Do you know a method to solve it?
frag_b1 = dequant_per_group(b_quant_shift, frag_s3[k % 2][j], 1); | ||
} else { | ||
int b_quant_shift = b_quant << 4; | ||
frag_b0 = dequant_per_channel(b_quant); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you need to apply "dequant_per_channel" here and not on the final result (as it is done in the original Marlin, before write to C)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I understand why you did it this way. You actually have a scale applied at the end on the output, and here you simply convert 4 bit to 8 bit by extracting bits. But maybe I'm missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Here dequant_per_channel
just converts INT4 to INT8 by left shifting 4 bits.
// finally have to globally reduce over the results. As the striped partioning | ||
// minimizes the number of such reductions and our outputs are usually rather | ||
// small, we perform this reduction serially in L2 cache. | ||
auto global_reduce = [&](bool first = false, bool last = false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to document here that the global_reduce works on int32 elements (since the original marlin was half), and mention also that this is the reason that you needed to add the temporary buffer C (since the original Marlin was reducing directly on the output buffer)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will fix it later.
const int USER_THREADS = | ||
256; // Note: This is only used with user-provided thread_k/n | ||
const int STAGES = 4; // 4 pipeline stages fit into shared memory | ||
const int SHARED_MEM = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can change this code to dynamically detect the L1 cache size (like we did in gptq marlin and marlin24). This will provide you a better estimate of L1 cache, since sometimes you will need more L1 due to scales or increased batching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion! I will modify it.
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ | ||
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) | ||
|
||
void qqq_cuda(const void* A, const void* B, void* C, void* D, void* s1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you rename it to "marlin_qqq_cuda", I think it is better describing what is happening there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK.
} // namespace qqq | ||
} // namespace vllm | ||
|
||
torch::Tensor qqq_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, "marlin_qqq_gemm" is a better name in my opinion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK.
Another thing - what is the model serialization format? And how do we make models in this format? |
https://github.com/handh1998/qqq?tab=readme-ov-file#quantize-model Due to the changes from the original pull request, the performance will have a slight difference compared to the data in the image. |
frag_b1 = dequant_per_group(b_quant_shift, frag_s3[k % 2][j], 1); | ||
} else { | ||
int b_quant_shift = b_quant << 4; | ||
frag_b0 = dequant_per_channel(b_quant); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I understand why you did it this way. You actually have a scale applied at the end on the output, and here you simply convert 4 bit to 8 bit by extracting bits. But maybe I'm missing something.
// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values | ||
// for weight per channel dequant. | ||
__device__ inline FragB dequant_per_channel(int q) { | ||
static constexpr int MASK = 0xf0f0f0f0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the mask supposed to be 0x0f0f0f0f? why do you extract the higher 4 bits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For per-channel weight dequantization, I position the INT4 weight into the upper 4 bits of an INT8 by multiplying by 16, essentially performing a left shift by 4 bits. I use the mentioned mask and line 624 to achieve it. Note that the 8 INT4 weights of one INT32 are shuffled offline to ensure fetching the correct weights for every thread of a wrap.
Hi @alexm-neuralmagic These are the illustrations in the paper on QQQ drawn by @HandH1998 , hoping to help you review the code. The paper will be published on arxiv soon, please stay tuned. |
@zhyncs thanks for the figure, this is really helpful |
You guys can maintain qqq_gemm like you handle marlin. And remember to keep the copyright information at https://github.com/HandH1998/vllm/blob/b4b677138732ac952be8dbc87e44a5a58dffeb73/csrc/quantization/qqq/qqq_gemm_kernel.cu#L2-L7. If I have other optimization ideas, I will let you know. |
@robertgshaw2-neuralmagic @alexm-neuralmagic We have fixed the most issues you mentioned. For more details, please refer to our paper. |
@robertgshaw2-neuralmagic @alexm-neuralmagic We have rebased our code on vllm 0.5.0 and hope you guys can review it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some unit tests for this kernel?
csrc/ops.h
Outdated
torch::Tensor marlin_qqq_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, | ||
torch::Tensor& s1, torch::Tensor& s2, | ||
torch::Tensor& s3, torch::Tensor& workspace, | ||
int64_t size_m, int64_t size_n, int64_t size_k); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you mark as const the operands that aren't modified? I know that torch::tensor.data_ptr doesn't respect constness so this will be a somewhat nonfunctional change, but this will make it easier to define torch metafunctions, which are necessary to support this function in torch.compile
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, torch::Tensor const& b_q_weight,
torch::Tensor& s1, torch::Tensor& s2,
torch::Tensor& s3, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
// Wait until barrier reaches `count`, then lock for current threadblock. | ||
__device__ inline void barrier_acquire(int* lock, int count) { | ||
if (threadIdx.x == 0) { | ||
int state = -1; | ||
do | ||
// Guarantee that subsequent writes by this threadblock will be visible | ||
// globally. | ||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" | ||
: "=r"(state) | ||
: "l"(lock)); | ||
while (state != count); | ||
} | ||
__syncthreads(); | ||
} | ||
|
||
// Release barrier and increment visitation count. | ||
__device__ inline void barrier_release(int* lock, bool reset = false) { | ||
__syncthreads(); | ||
if (threadIdx.x == 0) { | ||
if (reset) { | ||
lock[0] = 0; | ||
return; | ||
} | ||
int val = 1; | ||
// Make sure that all writes since acquiring this barrier are visible | ||
// globally, while releasing the barrier. | ||
asm volatile("fence.acq_rel.gpu;\n"); | ||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" | ||
: | ||
: "l"(lock), "r"(val)); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these exactly the same as the functions in marlin_cuda_kernel.cu
? If so, could you factor these out and place them in a common location to be included by both?
// Since B-accesses have non-constant stride they have to be computed at | ||
// runtime; we break dependicies between subsequent accesses with a tile by | ||
// maintining multiple pointers (we have enough registers), a tiny | ||
// optimization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dependicies -> dependencies and maintining -> maintaining
I think codespell should catch this, so make sure you can run format.sh
on this to get the build green :)
@tlrmchlsmth We are adding the GEMM unit test and fixing the issues. We will release the code next Monday. |
f30daa6
to
e28f2bb
Compare
Hi @robertgshaw2-neuralmagic @alexm-neuralmagic @tlrmchlsmth May you help review this latest code? Thanks. If everything goes smoothly, after this PR is merged, out team @HandH1998 plan to continue working on INT4-FP8 QQQ (W4A8) based on this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HandH1998 thanks for adding the integration test! PR looks good to me now.
I am going to go ahead and accept it, but before landing it, let's run down the issue that @brisker ran into, just to make sure nothing funny is going on. @brisker could you share the result of vllm's collect_env.py script? I'd like to make sure you haven't uncovered some edge cases we're missing. The output of that script looks good to me on an A100.
Also @brisker could you also try re-running your test script with
assert(torch.allclose(out1, out2, atol=1e-3, rtol=1e-3))
That's going to be a better way to check if the implementations match.
@HandH1998 It turns out that, when packing the model, if quanti scales is absmax, nothing weird happens. But if quanti scales is not absmax, for the w4a8 kernel in this PR, the So for the weird results I mentioned above, after modifying if I modify further: modify
into
then out1 and out2 becomes close again So maybe this is due to the special calculation pipeline for this w4a8 kernel. But anyway, in my opinion, this is a little weird, because normally, we should get quantized Any explaination on this pipeline?
|
@brisker Thanks for your report! I have confirmed it is a bug in |
@HandH1998 |
@tlrmchlsmth Thanks for your review! I plan to resolve conflicts on Thursday. Then you can go ahead. |
@HandH1998 I mean, although the gemm is w4a8, but if it's quantized w must be converted from fp16 weights, there may exists accuracy risks, since many models are trained in bf16, and maybe the quantization tuning process(fake-quant )must also be done in bf16. https://github.com/HandH1998/QQQ/blob/main/QQQ/gptq/qlinear/qlinear_marlin.py#L181 |
@brisker BF16 is not supported in this PR. We don't have enough time to support it for now, though we would like to put it in our list. If you are interested in it, we welcome you submit a PR to our QQQ repo. This mainly includes changes about: offline quant and per-group dequant kernel. The former may be easy to achieve. |
@tlrmchlsmth I have resolved conflicts. |
A few days ago, I said ok, but that is for w4a8-with-no-group. Today I tried w4a8-gs128, now I think maybe there is another bug. The new
|
@brisker this wouldn't be a problem with any of the code in this PR, right, but rather an issue in https://github.com/HandH1998/QQQ? |
BTW, I just tried rerunning the failing jobs |
Marked PR as ready and resolved merge conflict, let's see if it is green now |
@tlrmchlsmth @mgoin The failed jobs is using |
Hi @HandH1998 I think you would see a more stable result if you increased the number of samples. Generally we use -l 1000 instead of 250 |
@tlrmchlsmth @mgoin Hi, I have solved the CR issue when using -l 1000 instead of 250. However, new CR issues came out. I think the new issues are not relevant with our code. Could you please take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for letting us know. The PP test is a flaky test so I will attempt to land
so if I use vllm for qqq, what is the speed up ratio compared to gptq-marlin kernel? in this pr |
Signed-off-by: Alvant <[email protected]>
We have proposed a W4A8 quantization solution QQQ and integrated it into vllm. QQQ can not only achieve the similar performance of the leading W4A8, W8A8, and W4A16 quantization methods but also significantly accelerate inference—achieving up to 2.24x, 2.10x, and 1.25x speed boosting compared to FP16, W8A8, and W4A16(Marlin), respectively.
News or Update
QQQLinearMethod
to support QQQ for all the models. Note that this will reduce some inference performance. You should use this repo to reproduce the inference speed results of our paper.Usage
You can export the quantized model weights with this repo QQQ(only support llama for now). Our paper will be published on arXiv soon. Here is an offline inference example.
Experiments
Here we provide our experiment results using this repo.
Settings
Model Performance
We evaluated the model performance on WikiText2 and five zero-shot tasks.
Throughput
We conducted the same-batch throughput comparison of quantized LLaMA-2 models under various batch sizes. The input sequence length is 1024 and the output sequence length is 128.
W4A8 GEMM Performance
We implement the W4A8 GEMM based on Marlin GEMM. Thanks to their great work! Here is the speedup over PyTorch FP16 GEMM (Calling CUTLASS) of all GEMMs under different numbers of input tokens. The weight matrix size is (N=8192, K=21760).