Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cuBLAS performance by using a memory pool #1094

Merged
merged 4 commits into from
Apr 21, 2023

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Apr 20, 2023

Previously, the cuda memory was allocated and freed as needed for each mat mul operation, which is very inefficient.
By using a memory pool, this is about 30-50% faster in my machine.

PR:

7B q4_0 3.59 seconds per pass - ETA 0.65 hours
7B f16  4.59 seconds per pass - ETA 0.83 hours
7B f32  6.68 seconds per pass - ETA 1.22 hours

 7B q4_0 llama_print_timings: prompt eval time =  5493.53 ms /   631 tokens (    8.71 ms per token)
13B q4_0 llama_print_timings: prompt eval time =  9003.90 ms /   631 tokens (   14.27 ms per token)
30B q4_0 llama_print_timings: prompt eval time = 19682.41 ms /   631 tokens (   31.19 ms per token)

Master:

7B q4_0 5.09 seconds per pass - ETA 0.93 hours

 7B q4_0 prompt eval time =  7840.48 ms /   631 tokens (   12.43 ms per token)
13B q4_0 prompt eval time = 13826.48 ms /   631 tokens (   21.91 ms per token)

@slaren slaren mentioned this pull request Apr 20, 2023
@glinscott
Copy link
Collaborator

Very nice! On, 30B 8.58 seconds per pass now, 10.99 seconds per pass before.

Copy link
Collaborator

@dfyz dfyz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, would you be interested in an implementation that uses cuda{Malloc,Free}Async() with a cudaMemPool_t (see this section of the CUDA programming guide) instead of a custom memory pool? I can try to come up and benchmark an implementation based on that.

Pros:

  • harder to get wrong, smaller code diff
  • allows to set a memory limit based on the number of bytes, not the number of allocations
  • potentially faster

Cons:

  • CUDA-specific, harder to port to other GPU-aware BLAS'es
  • requires CUDA 11.2

ggml-cuda.cu Outdated
Comment on lines 188 to 189
if (std::atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
b->size = size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this introduces a race condition: another thread can observe a non-nullptr pointer with a wrong size. E.g., consider this execution history of two threads T0 and T1:

  • T0 calls ggml_cuda_pool_malloc(LARGE_SIZE). The pool is empty, so T0 calls cudaMalloc(LARGE_SIZE) and gets the resulting pointer pLarge.
  • T1 calls ggml_cuda_pool_malloc(SMALL_SIZE). The pool is again empty, to T1 calls cudaMalloc(SMALL_SIZE) and gets the resulting pointer pSmall.
  • T0 calls ggml_cuda_pool_free(pLarge). cuda_buffer_pool[0]->ptr is NULL, so T0 makes an update: cuda_buffer_pool[0] = {.ptr = pLarge, .size = LARGE_SIZE}.
  • T0 calls ggml_cuda_pool_malloc(LARGE_SIZE). cuda_buffer_pool[0]->size >= LARGE_SIZE && cuda_buffer_pool[0]->ptr != nullptr, so T0 makes an update (cuda_buffer_pool[0] = {.ptr = nullptr, .size = LARGE_SIZE}) and gets pLarge.
  • HERE BE DRAGONS T1 calls ggml_cuda_pool_free(pSmall). cuda_buffer_pool[0]->ptr is nullptr, so T1 tries to make an update. It makes a successful CAS on line 188 (so that cuda_buffer_pool[0] = {.ptr = pSmall, .size = LARGE_SIZE}), but then gets preempted by the OS scheduler.
  • T0 calls ggml_cuda_pool_malloc(LARGE_SIZE). cuda_buffer_pool[0]->size >= LARGE_SIZE && cuda_buffer_pool[0]->ptr != nullptr, so it makes an (irrelevant) update and gets pSmall.

So now T0 thinks it has LARGE_SIZE bytes, while in fact it only has SMALL_SIZE bytes.

ggml-cuda.cu Outdated
if (b->size >= size) {
uintptr_t ptr = atomic_load(&b->ptr);
if (ptr) {
if (std::atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this has an ABA problem? The scenario I'm thinking of goes like this:

  • we get preempted just before the CAS and lose the race to ptr to another thread
  • this other thread eventually frees ptr with cudaFree()
  • some other thread calls cudaMalloc() with a smaller size and get the same pointer as ptr (i.e., it is equal to it as an integer), then frees it into the same pool slot we are trying to use
  • we wake up, do the CAS (which succeeds because the new pointer is equal to the old one as integer) and start using the pointer with the wrong size

This is highly unlikely to happen in practice, but I think is technically possible, unless CUDART never returns the same pointer twice from cudaMalloc().

@slaren
Copy link
Collaborator Author

slaren commented Apr 21, 2023

@dfyz I think it is very unlikely that a general purpose allocator like that will be faster than this, but if you want to give it a try please do. On the other hand, if it isn't compatible with HIPS and #1087 is merged, we would have to write a custom allocator anyway.

Good points with the synchronization issues, this is always harder than it looks. In my tests llama doesn't make any concurrent mat muls with BLAS, so my thought here was mostly about making it future proof, but if that is too complicated we may as well replace it with a spin lock with zero performance impact right now.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Apr 21, 2023

HIP supports the same async malloc/free operations and mempool stuff. Technically, it should be checked from the device's feature flags, but that would make our code more difficult.

When I was first getting hipBLAS working, I also tried hipMallocAsync and it really didn't make any difference. When I look at the profile, hipMalloc/Free is a tiny part of the matmul operation which is still dominated by hipMemcpy. Maybe if I try the changes in this PR it will reveal something different.

I think the only real impact would be to keep the big weight matrixes on the device permanently as much as it can fit.

@dfyz
Copy link
Collaborator

dfyz commented Apr 21, 2023

@slaren

I think it is very unlikely that a general purpose allocator like that will be faster than this

Yeah, I agree a general purpose allocator would have no chance here. My hope here is just that the cudaMemPool_t does something similar to your caching scheme, e.g., they say:

the driver attempts to reuse memory that was previously freed via cudaFreeAsync() before attempting to allocate more memory from the OS. For example, memory freed in a stream can immediately be reused for a subsequent allocation request in the same stream

I don't know if it will work in practice, but it's a fun direction to explore (independently of this PR, of course).

Could you please clarify which commands (and which hardware) you used for benchmarking? I guess that the output with ETA [...] hours is the output of perplexity command, but I'm less sure about the prompt eval time lines.

we may as well replace it with a spin lock with zero performance impact right now

I think a spinlock is a great idea here!

@SlyEcho

When I look at the profile, hipMalloc/Free is a tiny part of the matmul operation

I'd need to take a look at the CUDA trace to confirm this, but I think that cudaMalloc() is notoriously slow, to the point that removing it gives you a visible speedup, even though the real bottleneck is data transfer. The situation might be different on AMD devices.

I think the only real impact would be to keep the big weight matrixes on the device permanently as much as it can fit

This makes sense, and I think the custom memory pool from this PR can be extended to handle this. Another fun direction to explore.

@ggerganov
Copy link
Owner

ggerganov commented Apr 21, 2023

@slaren

If you manage to write an efficient q4_x_q8 CUDA kernel you will probably achieve another factor of 2 or more speed-up.

Anyway - great work as always!

Regarding the race condition:
Without spending too much time on this, it looks to me it is not thread safe. @dfyz might have already outlined a problematic scenario and explanation, but I don't want to spend too much time on analysing this. Just from my experience, you won't be able to guarantee thread-safety here with just an atomic var.
I would recommend to remove the atomic pointers as we don't have a case where we do parallel BLAS calls.

On that note - would there be any sort of benefit theoretically from parallel GPU BLAS calls? I've tried on CPU and it doesn't help.

Can we add some parameter that specifies how much VRAM you want to be used at maximum?
For example, people with less VRAM would prefer to limit allocation to that certain point so they can run even the 65B models and still get benefits from the GPU

@SlyEcho
Copy link
Collaborator

SlyEcho commented Apr 21, 2023

There could be a lot of benefit from parallel GPU BLAS calls. For example we could use the GemmStridedBatched functions to compute the multiplication over the entire tensor instead of for-looping the outer dimensions (but I think llama.cpp doesn't need it anyway?).

But they also mention this:

On certain problem sizes, it might be advantageous to make multiple calls to cublasgemm in different CUDA streams, rather than use this API.

Currently, the computation (on AMD, at least, could be that Nvidia is better at this) there is just one row of operations for one matmul calculation.
image

@slaren, do you think it would be worth having multiple cudaStreams, so that some operations could happen at the same time as others?

(In OpenCL (like clBLAS) every operation gives you event objects that other operations can wait on, so the entire ggml compute graph could be posted to the GPU command stream and let it sort out the dependencies and compute in parallel. But this is my guess if it actually works like that)

@slaren
Copy link
Collaborator Author

slaren commented Apr 21, 2023

I would recommend to remove the atomic pointers as we don't have a case where we do parallel BLAS calls.

Yes, I will replace the whole synchronization of the memory pool with a spin lock for now, just to make sure that it doesn't cause issues in the future.

On that note - would there be any sort of benefit theoretically from parallel GPU BLAS calls?

I think so, currently the GPU sits idle while waiting for the data, and the PCIe bus is unused while computing the gemm, with more threads it would be possible to continue copying data for the next mat mul while a gemm is being computed in a different thread. Probably will require some changes to the cuda code, we would need to use a different stream per thread, but if it is not very hard to change ggml to allow this, I think it is worth giving it a try.

Can we add some parameter that specifies how much VRAM you want to be used at maximum?

This PR does increase VRAM usage, I didn't test with 65B but with 30B it is still very low, below what any discrete GPU would have. If we implement multi-threaded BLAS mat muls that could change, though.

@slaren
Copy link
Collaborator Author

slaren commented Apr 21, 2023

@slaren, do you think it would be worth having multiple cudaStreams, so that some operations could happen at the same time as others?

Yes, that's what I was thinking.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Apr 21, 2023

I think, maybe the simplest thing is two have two streams, for src0 and src1. Then, before the Gemm, synchronize them.

@slaren
Copy link
Collaborator Author

slaren commented Apr 21, 2023

If you manage to write an efficient q4_x_q8 CUDA kernel you will probably achieve another factor of 2 or more speed-up.

Here is an interesting article about how that could be done: https://siboehm.com/articles/22/CUDA-MMM

@slaren
Copy link
Collaborator Author

slaren commented Apr 21, 2023

Could you please clarify which commands (and which hardware) you used for benchmarking? I guess that the output with ETA [...] hours is the output of perplexity command, but I'm less sure about the prompt eval time lines.

Yes, the perplexity times is just from running perplexity with default params and the standard wiki test file. The only additional parameter is -t 12, which is what works well for me.

The prompt eval times are just from running main with a large prompt file and -b 512 -n 1 -t 12.

Edit: also use --no-mmap just to make sure that everything is loaded before the first eval.

@slaren
Copy link
Collaborator Author

slaren commented Apr 21, 2023

The synchronization issues should be resolved, and with a 30B q4_0 model I see a VRAM usage of ~2 GB, so I don't expect that it will be a problem even with 65B models on very low VRAM GPUs.

We could do better if we were able to predict the allocations, currently a single pass will make these allocations:

ggml_compute_forward_mul_mat_q_f32: (6656 x 6656) x (512 x 6656) => (512 x 6656)
cuda allocated 169 MB: total: 169 MB
cuda allocated 13 MB: total: 182 MB
cuda allocated 13 MB: total: 195 MB
cuda allocated 26 MB: total: 221 MB

ggml_compute_forward_mul_mat_q_f32: (17920 x 6656) x (512 x 6656) => (512 x 17920)
cuda allocated 455 MB: total: 676 MB
cuda allocated 35 MB: total: 711 MB
cuda allocated 71 MB: total: 782 MB

ggml_compute_forward_mul_mat_q_f32: (32000 x 6656) x (512 x 6656) => (512 x 32000)
cuda allocated 812 MB: total: 1595 MB
cuda allocated 126 MB: total: 1721 MB

In principle, we could allocate only the memory for the (32000 x 6656) x (512 x 6656) mat mul and reuse these buffers for the smaller mat muls. For now this isn't an issue, but that may change if we implement multi-threaded mat muls.

@slaren slaren merged commit 50cb666 into ggerganov:master Apr 21, 2023
@slaren slaren deleted the cuda-pool branch April 21, 2023 19:59
@Dampfinchen
Copy link

Dampfinchen commented Apr 27, 2023

@slaren I wonder if tensor cores could provide another huge speedup when optimizing for them. Judging by how GPU architectures with similar raw performance (one with and one without tensor cores) perform in llama's cuBLAS implementation right now, tensor cores are either not in use at all or used inefficiently. I've found the following documents about that matter, maybe you can look into them:

https://forums.developer.nvidia.com/t/turing-arch-int4-ops-with-tensor-cores/66656
https://docs.nvidia.com/cuda/cublas/index.html?highlight=tensor%20cores#tensor-core-usage

Thanks for your great work. I'm looking forward to your next innovation for cuBLAS!

@slaren
Copy link
Collaborator Author

slaren commented Apr 27, 2023

cuBLAS already uses tensors cores, see https://docs.nvidia.com/cuda/cublas/index.html#tensor-core-usage

Unfortunately, it seems that NVIDIA intends to deprecate INT4 support in tensor cores, and it has already been removed in sm_90 (Hopper / H100), but we may still benefit from it with RTX 20 to 40 series cards if we write our own mat mul kernel in the future.

@dfyz
Copy link
Collaborator

dfyz commented Apr 27, 2023

cuBLAS already uses tensors cores

As far as I can see (I'm running your changes from #1207), we only use cuBLAS for SGEMMs. TF32 tensor cores on Ampere and later (which can accelerate single-precision computations) are not used by default, since they result in reduced precision. You can test this by, for example, applying this patch:

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 0c01863..d8f2d6e 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -360,6 +360,8 @@ void ggml_init_cublas() {
         CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
         CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
 
+        CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
+
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
     }

Here's what the first matmul in the transformer layer looks like without the patch applied (i.e., without tensor cores):
Screenshot 2023-04-28 at 01 12 55

And here's what it looks like with the patch applied (i.e., with tensor cores):
Screenshot 2023-04-28 at 01 13 55

Without tensor cores, the SGEMM takes much longer than the dequantize kernel. With tensor cores, it's the other way round.

I don't know if we should enable TF32 tensor cores, though. I might be measuring it wrong, but we are actually not bottlenecked by the SGEMM speed, so while tensor cores give an impressive speed-up, the overall prompt processing time stays largely the same.

@slaren
Copy link
Collaborator Author

slaren commented Apr 27, 2023

Nice! I see a 5% overall speedup enabling tensor cores in my current testing branch, it's not much, but it may be worth running a perplexity test to see if the loss of precision is not too bad.

Starting with cuBLAS version 11.0.0, the library will automatically make use of Tensor Core capabilities wherever possible, unless they are explicitly disabled by selecting pedantic compute modes in cuBLAS

I guess the documentation is wrong here, or maybe the heuristic is failing to choose the fastest algorithm.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Apr 27, 2023

I found a partial solution to the non-contiguous matrixes. If they are contiguous in 2D, they can be computed by Gemm (on all platforms) because we are looping over the higher dimensions. But this creates a lot of tiny multiplications that I wonder could be solved by that StridedBatchedGemm.

@slaren
Copy link
Collaborator Author

slaren commented Apr 27, 2023

This is the f16 x f32 with a lot of streams, converting f16 to f32 in the GPU and tensor cores:
image
Looks a lot better now.

@dfyz
Copy link
Collaborator

dfyz commented Apr 28, 2023

I'm a little confused about what's happening with V x KQ_soft_max and K x Q matmuls (the ones that are parallel) in my nsys profile. I don't see them happening on GPU at all:

Screenshot 2023-04-28 at 01 51 45

(I added some NVTX annotations with node labels to make it clear where the parallel matmuls are supposed to appear on the GPU)

I'm running this command with the latest changes from #1207 to process a relatively large prompt from the repo on an A100: ./main -b 512 -t 12 -n 1 -f prompts/dan.txt -m models/13B/ggml-model-q4_0.bin --no-mmap

What am I missing?

@slaren
Copy link
Collaborator Author

slaren commented Apr 28, 2023

This what two layers look like now (https://github.com/slaren/llama.cpp/tree/cuda-f16f32):
image

I suspect that most of the time between layers is the missing non-contiguous mat mul. So there is probably a lot to gain there.

@dfyz in this screenshot it is the noisy lines that look to happen at the start of the layer (actually it is the end of the previous layer). Maybe your prompt is not big enough? I am using the perplexity tool here.

@dfyz
Copy link
Collaborator

dfyz commented Apr 28, 2023

in this screenshot it is the noisy lines that look to happen at the start of the layer (actually it is the end of the previous layer). Maybe your prompt is not big enough? I am using the perplexity tool here.

Ah yes, I can see those now when I run the perplexity tool. I guess prompts/dan.txt is not large enough indeed.

I would still try to handle both matmuls (the missing non-contiguous one, and the one you currently throw multiple streams at) with cublasSgemmStridedBatched(), but it seems like a lot of experimentation is currently going on in cuda-f16f32, so I don't want to get in the way. :) Maybe multiple streams are indeed the way to go.

@slaren
Copy link
Collaborator Author

slaren commented Apr 28, 2023

My intuition is that this is better than batched because it allows us to keep copying memory while doing the compute, but I may be wrong. It is worth a try for sure.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Apr 28, 2023

I did an experiment where I tried to copy over even non-contiguous tensors using cublasSetVector(Async), looping over the outer dimensions and matrix rows, there is no matrix version of this, which is a shame. You can see that at least rocBLAS is using device kernels to accelerate the non-contiguous path and I assume cuBLAS is similar but it's not open source, so I can't tell.

EDIT: oh, the results were inconclusive, maybe I did something wrong.

@slaren
Copy link
Collaborator Author

slaren commented Apr 28, 2023

We could probably use cudaMemcpy2D or similar, but I think that the fastest way will be to upload it as non-contiguous and make it contiguous in a kernel.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Apr 28, 2023

@slaren I made this kind of crap, it seems to speed up from59 minutes to 47:

static cudaError_t ggml_cuda_copy_tensor_compatible_2D(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
    const uint64_t ne0 = src->ne[0];
    const uint64_t ne1 = src->ne[1];
    const uint64_t nb0 = src->nb[0];
    const uint64_t nb1 = src->nb[1];
    const uint64_t nb2 = src->nb[2];
    const uint64_t nb3 = src->nb[3];
    const enum ggml_type type = src->type;
    const size_t ts = GGML_TYPE_SIZE[type];
    const size_t bs = GGML_BLCK_SIZE[type];

    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
    if (ggml_is_contiguous(src)) {
        return cudaMemcpyAsync(dst, x, ts*ne0*ne1/bs, cudaMemcpyHostToDevice, stream);
    } else {
        GGML_ASSERT(nb0 == ts); // don't think about it now yet :(
        //fprintf(stderr, "cudaMemcpy2DAsync(dst, %ld*%ld/%ld, x, %ld, ts*ne0/bs, %ld, cudaMemcpyHostToDevice, stream);\n",
        //    ts, ne0, bs, nb1, ts, ne0, bs, ne1);
        return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
    }
}

Then just use that everywhere instead of cudaMemcpy2DAsync() and in ggml_compute_forward_mul_mat_use_blas() remove the checks for contiguous tensors.

EDIT: perplexity done, Q4_0 --memory_f32:

[649]6.2762,[650]6.2801,[651]6.2858,[652]6.2865,[653]6.2908,[654]6.2844,[655]6.2838,

llama_print_timings:        load time = 18068.80 ms
llama_print_timings:      sample time =     0.00 ms /     1 runs   (    0.00 ms per run)
llama_print_timings: prompt eval time = 2976711.34 ms / 335360 tokens (    8.88 ms per token)
llama_print_timings:        eval time =     0.00 ms /     1 runs   (    0.00 ms per run)
llama_print_timings:       total time = 3009391.94 ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants