Skip to content

Commit

Permalink
minor comment fixes, more to come
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Apr 23, 2024
1 parent 2491402 commit fd7da62
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ because these are faster (just read, no write). This is okay for all activations
except for those in the residual stream, where the gradients have to add. We make
sure that those parts work out ok and that we do a += as necessary. E.g.,
the layernorms are connected to the residuals so we += in layernorm backward.
In this file we are using Mixed Precision training, so different activations,
paramaters, grads and buffers may be kept at different precisions, to take
advantage of the fast low-precision hardware in the latest GPUs (bf16/fp16),
and fp8 (coming soon^TM).
*/

#include <stdio.h>
Expand All @@ -29,18 +34,25 @@ the layernorms are connected to the residuals so we += in layernorm backward.

// ----------------------------------------------------------------------------
// CUDA precision settings

// turn on bf16 as default, done up here for now
#define ENABLE_BF16

// use bf16 (bfloat 16)
#if defined(ENABLE_BF16)
typedef __nv_bfloat16 floatX;
typedef float floatN;
#define CUBLAS_LOWP CUDA_R_16BF
#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F

// use fp16 (note: this may require gradient scaler, currently not implemented!)
#elif defined(ENABLE_FP16)
typedef half floatX;
typedef float floatN;
#define CUBLAS_LOWP CUDA_R_16F
#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F

// fallback for fp32
#else
typedef float floatX;
typedef float floatN;
Expand Down Expand Up @@ -125,7 +137,7 @@ __device__ __host__ float random_f32(unsigned long long *state) { // random floa
// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU
// todo - possibly overkill and we don't need such high quality random numbers? (tbd)
// http://eiserloh.net/noise/SquirrelNoise5.hpp
__device__ __host__ constexpr unsigned int SquirrelNoise5( int positionX, unsigned int seed )
__device__ __host__ constexpr unsigned int SquirrelNoise5(int positionX, unsigned int seed)
{
constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111
constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111
Expand All @@ -146,27 +158,27 @@ __device__ __host__ constexpr unsigned int SquirrelNoise5( int positionX, unsign
mangledBits ^= (mangledBits >> 17);
return mangledBits;
}
__device__ __host__ constexpr unsigned int Get1dNoiseUint( int positionX, unsigned int seed )
__device__ __host__ constexpr unsigned int Get1dNoiseUint(int positionX, unsigned int seed)
{
return SquirrelNoise5( positionX, seed );
return SquirrelNoise5(positionX, seed);
}
__device__ __host__ constexpr unsigned int Get2dNoiseUint( int indexX, int indexY, unsigned int seed )
__device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed)
{
constexpr int PRIME_NUMBER = 198491317; // Large prime number with non-boring bits
return SquirrelNoise5( indexX + (PRIME_NUMBER * indexY), seed );
return SquirrelNoise5(indexX + (PRIME_NUMBER * indexY), seed);
}
__device__ __host__ constexpr float Get1dNoiseZeroToOne( int index, unsigned int seed )
__device__ __host__ constexpr float Get1dNoiseZeroToOne(int index, unsigned int seed)
{
constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF);
return (float)( ONE_OVER_MAX_UINT * (double) SquirrelNoise5( index, seed ) );
return (float)(ONE_OVER_MAX_UINT * (double) SquirrelNoise5(index, seed));
}
__device__ __host__ constexpr float Get2dNoiseZeroToOne( int indexX, int indexY, unsigned int seed )
__device__ __host__ constexpr float Get2dNoiseZeroToOne(int indexX, int indexY, unsigned int seed)
{
constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF);
return (float)( ONE_OVER_MAX_UINT * (double) Get2dNoiseUint( indexX, indexY, seed ) );
return (float)(ONE_OVER_MAX_UINT * (double) Get2dNoiseUint(indexX, indexY, seed));
}

// Stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift)
// stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift)
__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) {
// todo - is this stochastic rounding *too good*? can we cut any corners?
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
Expand Down Expand Up @@ -1116,7 +1128,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
const floatX beta_lowp = (floatX)beta;
void* alpha_ptr = (CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F) ? (void*)&alpha_lowp : (void*)&alpha;
void* beta_ptr = (CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F) ? (void*)&beta_lowp : (void*)&beta;

floatX* preatt = inp;
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_N,
Expand Down Expand Up @@ -1150,7 +1162,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
B * NH,
CUBLAS_LOWP_COMPUTE,
CUBLAS_GEMM_DEFAULT));

// now unpermute
// y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
num_blocks = CEIL_DIV(B * T * C, block_size);
Expand Down Expand Up @@ -1194,7 +1206,7 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
float one = 1.0f;
float zero = 0.0f;
// backward to input, uses = in the backward pass (set the gradient)
cublasCheck(cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one,
cublasCheck(cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one,
weight, CUBLAS_LOWP, C, dout, CUBLAS_LOWP, OC, &zero,
dinp, CUBLAS_LOWP, C, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// backward to weight, uses += in the backward pass (accumulate the gradient)
Expand Down Expand Up @@ -1254,14 +1266,14 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da
unpermute_kernel_backward<<<num_blocks, block_size>>>(scratch, dout, B, T, NH, HS);
cudaCheck(cudaGetLastError());
// backward into datt

cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, alpha_ptr,
v, CUBLAS_LOWP, HS, T * HS, scratch, CUBLAS_LOWP, HS, T * HS, beta_ptr,
datt, CUBLAS_LOWP, T, T * T, B * NH, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT));

// backward into dv
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, alpha_ptr,
scratch, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, beta_ptr,
scratch, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, beta_ptr,
dv, CUBLAS_LOWP, HS, T * HS, B * NH, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT));

// backward into preatt
Expand Down Expand Up @@ -1348,7 +1360,7 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf
param_sizes[13] = L * C; // fcprojb
param_sizes[14] = C; // lnfw
param_sizes[15] = C; // lnfb

// Set parameter sizes
// floatN gives us an option to keep layernorm params in FP32 if we want to
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
Expand Down Expand Up @@ -1621,7 +1633,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->rng_state = 13371337;
}

void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {
void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {
// targets are optional and could be NULL

// ensure the model was initialized or error out
Expand Down

0 comments on commit fd7da62

Please sign in to comment.