diff --git a/train_gpt2.cu b/train_gpt2.cu index bf7f7bae8..913a65393 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -9,6 +9,11 @@ because these are faster (just read, no write). This is okay for all activations except for those in the residual stream, where the gradients have to add. We make sure that those parts work out ok and that we do a += as necessary. E.g., the layernorms are connected to the residuals so we += in layernorm backward. + +In this file we are using Mixed Precision training, so different activations, +paramaters, grads and buffers may be kept at different precisions, to take +advantage of the fast low-precision hardware in the latest GPUs (bf16/fp16), +and fp8 (coming soon^TM). */ #include @@ -29,18 +34,25 @@ the layernorms are connected to the residuals so we += in layernorm backward. // ---------------------------------------------------------------------------- // CUDA precision settings + +// turn on bf16 as default, done up here for now #define ENABLE_BF16 +// use bf16 (bfloat 16) #if defined(ENABLE_BF16) typedef __nv_bfloat16 floatX; typedef float floatN; #define CUBLAS_LOWP CUDA_R_16BF #define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F + +// use fp16 (note: this may require gradient scaler, currently not implemented!) #elif defined(ENABLE_FP16) typedef half floatX; typedef float floatN; #define CUBLAS_LOWP CUDA_R_16F #define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F + +// fallback for fp32 #else typedef float floatX; typedef float floatN; @@ -125,7 +137,7 @@ __device__ __host__ float random_f32(unsigned long long *state) { // random floa // This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU // todo - possibly overkill and we don't need such high quality random numbers? (tbd) // http://eiserloh.net/noise/SquirrelNoise5.hpp -__device__ __host__ constexpr unsigned int SquirrelNoise5( int positionX, unsigned int seed ) +__device__ __host__ constexpr unsigned int SquirrelNoise5(int positionX, unsigned int seed) { constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111 constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111 @@ -146,27 +158,27 @@ __device__ __host__ constexpr unsigned int SquirrelNoise5( int positionX, unsign mangledBits ^= (mangledBits >> 17); return mangledBits; } -__device__ __host__ constexpr unsigned int Get1dNoiseUint( int positionX, unsigned int seed ) +__device__ __host__ constexpr unsigned int Get1dNoiseUint(int positionX, unsigned int seed) { - return SquirrelNoise5( positionX, seed ); + return SquirrelNoise5(positionX, seed); } -__device__ __host__ constexpr unsigned int Get2dNoiseUint( int indexX, int indexY, unsigned int seed ) +__device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed) { constexpr int PRIME_NUMBER = 198491317; // Large prime number with non-boring bits - return SquirrelNoise5( indexX + (PRIME_NUMBER * indexY), seed ); + return SquirrelNoise5(indexX + (PRIME_NUMBER * indexY), seed); } -__device__ __host__ constexpr float Get1dNoiseZeroToOne( int index, unsigned int seed ) +__device__ __host__ constexpr float Get1dNoiseZeroToOne(int index, unsigned int seed) { constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF); - return (float)( ONE_OVER_MAX_UINT * (double) SquirrelNoise5( index, seed ) ); + return (float)(ONE_OVER_MAX_UINT * (double) SquirrelNoise5(index, seed)); } -__device__ __host__ constexpr float Get2dNoiseZeroToOne( int indexX, int indexY, unsigned int seed ) +__device__ __host__ constexpr float Get2dNoiseZeroToOne(int indexX, int indexY, unsigned int seed) { constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF); - return (float)( ONE_OVER_MAX_UINT * (double) Get2dNoiseUint( indexX, indexY, seed ) ); + return (float)(ONE_OVER_MAX_UINT * (double) Get2dNoiseUint(indexX, indexY, seed)); } -// Stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift) +// stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift) __device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) { // todo - is this stochastic rounding *too good*? can we cut any corners? unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); @@ -1116,7 +1128,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, const floatX beta_lowp = (floatX)beta; void* alpha_ptr = (CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F) ? (void*)&alpha_lowp : (void*)α void* beta_ptr = (CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F) ? (void*)&beta_lowp : (void*)β - + floatX* preatt = inp; cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, @@ -1150,7 +1162,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, B * NH, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT)); - + // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side num_blocks = CEIL_DIV(B * T * C, block_size); @@ -1194,7 +1206,7 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, float one = 1.0f; float zero = 0.0f; // backward to input, uses = in the backward pass (set the gradient) - cublasCheck(cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one, + cublasCheck(cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one, weight, CUBLAS_LOWP, C, dout, CUBLAS_LOWP, OC, &zero, dinp, CUBLAS_LOWP, C, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // backward to weight, uses += in the backward pass (accumulate the gradient) @@ -1254,14 +1266,14 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da unpermute_kernel_backward<<>>(scratch, dout, B, T, NH, HS); cudaCheck(cudaGetLastError()); // backward into datt - + cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, alpha_ptr, v, CUBLAS_LOWP, HS, T * HS, scratch, CUBLAS_LOWP, HS, T * HS, beta_ptr, datt, CUBLAS_LOWP, T, T * T, B * NH, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT)); // backward into dv cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, alpha_ptr, - scratch, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, beta_ptr, + scratch, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, beta_ptr, dv, CUBLAS_LOWP, HS, T * HS, B * NH, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT)); // backward into preatt @@ -1348,7 +1360,7 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf param_sizes[13] = L * C; // fcprojb param_sizes[14] = C; // lnfw param_sizes[15] = C; // lnfb - + // Set parameter sizes // floatN gives us an option to keep layernorm params in FP32 if we want to for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { @@ -1621,7 +1633,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->rng_state = 13371337; } -void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { +void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { // targets are optional and could be NULL // ensure the model was initialized or error out