Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support W4A8 quantization for vllm #5218

Merged
merged 1 commit into from
Jul 31, 2024
Merged

Support W4A8 quantization for vllm #5218

merged 1 commit into from
Jul 31, 2024

Conversation

HandH1998
Copy link
Contributor

@HandH1998 HandH1998 commented Jun 3, 2024

We have proposed a W4A8 quantization solution QQQ and integrated it into vllm. QQQ can not only achieve the similar performance of the leading W4A8, W8A8, and W4A16 quantization methods but also significantly accelerate inference—achieving up to 2.24x, 2.10x, and 1.25x speed boosting compared to FP16, W8A8, and W4A16(Marlin), respectively.

News or Update

  • [2024/06/17] Update!!! We release the QQQ paper on arXiv.
  • [2024/06/07] We fuse dynamic activation quantization into QQQLinearMethod to support QQQ for all the models. Note that this will reduce some inference performance. You should use this repo to reproduce the inference speed results of our paper.
  • [2024/06/03] We integrate QQQ into vLLM and release the code.

Usage

You can export the quantized model weights with this repo QQQ(only support llama for now). Our paper will be published on arXiv soon. Here is an offline inference example.

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is a",
    "A pig",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
model = your_quantized_model_path
tokenizer = your_tokenizer_path

# Create an LLM.
llm = LLM(
    model=model,
    tokenizer=tokenizer,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Experiments

Here we provide our experiment results using this repo.

Settings

  • vllm v0.4.1
  • 1 A100 80G
  • CUDA 11.8

Model Performance

We evaluated the model performance on WikiText2 and five zero-shot tasks.
model_performace

Throughput

We conducted the same-batch throughput comparison of quantized LLaMA-2 models under various batch sizes. The input sequence length is 1024 and the output sequence length is 128.
speedup

W4A8 GEMM Performance

We implement the W4A8 GEMM based on Marlin GEMM. Thanks to their great work! Here is the speedup over PyTorch FP16 GEMM (Calling CUTLASS) of all GEMMs under different numbers of input tokens. The weight matrix size is (N=8192, K=21760).
gemm_performance

@robertgshaw2-neuralmagic
Copy link
Collaborator

Thanks for the PR!

The team for NM will review this. We are going to have to split this work out into a few pieces.

Will get back to you

@robertgshaw2-neuralmagic
Copy link
Collaborator

@HandH1998 I synced up with @comaniac

This PR does 2 things:

  1. Introduces W4A8 GEMM for QQQLinearMethod
  2. Introduces the concept of A8 compute for layers that are not "Linear"

For (2), we need to make a more holistic plan to support this concept and NM and Anyscale are already working on this. We do not want to just hack up one model file to support this as it will become very unmaintainable

So, in terms of path forward for this PR, please remove the A8 compute from LayerNorm, Attention, and SiluAndMul and we should not make any material changes to llama.py.

This will allow QQQ to land quickly with the Linear running at A8.

@alexm-neuralmagic will review the kernels when he has some time this week.

We can then incorporate the other layers as part of the broader effort we have going on in a separate PR

@zhyncs
Copy link
Contributor

zhyncs commented Jun 4, 2024

remove the A8 compute from LayerNorm, Attention, and SiluAndMul

@robertgshaw2-neuralmagic This will reduce some performance gains. Could you share a more general approach and timeline? Perhaps our team can collaborate to implement the corresponding features in this PR, and then you can modify it into a general version later.
The existing W4A8 in the industry does not meet our accuracy requirements. That's why we propose this pr. In order to maximize performance while ensuring accuracy, all efforts are made to make this quantization truly usable in an online production environment, not just for publishing papers.

@mgoin
Copy link
Collaborator

mgoin commented Jun 4, 2024

@zhyncs The goal of making this W4A8 optimization "production-ready" is exactly why I also think it is a good idea to land the first step as simply having this only be enabled for Linear modules as QQQLinearMethod. With this, we can enable the method for all models, rather than just supporting Llama.

We have an RFC tracking Int8 W8A8 that lays out some of the high-level goals, including graph fusion through torch.compile #3975. We have been landing many kernels recently and hope to share the torch.compile prototype soon!

@HandH1998
Copy link
Contributor Author

@zhyncs The goal of making this W4A8 optimization "production-ready" is exactly why I also think it is a good idea to land the first step as simply having this only be enabled for Linear modules as QQQLinearMethod. With this, we can enable the method for all models, rather than just supporting Llama.

We have an RFC tracking Int8 W8A8 that lays out some of the high-level goals, including graph fusion through torch.compile #3975. We have been landing many kernels recently and hope to share the torch.compile prototype soon!

I see what you want to do. Hope you can finish supporting dynamic Activation Quantization successfully. I will try to split the activation quantization with LayerNorm, Attention, and SiluAndMul, and put it in QQQLinearMethod. Note that this will reduce some inference performance.

@robertgshaw2-neuralmagic
Copy link
Collaborator

@zhyncs The goal of making this W4A8 optimization "production-ready" is exactly why I also think it is a good idea to land the first step as simply having this only be enabled for Linear modules as QQQLinearMethod. With this, we can enable the method for all models, rather than just supporting Llama.
We have an RFC tracking Int8 W8A8 that lays out some of the high-level goals, including graph fusion through torch.compile #3975. We have been landing many kernels recently and hope to share the torch.compile prototype soon!

I see what you want to do. Hope you can finish supporting dynamic Activation Quantization successfully. I will try to split the activation quantization with LayerNorm, Attention, and SiluAndMul, and put it in QQQLinearMethod. Note that this will reduce some inference performance.

SG - then we can add in the quantized execution of non-Linear layers as part of our broader project

@HandH1998
Copy link
Contributor Author

@robertgshaw2-neuralmagic We have finished the mentioned work. This may help you achieve your high-level goals.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Thank you! This looks much cleaner

@robertgshaw2-neuralmagic
Copy link
Collaborator

@HandH1998 - where are these kernels from? Is there an open source library that we are tracking for these?

I know these are adapted from Marlin so I think we could ramp up on them quickly, but I wanted to understand if this is something that we [Neural Magic + rest of the vLLM team] need to maintain or if we are just simply tracking a remote repo

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

It's adapted from Marlin and modified as necessary.
https://github.com/HandH1998/vllm/blob/b4b677138732ac952be8dbc87e44a5a58dffeb73/csrc/quantization/qqq/qqq_gemm_kernel.cu#L2-L7

need to maintain

yep.

simply tracking a remote repo

There seems to be no need at the moment.

@robertgshaw2-neuralmagic
Copy link
Collaborator

It's adapted from Marlin and modified as necessary. https://github.com/HandH1998/vllm/blob/b4b677138732ac952be8dbc87e44a5a58dffeb73/csrc/quantization/qqq/qqq_gemm_kernel.cu#L2-L7

need to maintain

yep.

simply tracking a remote repo

There seems to be no need at the moment.

Im just trying to understand who is responsible for maintaining these kernels moving forward?

We currently have two models

  • With Punica / Flashattention / AWQ , we track from a different repo maintained by a 3rd party (+ we pull them into vllm as needed)
  • With Marlin / cutlass / triton / paged_attention, we maintain them ourselves

So I am trying to understand which case this is

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

which case this is

I think it should be the latter.

On another note, speaking off the topic, the performance of AWQ in vLLM is almost the worst among several mainstream frameworks. I also do not recommend continuing to use 3rd party, but rather maintaining our own set.

@robertgshaw2-neuralmagic
Copy link
Collaborator

which case this is

I think it should be the latter.

On another note, speaking off the topic, the performance of AWQ in vLLM is almost the worst among several mainstream frameworks. I also do not recommend continuing to use 3rd party, but rather maintaining our own set.

sg - we are reviewing

re: AWQ. I think we can run AWQ models with Marlin if we support zero points in Marlin. We would welcome a contribution

Copy link
Collaborator

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the marlin modified kernel, great work guys! Left some comments.

// may be evicted immediately; used for quantized weights B, which are only
// accessed precisely once and should thus not pollute the L2 cache which we
// need for inputs A and outputs C.
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have found in our tests that "stream" PTX is crashing on H100. Is it possible for you to use "cp.async.cg" and "cp.async.ca" without any createpolicy of cachehints? You can also test it on H100 and see if it crashes for you.

Copy link
Contributor Author

@HandH1998 HandH1998 Jun 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to modify functions cp_async4_stream and cp_async1_stream as you suggest next week. I think it should be easy to do.

"l"(glob_ptr), "n"(BYTES));
}

// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for documenting this! It is helpful

// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as s1's
// size is not fixed, we can not shuffle before inference we shuffle it when
// fetching s1 from global memory to shared memory, that's why s1_sh_wr is
// like this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "s1" statically defined with the quantization of the model or it is computed dynamically during inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter. Actually It is the dynamic per-token quantization scale, needing to be computed online.

FragC frag_c[thread_m_blocks][4][2];
FragS_GROUP frag_s3[2][4];
FragS_CHANNEL frag_s1[thread_m_blocks];
FragS_CHANNEL frag_s2[2][4];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During execution, the code uses only FragS_GROUP o FragS_CHANNEL. You can use if (constexpr) to avoid always defining both, or you are relying on the compiler to eliminate them anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it only uses FragS_CHANNEL for per-channel dequantization, while it needs both FragS_GROUP and FragS_CHANNEL for per-group weight quantization. I will consider your suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I am trying to define the register array frag_s3 in if constexpr, but it is invisible to external variables. Do you know a method to solve it?

frag_b1 = dequant_per_group(b_quant_shift, frag_s3[k % 2][j], 1);
} else {
int b_quant_shift = b_quant << 4;
frag_b0 = dequant_per_channel(b_quant);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you need to apply "dequant_per_channel" here and not on the final result (as it is done in the original Marlin, before write to C)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand why you did it this way. You actually have a scale applied at the end on the output, and here you simply convert 4 bit to 8 bit by extracting bits. But maybe I'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Here dequant_per_channel just converts INT4 to INT8 by left shifting 4 bits.

// finally have to globally reduce over the results. As the striped partioning
// minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to document here that the global_reduce works on int32 elements (since the original marlin was half), and mention also that this is the reason that you needed to add the temporary buffer C (since the original Marlin was reducing directly on the output buffer)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it later.

const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can change this code to dynamically detect the L1 cache size (like we did in gptq marlin and marlin24). This will provide you a better estimate of L1 cache, since sometimes you will need more L1 due to scales or increased batching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! I will modify it.

__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)

void qqq_cuda(const void* A, const void* B, void* C, void* D, void* s1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you rename it to "marlin_qqq_cuda", I think it is better describing what is happening there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

} // namespace qqq
} // namespace vllm

torch::Tensor qqq_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, "marlin_qqq_gemm" is a better name in my opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Jun 7, 2024

Another thing - what is the model serialization format? And how do we make models in this format?

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

Another thing - what is the model serialization format? And how do we make models in this format?

https://github.com/HandH1998/QQQ/blob/4ab12906bb144ca8977a077f7f191e218fc2e038/examples/quant_model.py#L76-L83

https://github.com/handh1998/qqq?tab=readme-ov-file#quantize-model

Due to the changes from the original pull request, the performance will have a slight difference compared to the data in the image.

frag_b1 = dequant_per_group(b_quant_shift, frag_s3[k % 2][j], 1);
} else {
int b_quant_shift = b_quant << 4;
frag_b0 = dequant_per_channel(b_quant);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand why you did it this way. You actually have a scale applied at the end on the output, and here you simply convert 4 bit to 8 bit by extracting bits. But maybe I'm missing something.

// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
// for weight per channel dequant.
__device__ inline FragB dequant_per_channel(int q) {
static constexpr int MASK = 0xf0f0f0f0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the mask supposed to be 0x0f0f0f0f? why do you extract the higher 4 bits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For per-channel weight dequantization, I position the INT4 weight into the upper 4 bits of an INT8 by multiplying by 16, essentially performing a left shift by 4 bits. I use the mentioned mask and line 624 to achieve it. Note that the 8 INT4 weights of one INT32 are shuffled offline to ensure fetching the correct weights for every thread of a wrap.

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

Hi @alexm-neuralmagic These are the illustrations in the paper on QQQ drawn by @HandH1998 , hoping to help you review the code. The paper will be published on arxiv soon, please stay tuned.

@alexm-neuralmagic
Copy link
Collaborator

@zhyncs thanks for the figure, this is really helpful

@HandH1998
Copy link
Contributor Author

So I am trying to understand which case this is

You guys can maintain qqq_gemm like you handle marlin. And remember to keep the copyright information at https://github.com/HandH1998/vllm/blob/b4b677138732ac952be8dbc87e44a5a58dffeb73/csrc/quantization/qqq/qqq_gemm_kernel.cu#L2-L7. If I have other optimization ideas, I will let you know.

@HandH1998
Copy link
Contributor Author

@robertgshaw2-neuralmagic @alexm-neuralmagic We have fixed the most issues you mentioned. For more details, please refer to our paper.

@HandH1998
Copy link
Contributor Author

@robertgshaw2-neuralmagic @alexm-neuralmagic We have rebased our code on vllm 0.5.0 and hope you guys can review it.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some unit tests for this kernel?

csrc/ops.h Outdated
Comment on lines 97 to 112
torch::Tensor marlin_qqq_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& s1, torch::Tensor& s2,
torch::Tensor& s3, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you mark as const the operands that aren't modified? I know that torch::tensor.data_ptr doesn't respect constness so this will be a somewhat nonfunctional change, but this will make it easier to define torch metafunctions, which are necessary to support this function in torch.compile

torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, torch::Tensor const& b_q_weight,
                              torch::Tensor& s1, torch::Tensor& s2,
                              torch::Tensor& s3, torch::Tensor& workspace,
                              int64_t size_m, int64_t size_n, int64_t size_k);

Comment on lines 229 to 260
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}

// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these exactly the same as the functions in marlin_cuda_kernel.cu? If so, could you factor these out and place them in a common location to be included by both?

Comment on lines 502 to 425
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependicies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dependicies -> dependencies and maintining -> maintaining

I think codespell should catch this, so make sure you can run format.sh on this to get the build green :)

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth We are adding the GEMM unit test and fixing the issues. We will release the code next Monday.

@HandH1998 HandH1998 force-pushed the w4a8 branch 2 times, most recently from f30daa6 to e28f2bb Compare June 24, 2024 05:40
@zhyncs
Copy link
Contributor

zhyncs commented Jun 24, 2024

Hi @robertgshaw2-neuralmagic @alexm-neuralmagic @tlrmchlsmth May you help review this latest code? Thanks. If everything goes smoothly, after this PR is merged, out team @HandH1998 plan to continue working on INT4-FP8 QQQ (W4A8) based on this.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HandH1998 thanks for adding the integration test! PR looks good to me now.

I am going to go ahead and accept it, but before landing it, let's run down the issue that @brisker ran into, just to make sure nothing funny is going on. @brisker could you share the result of vllm's collect_env.py script? I'd like to make sure you haven't uncovered some edge cases we're missing. The output of that script looks good to me on an A100.

Also @brisker could you also try re-running your test script with
assert(torch.allclose(out1, out2, atol=1e-3, rtol=1e-3))
That's going to be a better way to check if the implementations match.

@brisker
Copy link

brisker commented Jul 24, 2024

@HandH1998
@tlrmchlsmth
I have figured out why weird thing happens:

It turns out that, when packing the model, if quanti scales is absmax, nothing weird happens.

But if quanti scales is not absmax, for the w4a8 kernel in this PR, the B (quantized w for the kernel) must be calculated by the fake-quant weight(quant-dequant weight), and scales must be calculated by original fp16 weight.

So for the weird results I mentioned above, after modifying
x_absmax = x.abs().amax(dim=-1, keepdim=True)
into
x_absmax = x.abs().amax(dim=-1, keepdim=True)* 0.4999 , ( any (0,1) small numbers will all cause weird results, not only 0.4999)

if I modify further:

modify

scale,scale_extra = get_scale(ori_fc.weight.data)
qqq_linear.pack(ori_fc, scale, scale_extra)
ori_fc.weight.data = w4_quant(ori_fc.weight.data)

into

scale,scale_extra = get_scale(ori_fc.weight.data)
ori_fc.weight.data = w4_quant(ori_fc.weight.data)
qqq_linear.pack(ori_fc, scale, scale_extra)  

then out1 and out2 becomes close again

So maybe this is due to the special calculation pipeline for this w4a8 kernel.

But anyway, in my opinion, this is a little weird, because normally, we should get quantized B and scales all from the original fp16 weight.

Any explaination on this pipeline?

@HandH1998 very weird thing is that, if you modify x_absmax = x.abs().amax(dim=-1, keepdim=True) to x_absmax = x.abs().amax(dim=-1, keepdim=True) * 0.4999 in the get_scale and w4_quant function in my script, large diff occurs.

I just tried it, and found this.

out1
tensor([[-0.0216,  0.5732,  0.4524,  ..., -0.4541,  0.2771, -0.4653],
        [-0.0375, -0.0845,  0.6333,  ...,  0.2949,  0.4790,  0.6089],
        [-0.1888, -0.1674, -0.4871,  ...,  0.2385,  0.7886,  0.2507],
        ...,
        [-0.0891,  0.2343,  0.1569,  ...,  0.6831, -0.1757, -0.5054],
        [-0.4187,  0.5977, -0.0225,  ...,  0.3330,  0.1484, -0.7993],
        [-0.4817, -0.1526, -0.1072,  ...,  0.4109, -0.0023, -0.5581]],
       device='cuda:0', dtype=torch.float16)
out2
tensor([[ 0.1066, -0.5508, -0.7207,  ...,  0.5381, -0.3894,  0.0118],
        [ 0.0127, -0.4534, -0.3826,  ..., -0.1334, -0.0036, -0.2688],
        [-0.1793, -0.0451,  0.4478,  ...,  0.4104,  0.6265,  0.5024],
        ...,
        [ 0.0429,  0.4233,  0.3572,  ..., -0.1161, -0.4062, -0.0836],
        [ 0.2568,  0.0115,  0.0422,  ...,  0.2424,  0.0394,  0.2054],
        [ 0.1392, -0.0248, -0.1234,  ..., -0.1874, -0.1741,  0.2067]],
       device='cuda:0', dtype=torch.float16)

@HandH1998
Copy link
Contributor Author

HandH1998 commented Jul 24, 2024

@brisker Thanks for your report! I have confirmed it is a bug in pack function. This is because that we missed clamp operation for per-channel quantization. So when the quant scale is not calculated from the absmax, the diff comes out. Now I have fixed it https://github.com/HandH1998/QQQ/commit/ae9531ce7a4d92080704df1e2b32f18a79600ae7. You can try your script again to verify it.

@brisker
Copy link

brisker commented Jul 24, 2024

@HandH1998
after bug fixed, I have tried it, now is ok

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth Thanks for your review! I plan to resolve conflicts on Thursday. Then you can go ahead.

@brisker
Copy link

brisker commented Jul 25, 2024

@HandH1998
Since many models are trained in bf16, is bf16 supported for this w4a8 PR?

I mean, although the gemm is w4a8, but if it's quantized w must be converted from fp16 weights, there may exists accuracy risks, since many models are trained in bf16, and maybe the quantization tuning process(fake-quant )must also be done in bf16.

https://github.com/HandH1998/QQQ/blob/main/QQQ/gptq/qlinear/qlinear_marlin.py#L181

@HandH1998
Copy link
Contributor Author

@brisker BF16 is not supported in this PR. We don't have enough time to support it for now, though we would like to put it in our list. If you are interested in it, we welcome you submit a PR to our QQQ repo. This mainly includes changes about: offline quant and per-group dequant kernel. The former may be easy to achieve.

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth I have resolved conflicts.

@brisker
Copy link

brisker commented Jul 29, 2024

@HandH1998

A few days ago, I said ok, but that is for w4a8-with-no-group.

Today I tried w4a8-gs128, now assert(torch.allclose(out1, out2, atol=1e-3, rtol=1e-3)) is not ok. ---this time, absmax or absmax*0.4999 are all not ok.

I think maybe there is another bug.

The new get_scale and w4_quant and qqq_linear init code is here:

qqq_linear = QQQ_Linear(
                4,
                group_size,
                2048,
                4096,
                False,
                weight_dtype=torch.float16
            )
def get_scale(x,group_size=-1):
    if group_size==-1:
        q_max = (2 ** (4-1)) - 1
        x_absmax = x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # if self.lwc:
        #     x_absmax = x_absmax * self.sigmoid(self.absbound_factor)

        scale = x_absmax / q_max
        return scale,None
    elif group_size==128:
        q_max = 2 ** (4) - 1
        zero_point = (q_max +1)/2
        reshaped_x = x.view(x.shape[0],int(x.shape[1]//group_size),group_size)
        xmax = reshaped_x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # if self.lwc:
        #     xmax = xmax * self.sigmoid(self.absbound_factor)
        xmin = -xmax
        scale = (xmax - xmin) / q_max
        # scale = scale.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value = torch.clamp(torch.round(reshaped_x / scale) + zero_point, 0, q_max)
        x_dequant =  scale * (quant_value - zero_point)
        x_dequant = x_dequant.view_as(x)


        ############ second step ##########
        second_step_q_max = 127  # int8
        absmax = x_dequant.abs().amax(dim=1, keepdim=True)
        scale_1 = absmax / second_step_q_max
        # import pdb;pdb.set_trace()
        # scale_1 = scale_1.clamp(min=CLIPMIN, max=CLIPMAX)
        return scale[:,:,0],scale_1

    else:
        raise RuntimeError
def w4_quant(x,group_size=-1):
    if group_size==-1:
        q_max = (2 ** (4-1)) - 1
        x_absmax = x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # x_absmax = x_absmax * 0.4999
        # if self.lwc:
        #     x_absmax = x_absmax * self.sigmoid(self.absbound_factor)
        scale = x_absmax / q_max
        # scale = scale.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value = torch.clamp(torch.round(x / scale), -q_max, q_max)

        x_dequant = quant_value * scale
        return x_dequant
    elif group_size==128:
        q_max = 2 ** (4) - 1
        zero_point = (q_max +1)/2
        reshaped_x = x.view(x.shape[0],int(x.shape[1]//group_size),group_size)
        # xmax = reshaped_x.amax(dim=-1, keepdim=True)
        # xmin = reshaped_x.amin(dim=-1,keepdim=True)  # bad direct w4a8-g128 acc, even worse than group=-1

        xmax = reshaped_x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # if self.lwc:
        #     xmax = xmax * self.sigmoid(self.absbound_factor)
        xmin = -xmax

        scale = (xmax - xmin) / q_max
        # scale = scale.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value = torch.clamp(torch.round(reshaped_x / scale) + zero_point, 0, q_max)
        x_dequant =  scale * (quant_value - zero_point)
        x_dequant = x_dequant.view_as(x)

        ###########


        ############ second step ##########
        second_step_q_max = 127  # int8
        absmax = x_dequant.abs().amax(dim=1, keepdim=True)
        scale_1 = absmax / second_step_q_max
        # scale_1 = scale_1.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value_1 = torch.clamp(torch.round(x_dequant / scale_1), -second_step_q_max, second_step_q_max)

        final_x_dequant = quant_value_1 * scale_1

        return final_x_dequant

    else:
        raise RuntimeError

@HandH1998 @zhyncs I got a w4a8 llama2-7b-model with wiki-ppl=5.7 when tested using fake-quant, but when I convert the weight into qqq-w4a8 format and test it again using the cuda-kernel, the wiki-ppl becomes 10000. So I run the following code to test the consistency between fake-quant and qqq-w4a8-kernel, I find there seems to be some difference which can not be neglected.

(I'm sure the conversion is right,because before tuning,the two ppl can be close,but all reasonable(7.5 and 7.53.) The QuantLinear class is extracted from here and dynamic_quant function is extracted from here

Any advice on this? I think this may be a potential bug.

I think this issue is important because the carefully tuned accuracy is on the fake-quant domain, but there is gap between fake-quant and cuda-kernel, which means carefully tuned accuracy can not be real accuracy. We can not tune w4a8-accuracy on cuda-kernel-domain, right?

from copy import deepcopy
import torch
import torch.nn as nn
import sys 
import QuantLinear

def dynamic_quant(x: torch.Tensor):
    quant_scale = x.abs().max(dim=-1, keepdim=True)[0].div(127.0).to(torch.float)
    quant_x = (x / quant_scale).round().clamp(-128, 127).to(torch.int8)
    dequant_x = quant_x*quant_scale
    return quant_x, quant_scale,dequant_x

def get_scale(x):
    q_max = (2 ** (4-1)) - 1
    x_absmax = x.abs().amax(dim=-1, keepdim=True)

    scale = x_absmax / q_max

    return scale,None


def w4_quant(x):
    q_max = (2 ** (4-1)) - 1
    x_absmax = x.abs().amax(dim=-1, keepdim=True)

    scale = x_absmax / q_max


    quant_value = torch.clamp(torch.round(x / scale), -q_max, q_max)

    x_dequant = quant_value * scale

    return x_dequant
ori_fc = nn.Linear(2048,4096,bias=False).cuda().half()

qqq_linear = QuantLinear(
                4,
                -1,
                2048,
                4096,
                False,
                weight_dtype=torch.float16
            )
qqq_linear.cuda()

input = torch.Tensor(10,2048).normal_().cuda().half()
quant_input, quant_input_scale,dequant_input = dynamic_quant(input)


scale,scale_extra = get_scale(ori_fc.weight.data)
qqq_linear.pack(ori_fc, scale, scale_extra)
ori_fc.weight.data = w4_quant(ori_fc.weight.data)

with torch.no_grad():
    out1 = ori_fc(dequant_input.half())

with torch.no_grad():
    out2 = qqq_linear.forward(input)
  
print((out1==out2).sum(),out1.numel())

@tlrmchlsmth
Copy link
Collaborator

@brisker this wouldn't be a problem with any of the code in this PR, right, but rather an issue in https://github.com/HandH1998/QQQ?

@tlrmchlsmth
Copy link
Collaborator

BTW, I just tried rerunning the failing jobs

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 29, 2024
@mgoin
Copy link
Collaborator

mgoin commented Jul 29, 2024

Marked PR as ready and resolved merge conflict, let's see if it is green now

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth @mgoin The failed jobs is using lm_eval to evaluate our model HandH1998/QQQ-Llama-3-8b-g128. I used the command bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 250 -f 5 -t 1 to get the evaluation result of .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml on my host before. The result is 0.484 for exact_match,strict-match and 0.492 for exact_match,flexible-extract. It can not match with the CR result (0.448 for exact_match,strict-match).
Today I changed my flash-attn engine from xformers to vllm-flash-attn and got a new result 0.424 for exact_match,strict-match and 0.428 for exact_match,flexible-extract. It seems that flash-attn engine has a non-negligible impact on the final evaluation result. Anyway, the result of my host still don't align with the result of CR.
I wonder if you guys can reproduce the CR result. Please share your results here. Thanks!

@mgoin
Copy link
Collaborator

mgoin commented Jul 30, 2024

Hi @HandH1998 I think you would see a more stable result if you increased the number of samples. Generally we use -l 1000 instead of 250

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth @mgoin Hi, I have solved the CR issue when using -l 1000 instead of 250. However, new CR issues came out. I think the new issues are not relevant with our code. Could you please take a look?

Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting us know. The PP test is a flaky test so I will attempt to land

@mgoin mgoin merged commit 6512937 into vllm-project:main Jul 31, 2024
63 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
@RanchiZhao
Copy link

RanchiZhao commented Aug 19, 2024

so if I use vllm for qqq, what is the speed up ratio compared to gptq-marlin kernel? in this pr
and I also wonder is this pr's QQQ-vllm only supports kv16? here the authored mentioned kv fp8?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants