Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance issues with cuBLAS and a bug #1129

Closed
cmp-nct opened this issue Apr 22, 2023 · 29 comments
Closed

Performance issues with cuBLAS and a bug #1129

cmp-nct opened this issue Apr 22, 2023 · 29 comments
Labels

Comments

@cmp-nct
Copy link
Contributor

cmp-nct commented Apr 22, 2023

Performance with cuBLAS isn't there yet, it is more a burden than a speedup with llama eval in my tests.
In a simple benchmark case it is absolutely amazing, getting 10 million elements multiplied in F32 goes from 1+ seconds down to 20 milliseconds. So the improvement is a blast!

But in the llama case the overhead seems to be enormous, when enabling it generically the average computation time shoots from 300ms up to 1500ms using cuBLAS.
I feel like the memory should have been prepared beforehand and I don't think the thousands of CUDA cores are used.
Is that loop really the best way to do it ?

      for (int64_t i03 = 0; i03 < ne03; i03++) {
            for (int64_t i02 = 0; i02 < ne02; i02++) {
                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
                // compute
                CUBLAS_CHECK(
                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
                            ne01, ne11, ne10,
                            &alpha, d_X, ne00,
                                    d_Y, ne10,
                            &beta,  d_D, ne01));

ggml_compute_forward_mul_mat_use_blas() is evaluated too often, I think it would be better to make a generic "check" function that uses the node Operation type as input.
I'm experimenting with another approach in that case, not finished yet but I think we need to calculate the flops required.
Then the function is called too often (init, compute, finalize), I'm just modifying it to set a flag in the tensor instead but that's just experimental atm.

There is a bug in ggml.c which causes the matrix multiplication being executed once for all threads in CUDA when we have a 32 bit input and output.

else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
cur = 0;
}

I think it should rather be like this:

else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                            cur = 0;
                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                node->n_tasks = 1;
                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
                            }

If n_tasks is not reduced to 1 then threads are associated with the cuda one it will currently not stop each thread from using cuda (and surprisingly no crashes)

Finally: the computation with flops taking into consideration the memory time using a quick local benchmark at startup might be a good generic solution to decide if cuda should be enabled. Generically.
But for our use case we know beforehand if we need it, every layer is known and sizes are known.
So that flag in the tensor struct should be set manually (I currently use -1,1 and 0: -1 = generic, 1 = force cuda if possible, 0 = disable cuda)
There should be a way to use the full cuda cores, I have no experience in it (yet) but I'm sure we can get the multiplication through a lot faster.
It looks like the 30+attention heads are calculated in sequence, cuda can do that in one shot.

Also when cuda is enabled we need to make sure that the big matrix tensors are proper in memory for cuda, or they will be skipped. I didn't get to that point yet, still fighting ;)

@slaren
Copy link
Collaborator

slaren commented Apr 22, 2023

I don't think we should expect to be able to do everything with cuBLAS, if the user has enough VRAM for that they should be using a GPU implementation of llama instead. We could investigate the possibility of always keeping the largest matrices in the GPU, up to the available VRAM. I am not sure if that is going to be enough to use cuBLAS for generation, but would improve the performance with batch processing.

Most of the time the GPU is sitting idle waiting for data, that is expected for the way this is done. We could likely improve utilization by allowing ggml to work on multiple mat muls simultaneously. There was some discussion about this in #1094, but I am not familiar with the graph processing code in ggml and I am not sure what would take to do that.

The loop does nothing in llama.cpp, ne03 and ne02 are always 1. Otherwise we could improve performance by working on these matrices at the same time here, but that's not the case.

@slaren
Copy link
Collaborator

slaren commented Apr 22, 2023

I have tried implementing a cache for the quantized matrices (ie. the weights), and even in the best case where the entire model fits in VRAM, for me this is only ~30% faster in the perplexity test. Considering how much more VRAM this requires, I am not sure that it is worth it

Anyway, if anybody wants to give it a try, it is at slaren@5a606d5.

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Apr 22, 2023

It's true that for F32 x F32 it doesn't change the thread number but the function ggml_compute_forward_mul_mat_f32() will return return early in the BLAS case from threads other than 0;

static void ggml_compute_forward_mul_mat_f32(
        const struct ggml_compute_params * params,
        const struct ggml_tensor * src0,
        const struct ggml_tensor * src1,
              struct ggml_tensor * dst) {
// ...
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
        if (params->ith != 0) {
            return;
        }
// ...

The variable cur is used to calculate the temporary work buffer size, F32 x F32 doesn't need any temp workspace becaus it doesn't need to dequantize the matrixes, so it is 0.

The GEMM computation is not a very big part of the time taken, it really depends on the hardware but slow memory, slow video memory, slow PCI bus can totally ruin the performance.

@slaren, 30% improvement is pretty big in my opinion! On slower hardware it could be even more significant.

@slaren
Copy link
Collaborator

slaren commented Apr 22, 2023

If there is interest in that, I can implement it properly and open a PR, however do it in a safe way in ggml we would need #1093 to be able to tag the matrices that are constant and can be cached in the GPU safely. Currently, I am just assuming that when multiplying q x f32, the quantized matrix is the weights and therefore constant and safe to cache, but that may not always be the case for other uses of ggml.

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Apr 23, 2023

I think that a smart mix of GPU and CPU can be a very interesting library extension. Any improvements in that area will be awesome.

The 1093 suggestions was kind of approved by GG: that is to extend the tensor_node with more information.
Such as "flag_enable_cuda=-1,0,1" or a "HR name", I am also tagging it to be of "high" or "low" performance, to enable a better decision making for hybrid core scheduling (Intel 12,13+ gen).
One more such flag would be if the tensor is constant, another one is if the tensor could be precached in VRAM (or on demand).
In my own source I use a struct but GG was probably right to just use constant lenghts and non structs, separate them by prefix.
Basically there is no downside to add anything into it, it's tiny in length.
To add something two steps are needed:

  1. add the new variables
  2. initialize them in the tensor_impl function

Strange Performance:
The current performance disadvantage when enabling cuBLAS for all MUL operations is huge, it's 4-5 times slower than on CPU.
I've not benchmarked it yet, something must be increddibly slow. the GPU can hardly be slower in calculation, so the only thing that comes to my mind is the async memory copy being slow ?

Of course a tiny multiplication is not worth to be moved to GPU if the graph is mostly CPU based, but it shouldn't take 4-5 times longer when doing it that way. Maybe someone has insight into that already ?
I've had MUL timings of 1500ms per layer, where CPU is at below 300ms. That's odd.

@slaren
Copy link
Collaborator

slaren commented Apr 23, 2023

so the only thing that comes to my mind is the async memory copy being slow ?

I think that's all there is to it. The cuBLAS matrix multiplication is very fast, but the latency is high due to having to move the data between main RAM and the GPU. At some point, the matrix is big enough that it takes longer to multiply it in the CPU than to copy the data to and from the GPU. But that only happens with big enough matrices, in our case is when processing in large batches. @SlyEcho showed that here. It is the same issue with OpenBLAS, it is only worth using it if the cost of dequantizing the matrices is lower than the time gained from the faster matrix multiplication. 4-5 times slower doesn't seem wrong, that's why the minimum batch size to use BLAS is 32.

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Apr 23, 2023

Though it's really a huge delay. The PCIE bus is faster than RAM was a couple years ago. I'll try benchmark that area (though the async part makes it a bit harder).

One function I was working on (which needs the flags mentioned earlier to be useable) is an actual smart decision maker.
I wanted to benchmark the system on startup with a quick memory, gpu and cpu test.
Then calculate the flops required and the memory movement required for GPU and for CPU work.
This can be used to make a smart decision on actual expected performance difference (better than >32 dimensions*3)

@gjmulder
Copy link
Collaborator

gjmulder commented Apr 26, 2023

@cmp-nct @slaren

EDIT2: Even though I had the performance power governor enabled the PCIe buses were powersaving by switching to 2.5 GT/sec. All results below are with 8GT/sec lanes, across 4 old GTX 1080Ti GPUs with varying numbers of PCIe lanes.

Here's some data points which may be of interest with a 65B q4_0 model, and for the first 406 lines of wiki.test.raw .

CPU GPU PCIe Gen3 Lanes Batch Size GPU Memory Util % (*) Perplexity per pass (secs)
AMD 1650 16 Core NVidia 1080Ti 4 8 0 233
AMD 1650 16 Core NVidia 1080Ti 4 64 7 206
AMD 1650 16 Core NVidia 1080Ti 4 128 8 124
AMD 1650 16 Core NVidia 1080Ti 4 256 6 65
AMD 1650 16 Core NVidia 1080Ti 4 512 7 47
AMD 1650 16 Core NVidia 1080Ti 8 512 36
AMD 1650 16 Core NVidia 1080Ti 16 512 32
AMD 1650 16 Core NVidia 1080Ti 16 512 10 31

EDIT1:
(*) Percent of time over the past sample period during which global (GPU device) memory was being read or written.

$ git log | head -1
commit 54bb60e26858be251a0eb3cb70f80322aff804a0

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Apr 26, 2023

It may be useful to see some kind of profile of what the GPU is doing in llama.cpp. For AMD, I can use rocprof, I tried to see if something similar is available on Nvidia, and there was just some confusing info.

@dfyz
Copy link
Collaborator

dfyz commented Apr 26, 2023

For AMD, I can use rocprof, I tried to see if something similar is available on Nvidia, and there was just some confusing info

For high-level tracing of what's going on in the system, NSight Systems is the way to go. It will give you a trace very similar to what rocprof provides, with CPU and GPU activity at microsecond-level granularity.

With lots of kernels, looking at the trace can get overwhelming, so you might want to add some markers/ranges to pinpoint the problematic MUL operations.

@slaren
Copy link
Collaborator

slaren commented Apr 26, 2023

I found some inefficiencies in the current master, there was a copy that could be made in parallel with the dequantize.

In branch cuda-cache, using pinned memory, an additional stream to copy and dequantize at the same time, and caching the weights, it is currently like this:
image

Overall this brings perplexity time from 40m to 25m, so it's a nice improvement.

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Apr 26, 2023

@slaren I tried your branch on HIP, it seems to be running the dequant and right after that the GEMM (Cijk_...) VRAM is also used more.

Speed is only a little better, from 59 minutes (5.49 s/p) to 44 (4.10 s/p)

image

EDIT: I think there might be something to optimize in the kernel launch arguments, you know with group and local sizes, etc.

@slaren
Copy link
Collaborator

slaren commented Apr 26, 2023

I think there might be something to optimize in the kernel launch arguments, you know with group and local sizes, etc.

I tried that when I made the initial implementation of dequantization on the GPU using cudaOccupancyMaxPotentialBlockSize, but it made no difference, but maybe now that everything is faster it is worth giving it another try.

Edit: Tested it again and still couldn't notice a difference.

@jon-chuang
Copy link
Contributor

In branch cuda-cache, using pinned memory, an additional stream to copy and dequantize at the same time, and caching the weights, it is currently like this:

Why doesn't CUDA do the DTH overlapped with the sgemm kernel? Isn't cublasSgemm intelligent enough to determine when some rows are complete and can hence be copied?

@slaren
Copy link
Collaborator

slaren commented Apr 26, 2023

No, I don't think it is.

@jon-chuang
Copy link
Contributor

Here is what GPT-3.5 suggests to do:

// Set up CUDA streams
cudaStream_t compStream, copyStream;
cudaStreamCreate(&compStream);
cudaStreamCreate(&copyStream);

// Set up matrices A, B, and C using cuBLAS functions
// ...

// Set up variables for subset computation and copy
int subsetSize = 100;  // Size of each subset of rows
int numSubsets = m / subsetSize;  // Total number of subsets
int subsetIdx = 0;  // Index of current subset being computed
int offset = 0;  // Offset to current subset of rows in matrix C

// Loop over all subsets of rows
for (int i = 0; i < numSubsets; i++) {
    // Set the computation stream for cublasSgemm
    cublasSetStream(handle, compStream);

    // Compute subset of rows in matrix C using cublasSgemm
    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, subsetSize, n, k, &alpha, 
                &d_A[offset], lda, &d_B[0], ldb, &beta, &d_C[offset], ldc);

    // Set the copy stream for cudaMemcpyAsync
    cudaStreamAttachMemAsync(copyStream, &d_C[offset], 0, cudaMemAttachSingle);

    // Copy computed subset of rows from device to host asynchronously
    cudaMemcpyAsync(&h_C[offset], &d_C[offset], subsetSize * n * sizeof(float),
                    cudaMemcpyDeviceToHost, copyStream);

    // Update subset index and offset
    subsetIdx++;
    offset = subsetIdx * subsetSize * n;
}

// Wait for all asynchronous operations to complete
cudaStreamSynchronize(compStream);
cudaStreamSynchronize(copyStream);

@jon-chuang
Copy link
Contributor

From your profiling, it looks like another 25-30% improvement if implemented.

@slaren
Copy link
Collaborator

slaren commented Apr 26, 2023

If you want to make a proof of concept and show the results go ahead, a snippet from ChatGPT does not help.

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Apr 26, 2023

I'm skeptical, these GEMM routines are super optimized to the hardware already depending on the exact problem sizes. Besides, the DeviceToHost copy is small and doesn't take much time.

@jon-chuang
Copy link
Contributor

jon-chuang commented Apr 26, 2023

If you want to make a proof of concept and show the results go ahead, a snippet from ChatGPT does not help.

I'm currently on holiday and won't have access to a CUDA machine for the next month.

The idea is clear, which is to split up the sgemm kernel to operate on row subsets rather than the entire matrix; thus the completed subsets can be synced to host asynchronously.

cudaStreamAttachMemAsync is what provides the synchronization between the compute and copy streams.

Besides, the DeviceToHost copy is small and doesn't take much time.

The profiling graph shows the exact opposite - that it takes 30% of the time. Unless the problem size profiled is much smaller than realistic? @slaren

these GEMM routines are super optimized to the hardware

This improvement is only intended for overlapping compute and communication, not improving compute.

@slaren
Copy link
Collaborator

slaren commented Apr 27, 2023

Here is what an entire layer(*) looks like:
image

Most of the HostToDevice copies are already done in parallel with the dequantize so I don't think there is much to gain there. Maybe from the DeviceToHost copies.

(*) This is missing a bunch of f16 x f32 mat muls (V x KQ_soft_max) which don't have a dequantize step. This is a single tensor mat mul of (512 x 128 x 32) x (512 x 512 x 32), which is split into 32 mat muls, and each of those could be done in parallel for some gains. Currently, each of those is done sequentially, and overall takes about 30% of the per layer time.

This is what each of the 32 mat muls look like:
image
Yes that looks awful, I am not sure why there is much delay between the operations. There is a f32 to f16 conversion on the CPU, but that should be between the two HostToDevice copies. Maybe cuBLAS doesn't like f16 x f16 mat muls?

There is also another K x Q mat mul with size (128 x 512 x 32) x (128 x 512 x 32) which isn't done on the GPU at all because the matrices aren't contiguous in memory. That's another possible source of performance improvements.

Now, the big gains would come from implementing the little operations that are done between the mat muls in CUDA… so that the matrices don't have to be copied back and forth between the CPU and the device… and then we could do everything entirely on the GPU… maybe write our own q x f32 mat mul kernel so that we don't have to dequantize first… and you know where this goes, at that point ggml would be entirely GPU based, and using "cuBLAS" for generation would be possible, and it would be an entirely different project. But it would something fun to attempt in "what if" fork.

@dfyz
Copy link
Collaborator

dfyz commented Apr 27, 2023

This is a single tensor mat mul of (512 x 128 x 32) x (512 x 512 x 32), which is split into 32 mat muls, and each of those could be done in parallel for some gains.

I need to look at how exactly we multiply V x KQ_soft_max and K x Q in the ggml graph used by llama, but typically, cublasSgemm*Batched() functions are used to perform these two multiplications. You can use cublasSgemmStridedBatched() if these 32 matrices are placed at a constant distance relative to each other, and cublasSgemmBatched() if they are not (in this case, you also have to transfer the pointers to the matrices to the device).

Here's how FasterTransformer does it, for example. Note that I'm linking to a very old version of FT for clarity, since I don't think the current versions illustrate this concept very clearly.

Maybe cuBLAS doesn't like f16 x f16 mat muls?

I'm guessing from the kernel names that you have an Ampere GPU. Those should have FP16 tensor cores and absolutely no problems with f16 x f16 matmuls (in fact, it's f32 x f32 matmuls that are second-class citizens). Besides, the actual matmul kernels are so tiny in your screenshot I can't even see their names. You should probably check what the CPU was doing between scheduling the copy and the matmul kernels (I think that if you click on a GPU activity, NSight Systems will show you the corresponding CUDA API call).

[...] the big gains would come from implementing the little operations that are done between the mat muls in CUDA [...]

It may turn out to be not that hard. I believe that you only need a softmax kernel, a layer norm kernel, and whatever activation function is used between the FFN layers in the model you choose. But even in a "what-if" scenario, I personally don't think it would be much fun. My impression is you will basically end up with a very stripped down version of FasterTransformer/ONNX Runtime, and at this point there's is no good reason to not just use one of those.

I might be wrong, but so far I'm understanding that the philosophy of llama.cpp is:

  • we don't need GPUs with a lot of compute for generation, since it's essentially a lot of matrix-vector (or almost matrix-vector) products, which are memory-bound anyway
  • but if you happen to have any kind of GEMM accelerator (NVIDIA GPU, AMD GPU, Apple AMX, etc.) around, we will use it whenever we need to compute large matrix-matrix products (i.e., in prompt processing), which are compute-bound

If we look at things this way, this actually sounds pretty convincing:

[...] maybe write our own q x f32 mat mul kernel so that we don't have to dequantize first [...]

If we allow ourselves to use CUTLASS (which is a big if, since it's a large dependency), this is pretty doable (e.g., see this discussion) and should give a nice speedup.

@dfyz
Copy link
Collaborator

dfyz commented Apr 27, 2023

cudaStreamAttachMemAsync() is what provides the synchronization between the compute and copy streams

Wait, what? Could you please clarify how exactly cudaStreamAttachMemAsync() in copyStream prevents cudaMemcpyAsync() (A) in copyStream from starting before cublasSgemm() (S) is finished in compStream? As far as I can see, A and S are entirely unrelated.

@jon-chuang
Copy link
Contributor

jon-chuang commented Apr 27, 2023

cudaStreamAttachMemAsync() in copyStream prevents cudaMemcpyAsync() (A) in copyStream from starting before cublasSgemm() (S) is finished in compStream?

Yes, I think this was wrong, perhaps this will work:

// Set up CUDA streams and events
cudaStream_t *compStreams, *copyStreams;
cudaEvent_t *events;
compStreams = new cudaStream_t[numSubsets];
copyStreams = new cudaStream_t[numSubsets];
events = new cudaEvent_t[numSubsets];
for (int i = 0; i < numSubsets; i++) {
    cudaStreamCreate(&compStreams[i]);
    cudaStreamCreate(&copyStreams[i]);
    cudaEventCreate(&events[i]);
}

// Set up matrices A, B, and C using cuBLAS functions
// ...

// Set up variables for subset computation and copy
int subsetSize = 100;  // Size of each subset of rows
int numSubsets = m / subsetSize;  // Total number of subsets
int offset = 0;  // Offset to current subset of rows in matrix C

// Loop over all subsets of rows
for (int i = 0; i < numSubsets; i++) {
    // Set the computation stream for cublasSgemm
    cublasSetStream(handle, compStreams[i]);

    // Compute subset of rows in matrix C using cublasSgemm
    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, subsetSize, n, k, &alpha, 
                &d_A[offset], lda, &d_B[0], ldb, &beta, &d_C[offset], ldc);

    // Record event after computation is completed
    cudaEventRecord(events[i], compStreams[i]);

    // Set the copy stream for cudaMemcpyAsync
    cudaMemcpyAsync(&h_C[offset], &d_C[offset], subsetSize * n * sizeof(float),
                    cudaMemcpyDeviceToHost, copyStreams[i]);

    // Update subset offset
    offset += subsetSize * n;
}

// Wait for all events to complete
cudaDeviceSynchronize();

// Clean up the computation streams and events
for (int i = 0; i < numSubsets; i++) {
    cudaStreamDestroy(compStreams[i]);
    cudaEventDestroy(events[i]);
}
delete[] compStreams;
delete[] events;

// Wait for all copy operations to complete
for (int i = 0; i < numSubsets; i++) {
    cudaStreamSynchronize(copyStreams[i]);
}

// Clean up the copy streams
for (int i = 0; i < numSubsets; i++) {
    cudaStreamDestroy(copyStreams[i]);
}
delete[] copyStreams;

It explicitly synchronizes subsets of the rows with per-subset CudaEvent_t

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Apr 27, 2023

There is a f32 to f16 conversion on the CPU, but that should be between the two HostToDevice copies. Maybe cuBLAS doesn't like f16 x f16 mat muls?

I thought about doing the f16→f32 conversion on the GPU with a simple kernel, after that a normal Sgemm could be done.

There is also another K x Q mat mul with size (128 x 512 x 32) x (128 x 512 x 32) which isn't done on the GPU at all because the matrices aren't contiguous in memory.

This was done because Gemm doesn't support them, and for OpenBLAS it would have meant to copy the matrix to temporary memory. But for the GPU we are copying anyway so I think it could be worth it.

Regarding the Strided/BatchedGemm for tensors, Nvidia mentions this in the docs:

On certain problem sizes, it might be advantageous to make multiple calls to cublasgemm in different CUDA streams, rather than use this API.

I don't know how many streams we would need. Creating a stream is expensive, when I did my hipBLAS code a month ago I didn't know this and gave up when it was too slow when I was creating and destoying a stream for every multiplication.

@slaren
Copy link
Collaborator

slaren commented Apr 27, 2023

I tried using a different stream for each of the 32 mat muls, but didn't notice any improvements in the performance. That's with pre-initialized streams, so it shouldn't have the overhead of creating and destroying the streams in each mat mul. However, I did observe a minor improvement in performance by doing the mat muls in f32 instead of f16. My GPU is indeed an Ampere RTX 3080, but there may be other factors. I have not looked at the difference in nsight yet (it doesn't work under WSL2), but it may be simply a case that converting f16 to f32 on the CPU is faster than converting f32 to f16. The f16 matrix is smaller too, so there is less data to convert that way. This conversion is necessary because the original data types for the mat mul are f16 x f32, and cuBLAS cannot do that directly.

Overall, after all the changes, I am seeing a speedup of 60-70% in the perplexity times. I am going to start opening PRs with each of the changes individually to make it easier to review. Some of the changes may cause some friction, in particular implementing host memory pinning and the weight cache may require changes to ggml that are dangerously close to adding GPU specific details to the ggml interface, which is not what we want (see this comment).

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Apr 27, 2023

What about converting f16→f32 on the GPU? it would be less data to copy.

@dfyz
Copy link
Collaborator

dfyz commented Apr 29, 2023

@slaren Since you and @SlyEcho have the cuBLAS stuff proper covered, I decided to take a look at the quantization kernels in #1221. Turns out they can be optimized quite a bit without introducing complexity (the code arguably becomes even simpler). If you are interested, I can make a real PR out of the draft I posted.

(at this point, the discussion in this issue is entirely unrelated to the initial issue description, but I'm not sure where else we should talk about cuBLAS/CUDA-related stuff)

@github-actions github-actions bot added the stale label Mar 25, 2024
Copy link
Contributor

github-actions bot commented Apr 9, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants