Skip to content

Commit

Permalink
Implement initial fp8 support
Browse files Browse the repository at this point in the history
NVidia GPUs support two fp8 types: e5m2 and e4m3. PyTorch supports both
from version 2.1; note that safetensors currently does not support these
fully, but it will once this PR gets merged:
huggingface/safetensors#404

This change implements initial support for e5m2. e4m3 should be a better
fit in general, but:

- It has a smaller exponent range so it requires weight adjustment to
  fit into this range; Llama2 works fine without it but Mistral breaks
  due to small weights that get rounded to zero.
- More critically, NV GPUs only support fp8 to half/float conversion
  natively since Hopper (SM9.0). fp8e5m2 has a fast emulation path because
  it has the same exponent range as fp16 (similarly to bfloat16,
  conversion just requires zero padding), but fp8e4m3 emulation is
  impractically slow.

We currently just use builtin PyTorch conversion which results in an
aggregate ~0.5% perplexity drop. This probably can be improved in the
future.

Warp-parallel matmul needs to process 4 elements at a time now so that
we keep loading 4b per thread to maximize effective bandwidth.
  • Loading branch information
zeux committed Dec 29, 2023
1 parent e3a250a commit 551907d
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 71 deletions.
20 changes: 19 additions & 1 deletion src/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <assert.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <float.h>

__device__ inline float warpreduce_sum(float v) {
Expand Down Expand Up @@ -46,7 +47,8 @@ __device__ inline float blockreduce_max(float v) {
}

// regular mat*vec; naive and unoptimized (won't reach peak bw or flops)
__device__ inline float matmul(float* x, half* w, int i, int n) {
template <typename T>
__device__ inline float matmul(float* x, T* w, int i, int n) {
float val = 0.0f;
for (int j = 0; j < n; j++) {
val += float(w[i * n + j]) * x[j];
Expand All @@ -67,3 +69,19 @@ __device__ inline float matmul_warppar(float* x, half* w, int i, int n) {
}
return warpreduce_sum(val);
}

// warp-parallel mat*vec; each warp collaboratively computes mat*vec for a single row
// specialized for fp8 weights and ensures that we maximize transaction sizes by reading 4 bytes per thread
__device__ inline float matmul_warppar(float* x, __nv_fp8_e5m2* w, int i, int n) {
assert(n % (warpSize * 4) == 0);
int lane = threadIdx.x % warpSize;
float val = 0.0f;
for (int j = lane * 4; j < n; j += warpSize * 4) {
float4 ww = float4(*(__nv_fp8x4_e5m2*)&w[i * n + j]);
val += ww.x * x[j];
val += ww.y * x[j + 1];
val += ww.z * x[j + 2];
val += ww.w * x[j + 3];
}
return warpreduce_sum(val);
}
12 changes: 12 additions & 0 deletions src/infer.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,22 @@
#include <stdio.h>
#include <stdlib.h>

// we only support CPU inference for fp16 and only when the compiler supports it natively
#if defined(__FLT16_MANT_DIG__)
typedef _Float16 dtype_t;
#else
typedef short dtype_t;
#endif

void prepare(struct Transformer* transformer) {
struct Config* p = &transformer->config;
struct RunState* s = &transformer->state;

if (transformer->weights.dsize != 2) {
fprintf(stderr, "FATAL: CPU backend only supports fp16 weights\n");
abort();
}

// we calloc instead of malloc to keep valgrind happy
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
s->x = calloc(p->dim, sizeof(float));
Expand Down
76 changes: 46 additions & 30 deletions src/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
#include <math.h>
#include <stdio.h>

#include <cuda_fp16.h>

#include "helpers.cuh"

#define CUDA_CHECK(x) \
Expand Down Expand Up @@ -62,19 +60,19 @@ extern "C" void prepare_cuda(struct Transformer* transformer) {
weights->rms_att_weight[l] = (float*)cuda_devicecopy(weights->rms_att_weight[l], dim * sizeof(float));
weights->rms_ffn_weight[l] = (float*)cuda_devicecopy(weights->rms_ffn_weight[l], dim * sizeof(float));

weights->wq[l] = (dtype_t*)cuda_devicecopy(weights->wq[l], dim * dim * sizeof(dtype_t));
weights->wk[l] = (dtype_t*)cuda_devicecopy(weights->wk[l], dim * kv_dim * sizeof(dtype_t));
weights->wv[l] = (dtype_t*)cuda_devicecopy(weights->wv[l], dim * kv_dim * sizeof(dtype_t));
weights->wo[l] = (dtype_t*)cuda_devicecopy(weights->wo[l], dim * dim * sizeof(dtype_t));
weights->wq[l] = cuda_devicecopy(weights->wq[l], dim * dim * weights->dsize);
weights->wk[l] = cuda_devicecopy(weights->wk[l], dim * kv_dim * weights->dsize);
weights->wv[l] = cuda_devicecopy(weights->wv[l], dim * kv_dim * weights->dsize);
weights->wo[l] = cuda_devicecopy(weights->wo[l], dim * dim * weights->dsize);

weights->w1[l] = (dtype_t*)cuda_devicecopy(weights->w1[l], dim * hidden_dim * sizeof(dtype_t));
weights->w2[l] = (dtype_t*)cuda_devicecopy(weights->w2[l], dim * hidden_dim * sizeof(dtype_t));
weights->w3[l] = (dtype_t*)cuda_devicecopy(weights->w3[l], dim * hidden_dim * sizeof(dtype_t));
weights->w1[l] = cuda_devicecopy(weights->w1[l], dim * hidden_dim * weights->dsize);
weights->w2[l] = cuda_devicecopy(weights->w2[l], dim * hidden_dim * weights->dsize);
weights->w3[l] = cuda_devicecopy(weights->w3[l], dim * hidden_dim * weights->dsize);
}

weights->rms_final_weight = (float*)cuda_devicecopy(weights->rms_final_weight, dim * sizeof(float));
weights->token_embedding_table = (dtype_t*)cuda_devicecopy(weights->token_embedding_table, config->vocab_size * dim * sizeof(dtype_t));
weights->wcls = (dtype_t*)cuda_devicecopy(weights->wcls, dim * config->vocab_size * sizeof(dtype_t));
weights->token_embedding_table = cuda_devicecopy(weights->token_embedding_table, config->vocab_size * dim * weights->dsize);
weights->wcls = cuda_devicecopy(weights->wcls, dim * config->vocab_size * weights->dsize);

state->x = (float*)cuda_devicealloc(dim * sizeof(float));
state->xb = (float*)cuda_devicealloc(dim * sizeof(float));
Expand All @@ -91,7 +89,8 @@ extern "C" void prepare_cuda(struct Transformer* transformer) {
state->logits = (float*)cuda_hostalloc(config->vocab_size * sizeof(float));
}

__global__ static void kernel_embed(float* o, dtype_t* weight, int size) {
template <typename T>
__global__ static void kernel_embed(float* o, T* weight, int size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
assert(i < size);

Expand Down Expand Up @@ -122,7 +121,8 @@ __global__ static void kernel_rmsnorm(float* o, float* x, float* weight, int siz
}
}

__global__ static void kernel_matmul_cls(float* xout, float* x, dtype_t* w, int n, int d) {
template <typename T>
__global__ static void kernel_matmul_cls(float* xout, float* x, T* w, int n, int d) {
int i = blockIdx.x;
assert(i < d);

Expand All @@ -133,12 +133,13 @@ __global__ static void kernel_matmul_cls(float* xout, float* x, dtype_t* w, int
}
}

__global__ static void kernel_matmul_qkv(float* qout, float* kout, float* vout, float* x, dtype_t* wq, dtype_t* wk, dtype_t* wv, int n, int d, int kvd) {
template <typename T>
__global__ static void kernel_matmul_qkv(float* qout, float* kout, float* vout, float* x, T* wq, T* wk, T* wv, int n, int d, int kvd) {
int i = blockIdx.x;
assert(i < d + kvd * 2);

float* out = i < d ? qout : (i < d + kvd ? kout : vout);
dtype_t* w = i < d ? wq : (i < d + kvd ? wk : wv);
T* w = i < d ? wq : (i < d + kvd ? wk : wv);
int j = i < d ? i : (i < d + kvd ? i - d : i - d - kvd);

float val = matmul_warppar(x, w, j, n);
Expand All @@ -147,7 +148,8 @@ __global__ static void kernel_matmul_qkv(float* qout, float* kout, float* vout,
}
}

__global__ static void kernel_matmul_attn(float* xout, float* x, dtype_t* w, int n, int d) {
template <typename T>
__global__ static void kernel_matmul_attn(float* xout, float* x, T* w, int n, int d) {
int i = blockIdx.x;
assert(i < d);

Expand All @@ -159,7 +161,8 @@ __global__ static void kernel_matmul_attn(float* xout, float* x, dtype_t* w, int
}
}

__global__ static void kernel_matmul_ffn13(float* xout, float* x, dtype_t* w1, dtype_t* w3, int n, int d) {
template <typename T>
__global__ static void kernel_matmul_ffn13(float* xout, float* x, T* w1, T* w3, int n, int d) {
int i = blockIdx.x;
assert(i < d);

Expand All @@ -176,7 +179,8 @@ __global__ static void kernel_matmul_ffn13(float* xout, float* x, dtype_t* w1, d
}
}

__global__ static void kernel_matmul_ffn2(float* xout, float* x, dtype_t* w, int n, int d) {
template <typename T>
__global__ static void kernel_matmul_ffn2(float* xout, float* x, T* w, int n, int d) {
int i = blockIdx.x;
assert(i < d);

Expand Down Expand Up @@ -312,7 +316,8 @@ __global__ static void kernel_attn_mix(float* xout, float* attb, kvtype_t* valb,
}
}

extern "C" float* forward_cuda(struct Transformer* transformer, int token, int pos, unsigned flags) {
template <typename T>
static float* forward(struct Transformer* transformer, int token, int pos, unsigned flags) {
profiler_begin();

// a few convenience variables
Expand All @@ -336,7 +341,7 @@ extern "C" float* forward_cuda(struct Transformer* transformer, int token, int p

// copy the token embedding into x
assert(token < p->vocab_size);
kernel_embed<<<dim / 32, 32>>>(x, w->token_embedding_table + token * dim, dim);
kernel_embed<<<dim / 32, 32>>>(x, (T*)w->token_embedding_table + token * dim, dim);
profiler_trigger("embed", 0);

// forward all the layers
Expand All @@ -348,8 +353,8 @@ extern "C" float* forward_cuda(struct Transformer* transformer, int token, int p
profiler_trigger("rmsnorm", 0);

// qkv matmuls for this position
kernel_matmul_qkv<<<dim + kv_dim * 2, 32>>>(s->q, s->k, s->v, s->xb, w->wq[l], w->wk[l], w->wv[l], dim, dim, kv_dim);
profiler_trigger("matmul_qkv", (dim + kv_dim * 2) * dim * sizeof(dtype_t));
kernel_matmul_qkv<<<dim + kv_dim * 2, 32>>>(s->q, s->k, s->v, s->xb, (T*)w->wq[l], (T*)w->wk[l], (T*)w->wv[l], dim, dim, kv_dim);
profiler_trigger("matmul_qkv", (dim + kv_dim * 2) * dim * sizeof(T));

// RoPE relative positional encoding: complex-valued rotate q and k in each head, and update kv cache
assert(dim % 64 == 0 && kv_dim % 64 == 0);
Expand All @@ -375,19 +380,19 @@ extern "C" float* forward_cuda(struct Transformer* transformer, int token, int p
profiler_trigger("attn_mix", p->n_kv_heads * (pos + 1) * head_size * sizeof(kvtype_t));

// final matmul to get the output of the attention
kernel_matmul_attn<<<dim, 32>>>(x, s->xb, w->wo[l], dim, dim);
profiler_trigger("matmul_attn", dim * dim * sizeof(dtype_t));
kernel_matmul_attn<<<dim, 32>>>(x, s->xb, (T*)w->wo[l], dim, dim);
profiler_trigger("matmul_attn", dim * dim * sizeof(T));

// ffn rmsnorm
kernel_rmsnorm<<<1, rmsnorm_size>>>(s->xb, x, w->rms_ffn_weight[l], dim);
profiler_trigger("rmsnorm", 0);

// self.w2(F.silu(self.w1(x)) * self.w3(x)) + pre-rmsnorm residual
kernel_matmul_ffn13<<<hidden_dim, 32>>>(s->hb, s->xb, w->w1[l], w->w3[l], dim, hidden_dim);
profiler_trigger("matmul_ffn13", 2 * hidden_dim * dim * sizeof(dtype_t));
kernel_matmul_ffn13<<<hidden_dim, 32>>>(s->hb, s->xb, (T*)w->w1[l], (T*)w->w3[l], dim, hidden_dim);
profiler_trigger("matmul_ffn13", 2 * hidden_dim * dim * sizeof(T));

kernel_matmul_ffn2<<<dim, 32>>>(x, s->hb, w->w2[l], hidden_dim, dim);
profiler_trigger("matmul_ffn2", dim * hidden_dim * sizeof(dtype_t));
kernel_matmul_ffn2<<<dim, 32>>>(x, s->hb, (T*)w->w2[l], hidden_dim, dim);
profiler_trigger("matmul_ffn2", dim * hidden_dim * sizeof(T));
}

if (flags & FF_UPDATE_KV_ONLY) {
Expand All @@ -402,12 +407,23 @@ extern "C" float* forward_cuda(struct Transformer* transformer, int token, int p
profiler_trigger("rmsnorm", 0);

// classifier into logits
kernel_matmul_cls<<<p->vocab_size, 32>>>(s->logits, x, w->wcls, dim, p->vocab_size);
profiler_trigger("matmul_cls", p->vocab_size * dim * sizeof(dtype_t));
kernel_matmul_cls<<<p->vocab_size, 32>>>(s->logits, x, (T*)w->wcls, dim, p->vocab_size);
profiler_trigger("matmul_cls", p->vocab_size * dim * sizeof(T));

profiler_endsync();

CUDA_SYNC();

return s->logits;
}

extern "C" float* forward_cuda(struct Transformer* transformer, int token, int pos, unsigned flags) {
switch (transformer->weights.dsize) {
case 1:
return forward<__nv_fp8_e5m2>(transformer, token, pos, flags);
case 2:
return forward<half>(transformer, token, pos, flags);
default:
return NULL;
}
}
25 changes: 12 additions & 13 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@

#define MAX_LAYERS 128

// Can switch between float and _Float16 (model rebuild required)
// Can switch between float and _Float16
#ifdef __CUDACC__
typedef half dtype_t;
typedef half kvtype_t;
#elif defined(__FLT16_MANT_DIG__)
typedef _Float16 dtype_t;
typedef _Float16 kvtype_t;
#else
// We can't use _Float16 on CPU but we can still run CUDA
typedef short dtype_t;
typedef short kvtype_t;
#endif

Expand All @@ -31,24 +28,26 @@ struct Config {
};

struct Weights {
int dsize; // 1 for fp8, 2 for fp16; determines type of void* below

// token embedding table
dtype_t* token_embedding_table; // (vocab_size, dim)
void* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float* rms_att_weight[MAX_LAYERS]; // (dim) rmsnorm weights
float* rms_ffn_weight[MAX_LAYERS]; // (dim)
// weights for matmuls. note dim == n_heads * head_size
dtype_t* wq[MAX_LAYERS]; // (dim, n_heads * head_size)
dtype_t* wk[MAX_LAYERS]; // (dim, n_kv_heads * head_size)
dtype_t* wv[MAX_LAYERS]; // (dim, n_kv_heads * head_size)
dtype_t* wo[MAX_LAYERS]; // (n_heads * head_size, dim)
void* wq[MAX_LAYERS]; // (dim, n_heads * head_size)
void* wk[MAX_LAYERS]; // (dim, n_kv_heads * head_size)
void* wv[MAX_LAYERS]; // (dim, n_kv_heads * head_size)
void* wo[MAX_LAYERS]; // (n_heads * head_size, dim)
// weights for ffn
dtype_t* w1[MAX_LAYERS]; // (hidden_dim, dim)
dtype_t* w2[MAX_LAYERS]; // (dim, hidden_dim)
dtype_t* w3[MAX_LAYERS]; // (hidden_dim, dim)
void* w1[MAX_LAYERS]; // (hidden_dim, dim)
void* w2[MAX_LAYERS]; // (dim, hidden_dim)
void* w3[MAX_LAYERS]; // (hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// classifier weights for the logits, on the last layer
dtype_t* wcls;
void* wcls;
};

struct RunState {
Expand Down
Loading

0 comments on commit 551907d

Please sign in to comment.