Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Speedup dequantize kernels #1221

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <cuda_fp16.h>
Expand All @@ -7,7 +8,13 @@
typedef uint16_t ggml_fp16_t;
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");

#define QK4_0 32
#define WARP_SIZE 32
#define THREAD_COUNT 1024
#define WARP_COUNT (THREAD_COUNT / WARP_SIZE)

#define QK4_0 32
#define QK4_0_Q_BLOCKS_PER_WARP 2
#define QK4_0_Q_BLOCKS_PER_THREAD_BLOCK (WARP_COUNT * QK4_0_Q_BLOCKS_PER_WARP)
typedef struct {
float d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
Expand Down Expand Up @@ -53,26 +60,28 @@ typedef struct {
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");

static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static __global__ void dequantize_block_q4_0(int nb, const void * vx, float * y) {
const block_q4_0 * x = (const block_q4_0 *) vx;

const int i = blockIdx.x;

const float d = x[i].d;
const unsigned lane_id = threadIdx.x % WARP_SIZE;
const unsigned warp_id = threadIdx.x / WARP_SIZE;
const unsigned start_block_id = blockIdx.x * QK4_0_Q_BLOCKS_PER_THREAD_BLOCK + warp_id * QK4_0_Q_BLOCKS_PER_WARP;

const uint8_t * pp = x[i].qs;

for (int l = 0; l < QK4_0; l += 2) {
const uint8_t vi = pp[l/2];
if (start_block_id >= nb) {
return;
}

const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
#pragma unroll
for (int i = 0; i < QK4_0_Q_BLOCKS_PER_WARP; ++i) {
const int block_id = start_block_id + i;
const unsigned * int_qs = (unsigned *) x[block_id].qs;

const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
const unsigned int_id = lane_id / 8;
const unsigned shift = 4*(lane_id % 8);
const unsigned nibble = (int_qs[int_id] >> shift) & 0xf;
const float v = ((int)nibble - 8) * x[block_id].d;

y[i*QK4_0 + l + 0] = v0;
y[i*QK4_0 + l + 1] = v1;
y[block_id*QK4_0 + lane_id] = v;
}
}

Expand Down Expand Up @@ -197,9 +206,15 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}

static int get_thread_block_count(int nb, int q_blocks_per_thread_block) {
return nb / q_blocks_per_thread_block + (nb % q_blocks_per_thread_block != 0);
}

void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
assert(nb % QK4_0_Q_BLOCKS_PER_WARP == 0);
const int n_thread_blocks = get_thread_block_count(nb, QK4_0_Q_BLOCKS_PER_THREAD_BLOCK);
dequantize_block_q4_0<<<n_thread_blocks, THREAD_COUNT, 0, stream>>>(nb, vx, y);
}

void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
Expand Down