Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Speedup dequantize kernels #1221

Closed
wants to merge 1 commit into from

Conversation

dfyz
Copy link
Collaborator

@dfyz dfyz commented Apr 29, 2023

NB: this is a proof-of-concept PR which only modifies dequantize_block_q4_0() (this is the quantization method I'm most familiar with). If there is interest, I will modify all quantization kernels.
NB2: I only did rudimentary testing: I checked that ./main -b 512 -t 12 -n 40 -f prompts/dan.txt -m models/13B/ggml-model-q4_0.bin --no-mmap -s 12345 generates the same tokens with and without this diff. Again, if there is interest, I will test the changes more thoroughly.

High-level overview

dequantize_block_q4_0() should be memory-bound, but currently it's not. Here's what NSight Compute thinks:

Compute/memory utilization
Roofline

The most obvious problem I'm trying to fix in this PR is that using 1 thread per thread block to load the quantized block bytes is very inefficient. Threads within a thread block are arranged in 32-thread warps, and all threads in a warp always execute the same instruction simultaneously. So, instead of issuing many sequential loads using only the first thread of the warp, we can make all 32 threads issue different loads to dequantize different floats. These loads will then be coalesced whenever possible. The 32 threads will also perform arithmetic in parallel, but since we are memory-bound, it doesn't matter that much.

Other tweaks implemented here (guided by NSight Compute):

  • running only warp per thread block leads to low occupancy and warp scheduling stalls, so I increased the number of threads in the thread block (1024 is the maximum number of threads; smaller values work, too)
  • we know that the number of quantized blocks is divisible by 2, and handling 2 blocks per warp results in utilizing memory better
  • NSight Compute showed that the hardware failed to coalesce byte-sized loads for some reason, so I switched the loads to 4-byte-loads. We could even use uint4 to load the whole block at once, but unfortunately the block bytes are not properly aligned to 16 bytes.

Here's the NSight Compute analysis for the new kernel:

Compute/memory utilization
Roofline

We still have a large stall waiting for the result of a 4-byte memory load, but I am out of ideas don't think it's worth chasing it further (see the performance measurements below):
Screenshot 2023-04-29 at 01 56 14

Performance measurements

These are the results of running ./perplexity -m models/13B/ggml-model-q4_0.bin -f wikitext-2-raw/wiki.test.raw under NSight Systems with --stats before the PR on an A100-SXM4-40GB GPU:

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                             Name
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------
...
     29.0      489,024,865        449  1,089,142.2    644,828.0    608,412  3,791,913    546,249.9  dequantize_block_q4_0(const void *, float *)
...

And after the PR:

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                             Name
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------
...
      5.7       66,543,457        410    162,301.1     95,008.0     88,127    586,301     83,005.9  dequantize_block_q4_0(int, const void *, float *)
...

So, a ≈6.5x speedup in the quantization kernel.

Unfortunately, I don't see any improvement in perplexity calculation times, presumably because even if we finish quantization early, we still have to wait until the second host-to-device transfer finishes. Here is a screenshot of what's happening before the PR:
Screenshot 2023-04-29 at 00 24 21

And after the PR:
Screenshot 2023-04-29 at 00 24 52

However, the quantization kernel becomes consistently faster across various GPUs. Some additional data points below.

GeForce RTX 3070

BEFORE
 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)    Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  -----------  -----------  ---------  ----------  -----------  ----------------------------------------------------------------------------------------------------
...
     41,0      835 572 902        342  2 443 195,0  1 415 739,0  1 404 540   8 822 843  1 229 082,0  dequantize_block_q4_0(const void *, float *)
...

AFTER
 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)    Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  -----------  -----------  ---------  ----------  -----------  ----------------------------------------------------------------------------------------------------
...
     13,0      226 221 174        420    538 621,0    310 908,0    309 132   1 939 892    269 331,0  dequantize_block_q4_0(int, const void *, float *)
...

Tesla M40

BEFORE
 Time (%)  Total Time (ns)  Instances    Avg (ns)     Med (ns)    Min (ns)    Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  ------------  -----------  ---------  ----------  -----------  ----------------------------------------------------------------------------------------------------
     50.4    3,157,715,863        303  10,421,504.5  5,971,304.0  5,952,808  37,317,443  5,268,190.7  dequantize_block_q4_0(const void *, float *)
...

AFTER
 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)    Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  -----------  -----------  ---------  ----------  -----------  ----------------------------------------------------------------------------------------------------
...
      9.1      332,668,931        322  1,033,133.3    599,592.5    566,342   3,600,173    517,167.3  dequantize_block_q4_0(int, const void *, float *)
...

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Apr 29, 2023

That could explain something.

I took a look at the CL kernels and saw they didn't have any loops in them. Thinking this could be a nice way to improve performace, I changed them to work like the CUDA kernels do.

It was worse, a lot worse. I'm sure they could also be optimized, but we need to account for the "warps" and the "wavefronts". I'm all new to this stuff anyway, but it's been a good learning experience.

On a sidenote, there could also be a way to unify the CL and CUDA/ROCm kernels using some "clever" techniques like how CLBlast also supports CUDA

@slaren
Copy link
Collaborator

slaren commented Apr 29, 2023

Very interesting and very informative as always @dfyz! Do you think there would be any advantage to using cudaOccupancyMaxPotentialBlockSize instead of using a fixed number of threads?

My only concern with this is that currently adding new quantization types to cuBLAS is very easy, it's just an iteration of the reference C implementation. This looks a lot more complicated, and I am not sure if it is worth it if the bottleneck is elsewhere anyway. Could we get 90% of the way there just by improving occupancy? I know that the occupancy currently is very low, I had tested it with cudaOccupancyMaxActiveBlocksPerMultiprocessor during the initial implementation, but since it didn't affect the overall performance I preferred to keep it simple.

In any case, if this becomes the bottleneck in the future, I think we should definitely do this.

@dfyz
Copy link
Collaborator Author

dfyz commented Apr 30, 2023

@slaren

Yeah, I totally agree with your overall message -- adding new quantization methods (which, judging by the issues/discussions, will keep appearing) by porting reference implementation should be easy, and this is indeed way more complicated than that.

Let's keep this PR around if anyone becomes interested in fast CUDA quantization. It's a draft anyway, so I hope it shouldn't clutter anything up.

Do you think there would be any advantage to using cudaOccupancyMaxPotentialBlockSize instead of using a fixed number of threads?

For kernels of this kind, I don't think it makes much difference -- e.g., it says I should launch 768 threads instead of 1024 for dequantize_block_q4_0, which doesn't result in any noticeable change in runtime. But yes, using the result of cudaOccupancyMaxPotentialBlockSize instead of hardcoded constants makes the code cleaner. I should do this if this PR becomes relevant.

Could we get 90% of the way there just by improving occupancy?

I played around with this, and it seems to be one of these cases where "higher occupancy does not always equate to higher performance" (quoting CUDA Best Practices).

The original kernel runs in 1.66 milliseconds and 2,494,911 cycles, and has occupancy of ≈30% (because we only use 1 warp, it is limited by the number of block we can fit on an SM):
image

The easiest way to improve occupancy is to just increase the number of threads.

A hacky patch to do this
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 5a2701c..52eca3a 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -4,6 +4,8 @@
 #include <atomic>
 #include "ggml-cuda.h"

+#define THREAD_COUNT 768
+
 typedef uint16_t ggml_fp16_t;
 static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");

@@ -53,10 +55,12 @@ typedef struct {
 } block_q8_0;
 static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");

-static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
+static __global__ void dequantize_block_q4_0(int nb, const void * vx, float * y) {
     const block_q4_0 * x = (const block_q4_0 *) vx;

-    const int i = blockIdx.x;
+    const int i = blockIdx.x * THREAD_COUNT + threadIdx.x;
+
+    if (i >= nb) return;

     const float d = x[i].d;

@@ -199,7 +203,8 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {

 void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK4_0;
-    dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
+    const int n_thread_blocks = nb / THREAD_COUNT + (nb % THREAD_COUNT != 0);
+    dequantize_block_q4_0<<<n_thread_blocks, THREAD_COUNT, 0, stream>>>(nb, vx, y);
 }

 void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {

We get ≈90% occupancy after this:
image

However, the results get worse: 2.08 milliseconds and 3,114,724 cycles.

In contrast, the kernel from this PR (with 768 threads) only achieves ≈70% occupancy:
image

Even though the occupancy is lower, the speed is much, much better: 304.19 microseconds and 449,557 cycles.

My simplified interpretation of these results: since the kernel is memory-bound, speeding up is possible only if we a) decrease the overall number of memory transactions (by coalescing loads), or b) issue more loads when there's available memory bandwidth (by increasing occupancy). This PR targets a), and it seems like in this case this is the thing that matters.

If you also want to test this, use something like this:

sudo /usr/local/NVIDIA-Nsight-Compute/ncu -o report --set full -c 1 ./perplexity -m ~/ggml-model-q4_0.bin -f ~/wiki.test.raw

This will save a report about the very first kernel (which happens to be a dequantize kernel) to report.ncu-rep, which can then be opened in NSight Compute GUI. By increasing the value after -c, you can get more kernel launches in the report (but I think the first one is enough for performance analysis).

@dfyz
Copy link
Collaborator Author

dfyz commented Apr 30, 2023

@SlyEcho

On a sidenote, there could also be a way to unify the CL and CUDA/ROCm kernels using some "clever" techniques

Oh wow, I managed to totally miss the fact that we also have OpenCL dequantize kernels. I'm also new to OpenCL, but based on this, it supports pretty much every GPU imaginable for computations. Maybe we can just drop the custom CUDA/ROCm kernels and replace the with OpenCL ones?

@JohannesGaessler
Copy link
Collaborator

I have a GTX 1070. It seems that on this GPU increasing occupancy does lead to better performance. I did a low-effort PR #1341 that increased the number of threads per block to 256. This gave me ~14% faster prompt processing with 33b and a prompt with 399 tokens. When I re-ran the tests with the kernel in this PR prompt processing was ~57% faster compared to master.

@dfyz
Copy link
Collaborator Author

dfyz commented May 6, 2023

When I re-ran the tests with the kernel in this PR prompt processing was ~57% faster compared to master.

Whoa, this is nice to hear. Unfortunately I don't have a GTX 1070, but if there is interest, I can try speeding up other kernels in similar manner. I don't know if it's worth it, since it does make the code uglier and harder to iterate on.

(#1341 should definitely be merged independently of this PR, since #1341 is much simpler and cleaner than my approach)

@slaren
Copy link
Collaborator

slaren commented May 6, 2023

I am in favour of merging anything that improves overall performance, if there are significant gains to be had doing this I think it is great and should be done. In some mat muls dequantization is also a bottleneck in my 3080 too (for example see the timeline in #1269), though overall the difference seems small.

There may be some changes to the quantization formats happening soon (#1305, #1073) though, so it may be better to wait for those.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented May 7, 2023

Whoa, this is nice to hear. Unfortunately I don't have a GTX 1070, but if there is interest, I can try speeding up other kernels in similar manner. I don't know if it's worth it, since it does make the code uglier and harder to iterate on.

To put things into perspective, on my hardware the hacky implementation that I did takes ~38 ms per token and the implementation in this PR takes ~28 ms per token. Meanwhile generating a new token takes me ~500 ms (Edit: numbers are for 33b q4_0). So even for a relatively long prompt with 1000 tokens this PR only shaves off ~10 s of runtime which is the equivalent of ~20 new tokens. I typically try to generate at least 500 new tokens so the reduction of the total runtime would only be ~4%. For perplexity testing this would be very useful though since there the speedup of the entire program is comparatively larger.

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented May 7, 2023

... takes ~38 ms per token and the implementation in this PR takes ~28 ms per token. Meanwhile generating a new token takes me ~500 ms

What specs do you have, @JohannesGaessler? It seems very low performance. The 1070 should be about equal to my Vega, but I get about 6.40 ms per token with the 256 threads patch, new tokens 150 ms/run.

(CPU i7 7700K 4.7 GHz OC, RAM: 32 GB 2777 MT/s DDR4, Linux 6.3.1)

Can you try running perplexity with --no-mmap, you can use a smaller test file, usually I test with only the first 100 lines of the wiki test file, takes about a minute or couple.

@JohannesGaessler
Copy link
Collaborator

Are we talking about the same thing? Perhaps I should have clarified: I am talking about the performance for 33b q4_0 since 33b is the use case that I care about. For 7b I get 8.6 ms/t for prompt processing and 122 ms/t for generating new tokens.

My specs: Ryzen 3700X, 32GB RAM @ 3200 MHz, GTX 1070, Linux 6.3.0-1-MANJARO

Note that when I benchmarked performance I found that the speed at which new tokens are generated is essentially just proportional to memory bandwidth:

memory_scaling_1

@JohannesGaessler
Copy link
Collaborator

I quickly tried implementing a kernel for q4_0 myself: JohannesGaessler@390f0a9

On my hardware (GTX 1070) 7b perplexity on the first 100 lines of wikitext is 7% faster compared to the kernel in this PR (5.68 ms/t vs 6.07 ms/t). But more importantly, I think that my kernel is simpler (note that I implemented the kernel relative to #1341).

@dfyz
Copy link
Collaborator Author

dfyz commented May 8, 2023

But more importantly, I think that my kernel is simpler

Yup, I think this your version is pretty much what I started with. The additional complexity in my version (see "Other tweaks implemented here" in the PR description) gave me very noticeable speedups on Ampere-class GPUs, but apparently they make things worse for GTX 1070. I think we should go with a simpler version.

@dfyz
Copy link
Collaborator Author

dfyz commented May 15, 2023

Closing this, since this PR is outdated and largely superseded by @JohannesGaessler's efforts.

@dfyz dfyz closed this May 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants