Skip to content

Commit

Permalink
cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 5, 2023
1 parent e8457c9 commit b2acede
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stdio.h>
#include <atomic>
#include <assert.h>
#include <float.h>

#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
Expand Down Expand Up @@ -4587,20 +4588,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
block_q4_0 * dsti = (block_q4_0 *) cdsti;

float amax = 0.0f;
float max = 0.0f;
float vmax = 0.0f;

for (int j = 0; j < QK4_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
vmax = v;
}
}

const float d = max / -8;
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;

y[i].d = d;
dsti->d = d;

for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id;
Expand All @@ -4614,6 +4615,38 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
}
}

static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_1 * dsti = (block_q4_1 *) cdsti;

float vmin = FLT_MAX;
float vmax = -FLT_MAX;

for (int j = 0; j < QK4_1; ++j) {
const float v = xi[j];

if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}

const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;

dsti->dm.x = d;
dsti->dm.y = vmin;

for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (xi[0 + j] - vmin)*id;
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;

const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));

dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}

template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
Expand Down

0 comments on commit b2acede

Please sign in to comment.