From 67a82a611ef6135146e4bc496809d4d99d664f64 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 3 May 2024 23:08:44 +0300 Subject: [PATCH 1/4] first attempt at moving cudnn out of the main file for faster compiles --- Makefile | 15 ++- cudnn_att.cu | 319 ++++++++++++++++++++++++++++++++++++++++++++++++++ train_gpt2.cu | 280 ++------------------------------------------ 3 files changed, 339 insertions(+), 275 deletions(-) create mode 100644 cudnn_att.cu diff --git a/Makefile b/Makefile index 3af022564..865f6fa46 100644 --- a/Makefile +++ b/Makefile @@ -194,20 +194,23 @@ train_gpt2: train_gpt2.c test_gpt2: test_gpt2.c $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) $(OUTPUT_FILE) -train_gpt2cu: train_gpt2.cu - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE) +cudnn_att.o: cudnn_att.cu + $(NVCC) -c $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) + +train_gpt2cu: train_gpt2.cu cudnn_att.o + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) cudnn_att.o train_gpt2fp32cu: train_gpt2_fp32.cu - $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE) + $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) test_gpt2cu: test_gpt2.cu - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) test_gpt2fp32cu: test_gpt2_fp32.cu - $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE) + $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) profile_gpt2cu: profile_gpt2.cu - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) clean: $(REMOVE_FILES) $(TARGETS) diff --git a/cudnn_att.cu b/cudnn_att.cu new file mode 100644 index 000000000..a9ce3bf76 --- /dev/null +++ b/cudnn_att.cu @@ -0,0 +1,319 @@ +// all cudnn-related functions are in this file, so that they don't need to be recompiled everytime +// we change some unrelated piece of the code. +// TODO this currently duplicates some of the utilities from the main file + +#include +#include +#include + +// Specific configurations based on the enabled precision +#if defined(ENABLE_FP32) +typedef float floatX; + +// use fp16 (note: this may require gradient scaler, currently not implemented!) +#elif defined(ENABLE_FP16) +typedef half floatX; +#define CUBLAS_LOWP CUDA_R_16F + +#else // Default to bfloat16 +typedef __nv_bfloat16 floatX; +#endif + +// CUDA error checking +static void cudaCheck(cudaError_t error, const char *file, int line) { + if (error != cudaSuccess) { + printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, + cudaGetErrorString(error)); + exit(EXIT_FAILURE); + } +}; +#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__)) + +// Profiler utils +namespace { + class NvtxRange { + public: + NvtxRange(const char* s) { nvtxRangePush(s); } + + NvtxRange(const std::string& base_str, int number) { + std::string range_string = base_str + " " + std::to_string(number); + nvtxRangePush(range_string.c_str()); + } + + ~NvtxRange() { nvtxRangePop(); } + }; +} +#define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__) + +namespace fe = cudnn_frontend; +#if CUBLAS_LOWP == CUDA_R_16BF +#define CUDNN_16BIT fe::DataType_t::BFLOAT16 +#else +#define CUDNN_16BIT fe::DataType_t::HALF +#endif + +static cudnnHandle_t cudnn_handle; +static size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!) +static void* cudnn_workspace = NULL; +#define checkCudnnErr(err) assert((int)err == 0); + +using graph_tensors_fwd = std::tuple, +std::shared_ptr, // Q, +std::shared_ptr, // K, +std::shared_ptr, // V, +std::shared_ptr, // Attn_scale, +std::shared_ptr, // O +std::shared_ptr>; // Stats + +using graph_tensors_bwd = std::tuple, +std::shared_ptr, // Q, +std::shared_ptr, // K, +std::shared_ptr, // V, +std::shared_ptr, // O +std::shared_ptr, // dO +std::shared_ptr, // Stats +std::shared_ptr, // Attn_scale, +std::shared_ptr, // dQ, +std::shared_ptr, // dK, +std::shared_ptr>; // dV + +// Need a cache because graph->build_operation_graph() is slow but everything else seems fast +using cache_type_fwd = std::unordered_map; +using cache_type_bwd = std::unordered_map; + +// Loosely based on cuDNN frontend samples functions and massively simplified +template +auto lookup_cache_or_build_graph_fwd(Args... args) { + static cache_type_fwd user_maintained_cache_fwd; + auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...); + + auto graph = std::make_shared(); + graph->set_io_data_type(CUDNN_16BIT) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({B, H, T, HS}) + .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({B, H, T, HS}) + .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({B, H, T, HS}) + .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); + auto attn_scale = graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto sdpa_options = fe::graph::SDPA_attributes().set_name("flash_attention"); + sdpa_options.set_is_inference(is_inference_only); + sdpa_options.set_attn_scale(attn_scale); + sdpa_options.set_causal_mask(true); + + // Create the graph operation and get the output tensors back + auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); + + // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32 + O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}); + + assert(stats == nullptr || is_inference_only == false); + if (is_inference_only == false) { + stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) + .set_dim({B, H, T, 1}) + .set_stride({H * T, T, 1, 1}); + } + + assert(graph->validate().is_good()); + auto key = graph->key(); + auto it = user_maintained_cache_fwd.find(key); + if (it != user_maintained_cache_fwd.end()) { + return it->second; + } + + // Build the operation graph and execution part (this is the VERY SLOW PART) + assert(graph->build_operation_graph(cudnn_handle).is_good()); + auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); + assert(graph->check_support(cudnn_handle).is_good()); + assert(graph->build_plans(cudnn_handle).is_good()); + + auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats); + user_maintained_cache_fwd.insert({key, tuple}); + return tuple; +} + +template +auto lookup_cache_or_build_graph_bwd(Args... args) { + static cache_type_bwd user_maintained_cache_bwd; + auto [B, NH, T, HS] = std::make_tuple(args...); + + auto graph = std::make_shared(); + graph->set_io_data_type(CUDNN_16BIT) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + // (B, N, 3, NH, HS) + // must come from inp (which means we also need to convert THAT to FP16) + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({B, NH, T, HS}) + .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({B, NH, T, HS}) + .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({B, NH, T, HS}) + .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); + auto O = graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({B, NH, T, HS}) + .set_stride({NH * HS * T, HS, NH * HS, 1})); + auto dO = graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({B, NH, T, HS}) + .set_stride({NH * HS * T, HS, NH * HS, 1})); + + auto stats = graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({B, NH, T, 1}) + .set_stride({NH * T, T, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + auto attn_scale = graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(true) + .set_attn_scale(attn_scale); + + // Create the graph operation and get the output tensors back + auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options); + + dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); + dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); + dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); + + assert(graph->validate().is_good()); + auto key = graph->key(); + auto it = user_maintained_cache_bwd.find(key); + if (it != user_maintained_cache_bwd.end()) { + return it->second; + } + + // Build the operation graph and execution part (this is the VERY SLOW PART) + assert(graph->build_operation_graph(cudnn_handle).is_good()); + auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); + assert(graph->check_support(cudnn_handle).is_good()); + assert(graph->build_plans(cudnn_handle).is_good()); + + auto tuple = std::make_tuple(graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV); + user_maintained_cache_bwd.insert({key, tuple}); + return tuple; +} + +void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) + float* stats, // output for backward pass: (B, NH, T) + floatX* inp, // input: (B, T, 3, NH, HS) QKV + int B, int T, int NH, int C) { + NVTX_RANGE_FN(); + int HS = C / NH; // number of features per head + bool is_inference_only = (stats == nullptr); + + // Get graph and tensors from cache (or generate it on first use) + auto [graph, Q, K, V, attn_scale, O, softmax_stats] = + lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only); + + // Prepare all the tensor pointers for executing the graph + void* devPtrQ = inp; + void* devPtrK = (inp + C); + void* devPtrV = (inp + 2 * C); + float attn_scale_cpu = 1.0 / sqrtf(HS); + void* devPtrO = out; + + // Build variant pack + std::unordered_map, void*> variant_pack = { + {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}}; + + // Add the stats tensor unless we are only doing inference (only needed for backward pass) + if (is_inference_only == false) { + variant_pack[softmax_stats] = stats; + } + + // Reallocate the workspace if the required size is greater than the current workspace + // By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum + if (graph->get_workspace_size() > cudnn_workspace_size) { + if (cudnn_workspace_size > 0) { + cudaCheck(cudaFree(cudnn_workspace)); + } + cudnn_workspace_size = graph->get_workspace_size(); + cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); + } + + // Execute graph + assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); + cudaCheck(cudaGetLastError()); +} + +void attention_backward_cudnn(floatX* dqkvr, // output + floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs + int B, int T, int NH, int C) { + NVTX_RANGE_FN(); + int HS = C / NH; // number of features per head + + // Get graph and tensors from cache (or generate it on first use) + auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] = + lookup_cache_or_build_graph_bwd(B, NH, T, HS); + + // Prepare all the tensor pointers for executing the graph + void* devPtrQ = qkvr; + void* devPtrK = (qkvr + NH * HS); + void* devPtrV = (qkvr + 2 * NH * HS); + void* devPtrO = o; + void* devPtrdO = dout; + void* devPtrStats = stats; + float attn_scale_cpu = 1.0 / sqrtf(HS); + + void* devPtrdQ = dqkvr; + void* devPtrdK = (dqkvr + NH * HS); + void* devPtrdV = (dqkvr + 2 * NH * HS); + + // Build variant pack that links each tensor to its data pointer + std::unordered_map, void*> variant_pack = { + {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {O, devPtrO}, {dO, devPtrdO}, {Stats, devPtrStats}, + {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, + {attn_scale, &attn_scale_cpu}}; + + // Reallocate the workspace if the required size is greater than the current workspace + // By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum + if (graph->get_workspace_size() > cudnn_workspace_size) { + if (cudnn_workspace_size > 0) { + cudaCheck(cudaFree(cudnn_workspace)); + } + cudnn_workspace_size = graph->get_workspace_size(); + cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); + } + + // Execute graph + assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); + cudaCheck(cudaGetLastError()); +} + +void create_cudnn() { + checkCudnnErr(cudnnCreate(&cudnn_handle)); +} + +void destroy_cudnn() { + if (cudnn_workspace != NULL) { cudaCheck(cudaFree(cudnn_workspace)); } + checkCudnnErr(cudnnDestroy(cudnn_handle)); +} \ No newline at end of file diff --git a/train_gpt2.cu b/train_gpt2.cu index 084a7b860..3e90d7164 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -115,21 +115,6 @@ const ncclDataType_t ncclFloatX = ncclBfloat16; #endif #endif -#ifdef ENABLE_CUDNN -#include -namespace fe = cudnn_frontend; -#if CUBLAS_LOWP == CUDA_R_16BF -#define CUDNN_16BIT fe::DataType_t::BFLOAT16 -#else -#define CUDNN_16BIT fe::DataType_t::HALF -#endif - -static cudnnHandle_t cudnn_handle; -static size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!) -static void* cudnn_workspace = NULL; -#define checkCudnnErr(err) assert((int)err == 0); -#endif // ENABLE_CUDNN - // ---------------------------------------------------------------------------- // CUDA utils @@ -495,258 +480,20 @@ void printf0(const char *format, ...) { // ---------------------------------------------------------------------------- // cuDNN path #ifdef ENABLE_CUDNN - -using graph_tensors_fwd = std::tuple, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::shared_ptr, // Attn_scale, - std::shared_ptr, // O - std::shared_ptr>; // Stats - -using graph_tensors_bwd = std::tuple, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::shared_ptr, // O - std::shared_ptr, // dO - std::shared_ptr, // Stats - std::shared_ptr, // Attn_scale, - std::shared_ptr, // dQ, - std::shared_ptr, // dK, - std::shared_ptr>; // dV - -// Need a cache because graph->build_operation_graph() is slow but everything else seems fast -using cache_type_fwd = std::unordered_map; -using cache_type_bwd = std::unordered_map; - -// Loosely based on cuDNN frontend samples functions and massively simplified -template -auto lookup_cache_or_build_graph_fwd(Args... args) { - static cache_type_fwd user_maintained_cache_fwd; - auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...); - - auto graph = std::make_shared(); - graph->set_io_data_type(CUDNN_16BIT) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute - auto Q = graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({B, H, T, HS}) - .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); - auto K = graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({B, H, T, HS}) - .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); - auto V = graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({B, H, T, HS}) - .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); - auto attn_scale = graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - auto sdpa_options = fe::graph::SDPA_attributes().set_name("flash_attention"); - sdpa_options.set_is_inference(is_inference_only); - sdpa_options.set_attn_scale(attn_scale); - sdpa_options.set_causal_mask(true); - - // Create the graph operation and get the output tensors back - auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); - - // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32 - O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}); - - assert(stats == nullptr || is_inference_only == false); - if (is_inference_only == false) { - stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) - .set_dim({B, H, T, 1}) - .set_stride({H * T, T, 1, 1}); - } - - assert(graph->validate().is_good()); - auto key = graph->key(); - auto it = user_maintained_cache_fwd.find(key); - if (it != user_maintained_cache_fwd.end()) { - return it->second; - } - - // Build the operation graph and execution part (this is the VERY SLOW PART) - assert(graph->build_operation_graph(cudnn_handle).is_good()); - auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); - assert(graph->check_support(cudnn_handle).is_good()); - assert(graph->build_plans(cudnn_handle).is_good()); - - auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats); - user_maintained_cache_fwd.insert({key, tuple}); - return tuple; -} - -template -auto lookup_cache_or_build_graph_bwd(Args... args) { - static cache_type_bwd user_maintained_cache_bwd; - auto [B, NH, T, HS] = std::make_tuple(args...); - - auto graph = std::make_shared(); - graph->set_io_data_type(CUDNN_16BIT) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - // (B, N, 3, NH, HS) - // must come from inp (which means we also need to convert THAT to FP16) - auto Q = graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({B, NH, T, HS}) - .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); - auto K = graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({B, NH, T, HS}) - .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); - auto V = graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({B, NH, T, HS}) - .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); - auto O = graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({B, NH, T, HS}) - .set_stride({NH * HS * T, HS, NH * HS, 1})); - auto dO = graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({B, NH, T, HS}) - .set_stride({NH * HS * T, HS, NH * HS, 1})); - - auto stats = graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({B, NH, T, 1}) - .set_stride({NH * T, T, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - auto attn_scale = graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(true) - .set_attn_scale(attn_scale); - - // Create the graph operation and get the output tensors back - auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options); - - dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); - dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); - dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); - - assert(graph->validate().is_good()); - auto key = graph->key(); - auto it = user_maintained_cache_bwd.find(key); - if (it != user_maintained_cache_bwd.end()) { - return it->second; - } - - // Build the operation graph and execution part (this is the VERY SLOW PART) - assert(graph->build_operation_graph(cudnn_handle).is_good()); - auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); - assert(graph->check_support(cudnn_handle).is_good()); - assert(graph->build_plans(cudnn_handle).is_good()); - - auto tuple = std::make_tuple(graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV); - user_maintained_cache_bwd.insert({key, tuple}); - return tuple; -} - +// functions defined in cudnn_att.cu +void create_cudnn(); +void destroy_cudnn(); void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) float* stats, // output for backward pass: (B, NH, T) floatX* inp, // input: (B, T, 3, NH, HS) QKV - int B, int T, int NH, int C) { - NVTX_RANGE_FN(); - int HS = C / NH; // number of features per head - bool is_inference_only = (stats == nullptr); - - // Get graph and tensors from cache (or generate it on first use) - auto [graph, Q, K, V, attn_scale, O, softmax_stats] = - lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only); - - // Prepare all the tensor pointers for executing the graph - void* devPtrQ = inp; - void* devPtrK = (inp + C); - void* devPtrV = (inp + 2 * C); - float attn_scale_cpu = 1.0 / sqrtf(HS); - void* devPtrO = out; - - // Build variant pack - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}}; - - // Add the stats tensor unless we are only doing inference (only needed for backward pass) - if (is_inference_only == false) { - variant_pack[softmax_stats] = stats; - } - - // Reallocate the workspace if the required size is greater than the current workspace - // By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum - if (graph->get_workspace_size() > cudnn_workspace_size) { - if (cudnn_workspace_size > 0) { - cudaCheck(cudaFree(cudnn_workspace)); - } - cudnn_workspace_size = graph->get_workspace_size(); - cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); - } - - // Execute graph - assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); - cudaCheck(cudaGetLastError()); -} + int B, int T, int NH, int C); void attention_backward_cudnn(floatX* dqkvr, // output floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs - int B, int T, int NH, int C) { - NVTX_RANGE_FN(); - int HS = C / NH; // number of features per head - - // Get graph and tensors from cache (or generate it on first use) - auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] = - lookup_cache_or_build_graph_bwd(B, NH, T, HS); - - // Prepare all the tensor pointers for executing the graph - void* devPtrQ = qkvr; - void* devPtrK = (qkvr + NH * HS); - void* devPtrV = (qkvr + 2 * NH * HS); - void* devPtrO = o; - void* devPtrdO = dout; - void* devPtrStats = stats; - float attn_scale_cpu = 1.0 / sqrtf(HS); - - void* devPtrdQ = dqkvr; - void* devPtrdK = (dqkvr + NH * HS); - void* devPtrdV = (dqkvr + 2 * NH * HS); - - // Build variant pack that links each tensor to its data pointer - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {O, devPtrO}, {dO, devPtrdO}, {Stats, devPtrStats}, - {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, - {attn_scale, &attn_scale_cpu}}; - - // Reallocate the workspace if the required size is greater than the current workspace - // By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum - if (graph->get_workspace_size() > cudnn_workspace_size) { - if (cudnn_workspace_size > 0) { - cudaCheck(cudaFree(cudnn_workspace)); - } - cudnn_workspace_size = graph->get_workspace_size(); - cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); - } - - // Execute graph - assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); - cudaCheck(cudaGetLastError()); -} + int B, int T, int NH, int C); +#else +void create_cudnn() {} +void destroy_cudnn() {} #endif // ENABLE_CUDNN // ---------------------------------------------------------------------------- @@ -2600,10 +2347,8 @@ int main(int argc, char *argv[]) { cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); if(cublas_compute_type); // unused in BF16 mode, avoid warning - // set up cuDNN - #ifdef ENABLE_CUDNN - checkCudnnErr(cudnnCreate(&cudnn_handle)); - #endif + // set up cuDNN (noop if not available) + create_cudnn(); printf0("| device | %-50s |\n", deviceProp.name); printf0("| TF32 | %-50s |\n", enable_tf32 ? "enabled" : "disabled"); @@ -2786,10 +2531,7 @@ int main(int argc, char *argv[]) { free(cpu_logits_raw); free(cpu_logits); free(gen_tokens); - #ifdef ENABLE_CUDNN - if (cudnn_workspace != NULL) { cudaCheck(cudaFree(cudnn_workspace)); } - checkCudnnErr(cudnnDestroy(cudnn_handle)); - #endif + destroy_cudnn(); cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); From aa5bb258b6a998b33f861f4ce2e949a06bd2cb3b Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 4 May 2024 11:59:59 +0300 Subject: [PATCH 2/4] fixed up test and profile targets --- Makefile | 8 ++++---- profile_gpt2.cu | 9 ++------- test_gpt2.cu | 11 +++-------- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/Makefile b/Makefile index 865f6fa46..8be333bfd 100644 --- a/Makefile +++ b/Makefile @@ -203,14 +203,14 @@ train_gpt2cu: train_gpt2.cu cudnn_att.o train_gpt2fp32cu: train_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) -test_gpt2cu: test_gpt2.cu - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) +test_gpt2cu: test_gpt2.cu cudnn_att.o + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) cudnn_att.o test_gpt2fp32cu: test_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) -profile_gpt2cu: profile_gpt2.cu - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) +profile_gpt2cu: profile_gpt2.cu cudnn_att.o + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) cudnn_att.o clean: $(REMOVE_FILES) $(TARGETS) diff --git a/profile_gpt2.cu b/profile_gpt2.cu index 5b043e022..2ea7f98cc 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -49,9 +49,7 @@ int main() { cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); // setup the (global) cuBLASLt workspace cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); - #ifdef ENABLE_CUDNN - checkCudnnErr(cudnnCreate(&cudnn_handle)); - #endif + create_cudnn(); // build the GPT-2 model from a checkpoint GPT2 model; @@ -81,10 +79,7 @@ int main() { // free gpt2_free(&model); - #ifdef ENABLE_CUDNN - if (cudnn_workspace != NULL) { cudaCheck(cudaFree(cudnn_workspace)); } - checkCudnnErr(cudnnDestroy(cudnn_handle)); - #endif + destroy_cudnn(); cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); diff --git a/test_gpt2.cu b/test_gpt2.cu index 67afa5065..1a7df6304 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -106,10 +106,8 @@ int main(int argc, char *argv[]) { cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); - - #ifdef ENABLE_CUDNN - checkCudnnErr(cudnnCreate(&cudnn_handle)); - #endif + // set up cuDNN (noop if not available) + create_cudnn(); // build the GPT-2 model from a checkpoint GPT2 model; @@ -326,10 +324,7 @@ int main(int argc, char *argv[]) { free(grads_memory_cpu); free(grads_memory_cpu_float); gpt2_free(&model); - #ifdef ENABLE_CUDNN - if (cudnn_workspace != NULL) { cudaCheck(cudaFree(cudnn_workspace)); } - checkCudnnErr(cudnnDestroy(cudnn_handle)); - #endif + destroy_cudnn(); cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); From 19c290d7e603d7a5a9162665767d4249959efbfa Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 4 May 2024 12:31:15 +0300 Subject: [PATCH 3/4] improved debugging for cudnn --- cudnn_att.cu | 62 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/cudnn_att.cu b/cudnn_att.cu index a9ce3bf76..2735bbd14 100644 --- a/cudnn_att.cu +++ b/cudnn_att.cu @@ -57,25 +57,35 @@ static size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up t static void* cudnn_workspace = NULL; #define checkCudnnErr(err) assert((int)err == 0); +static void checkCudnnFE(fe::error_object e, const char *file, int line) { + if(!e.is_good()) { + printf("[CUDNN ERROR] at file %s:%d:\n%s\n", file, line, e.err_msg.c_str()); + exit(EXIT_FAILURE); + } +} +#define checkCudnnFE(err) checkCudnnFE(err, __FILE__, __LINE__) + using graph_tensors_fwd = std::tuple, -std::shared_ptr, // Q, -std::shared_ptr, // K, -std::shared_ptr, // V, -std::shared_ptr, // Attn_scale, -std::shared_ptr, // O -std::shared_ptr>; // Stats + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::shared_ptr, // Attn_scale, + std::shared_ptr, // O + std::shared_ptr // Stats +>; using graph_tensors_bwd = std::tuple, -std::shared_ptr, // Q, -std::shared_ptr, // K, -std::shared_ptr, // V, -std::shared_ptr, // O -std::shared_ptr, // dO -std::shared_ptr, // Stats -std::shared_ptr, // Attn_scale, -std::shared_ptr, // dQ, -std::shared_ptr, // dK, -std::shared_ptr>; // dV + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::shared_ptr, // O + std::shared_ptr, // dO + std::shared_ptr, // Stats + std::shared_ptr, // Attn_scale, + std::shared_ptr, // dQ, + std::shared_ptr, // dK, + std::shared_ptr // dV +>; // Need a cache because graph->build_operation_graph() is slow but everything else seems fast using cache_type_fwd = std::unordered_map; @@ -130,7 +140,7 @@ auto lookup_cache_or_build_graph_fwd(Args... args) { .set_stride({H * T, T, 1, 1}); } - assert(graph->validate().is_good()); + checkCudnnFE(graph->validate()); auto key = graph->key(); auto it = user_maintained_cache_fwd.find(key); if (it != user_maintained_cache_fwd.end()) { @@ -138,10 +148,10 @@ auto lookup_cache_or_build_graph_fwd(Args... args) { } // Build the operation graph and execution part (this is the VERY SLOW PART) - assert(graph->build_operation_graph(cudnn_handle).is_good()); + checkCudnnFE(graph->build_operation_graph(cudnn_handle)); auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); - assert(graph->check_support(cudnn_handle).is_good()); - assert(graph->build_plans(cudnn_handle).is_good()); + checkCudnnFE(graph->check_support(cudnn_handle)); + checkCudnnFE(graph->build_plans(cudnn_handle)); auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats); user_maintained_cache_fwd.insert({key, tuple}); @@ -204,7 +214,7 @@ auto lookup_cache_or_build_graph_bwd(Args... args) { dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); - assert(graph->validate().is_good()); + checkCudnnFE(graph->validate()); auto key = graph->key(); auto it = user_maintained_cache_bwd.find(key); if (it != user_maintained_cache_bwd.end()) { @@ -212,10 +222,10 @@ auto lookup_cache_or_build_graph_bwd(Args... args) { } // Build the operation graph and execution part (this is the VERY SLOW PART) - assert(graph->build_operation_graph(cudnn_handle).is_good()); + checkCudnnFE(graph->build_operation_graph(cudnn_handle)); auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); - assert(graph->check_support(cudnn_handle).is_good()); - assert(graph->build_plans(cudnn_handle).is_good()); + checkCudnnFE(graph->check_support(cudnn_handle)); + checkCudnnFE(graph->build_plans(cudnn_handle)); auto tuple = std::make_tuple(graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV); user_maintained_cache_bwd.insert({key, tuple}); @@ -261,7 +271,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) } // Execute graph - assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); + checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace)); cudaCheck(cudaGetLastError()); } @@ -305,7 +315,7 @@ void attention_backward_cudnn(floatX* dqkvr, } // Execute graph - assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); + checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace)); cudaCheck(cudaGetLastError()); } From b087b9c819516378cb0f60b00a35f9d812d94304 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 5 May 2024 02:44:15 +0300 Subject: [PATCH 4/4] don't compile/link cudnn if not asked for it --- Makefile | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 8be333bfd..04cbfbb2a 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,7 @@ NVCC_LDFLAGS = -lcublas -lcublasLt NVCC_INCLUDES = NVCC_LDLIBS = NCLL_INCUDES = +NVCC_CUDNN = # overridable flag for multi-GPU training. by default we won't build with cudnn # because it bloats up the compile time from a few seconds to ~minute USE_CUDNN ?= 0 @@ -81,6 +82,7 @@ ifeq ($(USE_CUDNN), 1) NVCC_INCLUDES += -I$(CUDNN_FRONTEND_PATH) NVCC_LDFLAGS += -lcudnn NVCC_FLAGS += -DENABLE_CUDNN + NVCC_CUDNN = cudnn_att.o else $(error ✗ cuDNN not found. See the Makefile for our currently hard-coded paths / install instructions) endif @@ -197,20 +199,20 @@ test_gpt2: test_gpt2.c cudnn_att.o: cudnn_att.cu $(NVCC) -c $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) -train_gpt2cu: train_gpt2.cu cudnn_att.o - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) cudnn_att.o +train_gpt2cu: train_gpt2.cu $(NVCC_CUDNN) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) $(NVCC_CUDNN) train_gpt2fp32cu: train_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) -test_gpt2cu: test_gpt2.cu cudnn_att.o - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) cudnn_att.o +test_gpt2cu: test_gpt2.cu $(NVCC_CUDNN) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) $(NVCC_CUDNN) test_gpt2fp32cu: test_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) -profile_gpt2cu: profile_gpt2.cu cudnn_att.o - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) cudnn_att.o +profile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) $(NVCC_CUDNN) clean: $(REMOVE_FILES) $(TARGETS)